2024-09-25 12:03:44 -07:00
|
|
|
import json
|
2024-10-05 19:25:16 -07:00
|
|
|
import hashlib
|
2024-10-09 18:04:52 -07:00
|
|
|
import app.commons.constants as const
|
|
|
|
|
|
|
|
|
|
from fastapi import Response
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from app.commons.utilities import get_model_server_logger
|
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
|
|
2024-09-10 14:24:46 -07:00
|
|
|
|
2024-10-08 12:40:24 -07:00
|
|
|
logger = get_model_server_logger()
|
|
|
|
|
|
2024-10-02 20:43:16 -07:00
|
|
|
|
2024-10-09 18:04:52 -07:00
|
|
|
class Message(BaseModel):
|
|
|
|
|
role: str
|
2024-10-18 13:25:39 -07:00
|
|
|
content: str = ""
|
|
|
|
|
tool_calls: List[Dict[str, Any]] = []
|
|
|
|
|
tool_call_id: str = ""
|
2024-10-09 18:04:52 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class ChatMessage(BaseModel):
|
|
|
|
|
messages: list[Message]
|
|
|
|
|
tools: List[Dict[str, Any]]
|
|
|
|
|
|
2024-09-10 14:24:46 -07:00
|
|
|
|
2024-10-18 13:25:39 -07:00
|
|
|
def process_messages(history: list[Message]):
|
2024-10-05 19:25:16 -07:00
|
|
|
updated_history = []
|
|
|
|
|
for hist in history:
|
2024-10-18 13:25:39 -07:00
|
|
|
if hist.tool_calls:
|
|
|
|
|
if len(hist.tool_calls) > 1:
|
2024-10-23 14:32:40 -07:00
|
|
|
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
|
|
|
|
|
logger.error(error_msg)
|
|
|
|
|
raise ValueError(error_msg)
|
2024-10-18 13:25:39 -07:00
|
|
|
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
|
|
|
|
updated_history.append(
|
|
|
|
|
{
|
|
|
|
|
"role": "assistant",
|
|
|
|
|
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
elif hist.role == "tool":
|
|
|
|
|
updated_history.append(
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
updated_history.append({"role": hist.role, "content": hist.content})
|
2024-10-05 19:25:16 -07:00
|
|
|
return updated_history
|
|
|
|
|
|
2024-10-07 15:21:05 -07:00
|
|
|
|
2024-09-10 14:24:46 -07:00
|
|
|
async def chat_completion(req: ChatMessage, res: Response):
|
|
|
|
|
logger.info("starting request")
|
2024-10-09 18:04:52 -07:00
|
|
|
|
|
|
|
|
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
|
|
|
|
|
2024-09-25 12:03:44 -07:00
|
|
|
messages = [{"role": "system", "content": tools_encoded}]
|
2024-10-09 18:04:52 -07:00
|
|
|
|
2024-10-18 13:25:39 -07:00
|
|
|
updated_history = process_messages(req.messages)
|
2024-10-05 19:25:16 -07:00
|
|
|
for message in updated_history:
|
|
|
|
|
messages.append({"role": message["role"], "content": message["content"]})
|
|
|
|
|
|
2024-10-09 18:04:52 -07:00
|
|
|
client_model_name = const.arch_function_client.models.list().data[0].id
|
|
|
|
|
|
2024-10-07 15:21:05 -07:00
|
|
|
logger.info(
|
2024-10-09 18:04:52 -07:00
|
|
|
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
|
2024-10-07 15:21:05 -07:00
|
|
|
)
|
2024-10-09 18:04:52 -07:00
|
|
|
|
2024-10-30 17:54:51 -07:00
|
|
|
try:
|
|
|
|
|
resp = const.arch_function_client.chat.completions.create(
|
|
|
|
|
messages=messages,
|
|
|
|
|
model=client_model_name,
|
|
|
|
|
stream=False,
|
|
|
|
|
extra_body=const.arch_function_generation_params,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"model_server <= arch_function: error: {e}")
|
|
|
|
|
raise
|
2024-10-09 18:04:52 -07:00
|
|
|
|
|
|
|
|
tool_calls = const.arch_function_hanlder.extract_tool_calls(
|
|
|
|
|
resp.choices[0].message.content
|
2024-10-01 12:47:26 -07:00
|
|
|
)
|
2024-10-09 18:04:52 -07:00
|
|
|
|
|
|
|
|
if tool_calls:
|
2024-10-01 12:47:26 -07:00
|
|
|
resp.choices[0].message.tool_calls = tool_calls
|
|
|
|
|
resp.choices[0].message.content = None
|
2024-10-09 18:04:52 -07:00
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}"
|
|
|
|
|
)
|
|
|
|
|
|
2024-09-10 14:24:46 -07:00
|
|
|
return resp
|