mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
update response from upstream llm to now include it in dict with "response"
This commit is contained in:
parent
5bd991e97b
commit
f2323f771c
4 changed files with 30 additions and 11 deletions
|
|
@ -88,6 +88,18 @@ def chat(
|
|||
|
||||
yield "", conversation, history, debug_output, model_selector
|
||||
|
||||
# update assistant response to have correct format
|
||||
# arch-fc 1.1 expects following format:
|
||||
# {
|
||||
# "response": "<assistant response>",
|
||||
# }
|
||||
|
||||
if not history[-1]["model"].startswith("Arch"):
|
||||
assistant_response = {
|
||||
"response": history[-1]["content"],
|
||||
}
|
||||
history[-1]["content"] = json.dumps(assistant_response)
|
||||
|
||||
|
||||
def main():
|
||||
with gr.Blocks(
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
return fixed_str
|
||||
|
||||
def _parse_model_resonse(self, content: str) -> Dict[str, any]:
|
||||
def _parse_model_response(self, content: str) -> Dict[str, any]:
|
||||
"""
|
||||
Extracts tool call information from a given string.
|
||||
|
||||
|
|
@ -212,7 +212,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
response_dict["is_valid"] = False
|
||||
response_dict["error_message"] = f"Fail to parse model responses: {e}"
|
||||
|
||||
return response_dict
|
||||
return content, response_dict
|
||||
|
||||
def _convert_data_type(self, value: str, target_type: str):
|
||||
# TODO: Add more conversion rules as needed
|
||||
|
|
@ -408,9 +408,13 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
self.hallucination_state.tokens
|
||||
)
|
||||
|
||||
logger.info(f"[arch-fc]: raw model response: {model_response}")
|
||||
# Extract tool calls from model response
|
||||
response_dict = self._parse_model_resonse(model_response)
|
||||
raw_model_response_json_fixed, response_dict = self._parse_model_response(
|
||||
model_response
|
||||
)
|
||||
logger.info(
|
||||
f"[arch-fc]: raw model response (json fixed): {raw_model_response_json_fixed}"
|
||||
)
|
||||
|
||||
# General model response
|
||||
if response_dict.get("response", ""):
|
||||
|
|
@ -462,12 +466,12 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_message)],
|
||||
model=self.model_name,
|
||||
metadata={"x-arch-fc-model-response": model_response},
|
||||
metadata={"x-arch-fc-model-response": raw_model_response_json_fixed},
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}"
|
||||
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump(exclude_none=True))}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
|
|
|||
|
|
@ -161,9 +161,10 @@ class ArchBaseHandler:
|
|||
# sample response below
|
||||
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
|
||||
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
|
||||
func_name = json.loads(messages[idx - 1].content)["tool_calls"][
|
||||
0
|
||||
].get("name", "no_name")
|
||||
tool_call_msg = messages[idx - 1].content
|
||||
func_name = json.loads(tool_call_msg)["tool_calls"][0].get(
|
||||
"name", "no_name"
|
||||
)
|
||||
tool_response = {
|
||||
"name": func_name,
|
||||
"result": content,
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ async def models():
|
|||
@app.post("/function_calling")
|
||||
async def function_calling(req: ChatMessage, res: Response):
|
||||
logger.info("[Endpoint: /function_calling]")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
|
||||
|
||||
final_response: ChatCompletionResponse = None
|
||||
error_messages = None
|
||||
|
|
@ -115,9 +115,11 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
except ValueError as e:
|
||||
res.statuscode = 503
|
||||
error_messages = f"[{handler_name}] - Error in tool call extraction: {e}"
|
||||
raise
|
||||
except StopIteration as e:
|
||||
res.statuscode = 500
|
||||
error_messages = f"[{handler_name}] - Error in hallucination check: {e}"
|
||||
raise
|
||||
except Exception as e:
|
||||
res.status_code = 500
|
||||
error_messages = f"[{handler_name}] - Error in ChatCompletion: {e}"
|
||||
|
|
@ -133,7 +135,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
@app.post("/guardrails")
|
||||
async def guardrails(req: GuardRequest, res: Response, max_num_words=300):
|
||||
logger.info("[Endpoint: /guardrails] - Gateway")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump())}")
|
||||
logger.info(f"[request body]: {json.dumps(req.model_dump(exclude_none=True))}")
|
||||
|
||||
final_response: GuardResponse = None
|
||||
error_messages = None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue