mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
update rest and other parts of the code to work with arch fc 1.1
This commit is contained in:
parent
e5949c584f
commit
b31a7a569a
8 changed files with 196 additions and 47 deletions
|
|
@ -197,12 +197,12 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
|
||||
response_dict["response"] = model_response.get("response", "")
|
||||
response_dict["required_functions"] = model_response.get(
|
||||
"required_functions", ""
|
||||
"required_functions", []
|
||||
)
|
||||
response_dict["clarification"] = model_response.get("clarification", "")
|
||||
|
||||
for tool_call in model_response.get("tool_calls", []):
|
||||
response_dict["tool_call"].append(
|
||||
response_dict["tool_calls"].append(
|
||||
{
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
|
|
@ -448,6 +448,7 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
if len(chunk.choices) > 0 and chunk.choices[0].delta.content:
|
||||
model_response += chunk.choices[0].delta.content
|
||||
|
||||
logger.info(f"[arch-fc]: raw model response: {model_response}")
|
||||
# Extract tool calls from model response
|
||||
response_dict = self._parse_model_resonse(model_response)
|
||||
|
||||
|
|
@ -499,10 +500,15 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
model_message = Message(content="", tool_calls=[])
|
||||
|
||||
chat_completion_response = ChatCompletionResponse(
|
||||
choices=[Choice(message=model_message)], model=self.model_name
|
||||
choices=[Choice(message=model_message)],
|
||||
model=self.model_name,
|
||||
metadata={"x-arch-fc-model-response": model_response},
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
logger.info(f"[response]: {json.dumps(chat_completion_response.model_dump())}")
|
||||
logger.info(
|
||||
f"[response arch-fc]: {json.dumps(chat_completion_response.model_dump())}"
|
||||
)
|
||||
|
||||
return chat_completion_response
|
||||
|
||||
|
|
|
|||
|
|
@ -142,7 +142,7 @@ class ArchBaseHandler:
|
|||
{"role": "system", "content": self._format_system_prompt(tools)}
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
for idx, message in enumerate(messages):
|
||||
role, content, tool_calls = (
|
||||
message.role,
|
||||
message.content,
|
||||
|
|
@ -158,9 +158,17 @@ class ArchBaseHandler:
|
|||
if metadata.get("optimize_context_window", "false").lower() == "true":
|
||||
content = f"<tool_response>\n\n</tool_response>"
|
||||
else:
|
||||
content = (
|
||||
f"<tool_response>\n{json.dumps(content)}\n</tool_response>"
|
||||
)
|
||||
# sample response below
|
||||
# "content": "<tool_response>\n{'name': 'get_stock_price', 'result': '$196.66'}\n</tool_response>"
|
||||
# msg[idx-1] contains tool call = '{"tool_calls": [{"name": "currency_exchange", "arguments": {"currency_symbol": "NZD"}}]}'
|
||||
func_name = json.loads(messages[idx - 1].content)["tool_calls"][
|
||||
0
|
||||
].get("name", "no_name")
|
||||
tool_response = {
|
||||
"name": func_name,
|
||||
"result": content,
|
||||
}
|
||||
content = f"<tool_response>\n{json.dumps(tool_response)}\n</tool_response>"
|
||||
|
||||
processed_messages.append({"role": role, "content": content})
|
||||
|
||||
|
|
|
|||
|
|
@ -87,16 +87,15 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
if not final_response.metadata:
|
||||
final_response.metadata = {}
|
||||
|
||||
# Parameter gathering for detected intents
|
||||
if final_response.choices[0].message.content:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
# Function Calling
|
||||
elif final_response.choices[0].message.tool_calls:
|
||||
final_response.metadata = {
|
||||
"function_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["function_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
# *********************************************************************************************
|
||||
# TODO: Put the following code back when hallucination check is ready
|
||||
|
|
@ -107,9 +106,7 @@ async def function_calling(req: ChatMessage, res: Response):
|
|||
# )
|
||||
# No intent detected
|
||||
else:
|
||||
final_response.metadata = {
|
||||
"intent_latency": str(round(latency * 1000, 3)),
|
||||
}
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
||||
if not use_agent_orchestrator:
|
||||
final_response.metadata["intent_latency"] = str(round(latency * 1000, 3))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue