update response from upstream llm to now include it in dict with "response"

This commit is contained in:
Adil Hafeez 2025-03-31 18:42:46 -07:00
parent 5bd991e97b
commit f2323f771c
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
4 changed files with 30 additions and 11 deletions

View file

@ -88,6 +88,18 @@ def chat(
yield "", conversation, history, debug_output, model_selector
# update assistant response to have correct format
# arch-fc 1.1 expects following format:
# {
# "response": "<assistant response>",
# }
if not history[-1]["model"].startswith("Arch"):
assistant_response = {
"response": history[-1]["content"],
}
history[-1]["content"] = json.dumps(assistant_response)
def main():
with gr.Blocks(

View file

@ -154,7 +154,7 @@ class ArchFunctionHandler(ArchBaseHandler):
return fixed_str
def _parse_model_resonse(self, content: str) -> Dict[str, any]:
def _parse_model_response(self, content: str) -> Dict[str, any]:
"""
Extracts tool call information from a given string.
@ -212,7 +212,7 @@ class ArchFunctionHandler(ArchBaseHandler):
response_dict["is_valid"] = False
response_dict["error_message"] = f"Fail to parse model responses: {e}"
return response_dict
return content, response_dict
def _convert_data_type(self, value: str, target_type: str):
# TODO: Add more conversion rules as needed
@ -408,9 +408,13 @@ class ArchFunctionHandler(ArchBaseHandler):
self.hallucination_state.tokens
)
logger.info(f"[arch-fc]: raw model response: {model_response}")
# Extract tool calls from model response
response_dict = self._parse_model_resonse(model_response)
raw_model_response_json_fixed, response_dict = self._parse_model_response(
model_response
)
logger.info(
f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}"
)
# General model response
if response_dict.get("response", ""):
@ -462,12 +466,12 @@ class ArchFunctionHandler(ArchBaseHandler):
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_message)],
model=self.model_name,
metadata={"x-arch-fc-model-response": model_response},
metadata={"x-arch-fc-model-response": raw_model_response_json_fixed},
role="assistant",
)
logger.info(
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}"
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}"
)
return chat_completion_response

View file

@ -161,9 +161,10 @@ class ArchBaseHandler:
# sample response below
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
func_name = json.loads(messages[idx - 1].content)["tool_calls"][
0
].get("name", "no_name")
tool_call_msg = messages[idx - 1].content
func_name = json.loads(tool_call_msg)["tool_calls"][0].get(
"name", "no_name"
)
tool_response = {
"name": func_name,
"result": content,

View file

@ -71,7 +71,7 @@ async def models():
@app.post("/function_calling")
async def function_calling(req: ChatMessage, res: Response):
logger.info("[Endpoint: /function_calling]")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: ChatCompletionResponse = None
error_messages = None
@ -115,9 +115,11 @@ async def function_calling(req: ChatMessage, res: Response):
except ValueError as e:
res.statuscode = 503
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
raise
except StopIteration as e:
res.statuscode = 500
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
raise
except Exception as e:
res.status_code = 500
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
@ -133,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response):
@app.post("/guardrails")
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
logger.info("[Endpoint: /guardrails] - Gateway")
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
final_response: GuardResponse = None
error_messages = None