oflakne26 commited on
Commit
9db93ec
1 Parent(s): 0abcce5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +42 -3
main.py CHANGED
@@ -30,7 +30,7 @@ class InputData(BaseModel):
30
  json_prompt: str
31
  history: str = ""
32
 
33
- @app.post("/generate-response/")
34
  async def generate_response(data: InputData) -> Any:
35
  client = InferenceClient(model=data.model, token=HF_TOKEN)
36
 
@@ -50,10 +50,49 @@ async def generate_response(data: InputData) -> Any:
50
 
51
  seed = random.randint(0, 2**32 - 1)
52
 
53
- models_to_try = [data.model] + FALLBACK_MODELS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- for model in models_to_try:
 
56
  try:
 
57
  response = client.text_generation(inputs,
58
  temperature=1.0,
59
  max_new_tokens=1000,
 
30
  json_prompt: str
31
  history: str = ""
32
 
33
+ @@app.post("/generate-response/")
34
  async def generate_response(data: InputData) -> Any:
35
  client = InferenceClient(model=data.model, token=HF_TOKEN)
36
 
 
50
 
51
  seed = random.randint(0, 2**32 - 1)
52
 
53
+ try:
54
+ response = client.text_generation(inputs,
55
+ temperature=1.0,
56
+ max_new_tokens=1000,
57
+ seed=seed)
58
+
59
+ strict_response = str(response)
60
+
61
+ repaired_response = repair_json(strict_response,
62
+ return_objects=True)
63
+
64
+ if isinstance(repaired_response, str):
65
+ raise HTTPException(status_code=500, detail="Invalid response from model")
66
+ else:
67
+ cleaned_response = {}
68
+ for key, value in repaired_response.items():
69
+ cleaned_key = key.replace("###", "")
70
+ cleaned_response[cleaned_key] = value
71
+
72
+ for i, text in enumerate(cleaned_response["New response"]):
73
+ if i <= 2:
74
+ sentences = tokenizer.tokenize(text)
75
+ if sentences:
76
+ cleaned_response["New response"][i] = sentences[0]
77
+ else:
78
+ del cleaned_response["New response"][i]
79
+ if cleaned_response.get("Sentence count"):
80
+ if cleaned_response["Sentence count"] > 3:
81
+ cleaned_response["Sentence count"] = 3
82
+ else:
83
+ cleaned_response["Sentence count"] = len(cleaned_response["New response"])
84
+
85
+ data.history += str(cleaned_response)
86
+
87
+ return cleaned_response
88
+
89
+ except Exception as e:
90
+ print(f"Primary model {data.model} failed with error: {e}")
91
 
92
+ # If the primary model fails, try fallback models
93
+ for model in FALLBACK_MODELS:
94
  try:
95
+ client = InferenceClient(model=model, token=HF_TOKEN)
96
  response = client.text_generation(inputs,
97
  temperature=1.0,
98
  max_new_tokens=1000,