plano/model_server/app/function_calling/model_utils.py

158 lines
5.5 KiB
Python
Raw Normal View History

2024-09-25 12:03:44 -07:00
import json
import hashlib
import app.commons.constants as const
2024-10-31 14:49:03 -07:00
import random
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
2024-10-30 17:00:30 -07:00
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
2024-10-30 17:00:30 -07:00
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
2024-10-30 17:00:30 -07:00
class Choice(BaseModel):
message: Message
2024-10-31 12:51:11 -07:00
finish_reason: Optional[str] = "stop"
index: Optional[int] = 0
2024-10-30 17:00:30 -07:00
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
2024-10-31 12:51:11 -07:00
model: Optional[str] = "Arch-Function"
created: Optional[str] = ""
id: Optional[str] = ""
object: Optional[str] = "chat_completion"
2024-10-30 17:00:30 -07:00
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
if hist.tool_calls:
if len(hist.tool_calls) > 1:
error_msg = f"Only one tool call is supported, tools counts: {len(hist.tool_calls)}"
logger.error(error_msg)
raise ValueError(error_msg)
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
elif hist.role == "tool":
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
}
)
else:
updated_history.append({"role": hist.role, "content": hist.content})
return updated_history
2024-10-07 15:21:05 -07:00
2024-11-06 16:16:08 -08:00
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
2024-09-25 12:03:44 -07:00
messages = [{"role": "system", "content": tools_encoded}]
updated_history = process_messages(req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})
client_model_name = const.arch_function_client.models.list().data[0].id
2024-10-07 15:21:05 -07:00
logger.info(
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
2024-10-07 15:21:05 -07:00
)
2024-10-30 17:00:30 -07:00
# Retrieve the first token, handling the Stream object carefully
2024-10-31 14:49:03 -07:00
2024-11-04 10:21:11 -08:00
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
2024-11-07 11:15:03 -08:00
stream=const.PREFILL_ENABLED,
2024-11-04 10:21:11 -08:00
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
2024-11-01 10:43:34 -07:00
2024-11-07 11:15:03 -08:00
if const.PREFILL_ENABLED:
2024-10-31 14:49:03 -07:00
first_token_content = ""
for token in resp:
first_token_content = token.choices[
2024-10-30 17:00:30 -07:00
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
2024-10-31 14:49:03 -07:00
# Check if the first token requires tool call handling
2024-11-07 11:15:03 -08:00
if first_token_content != const.TOOL_CALL_TOKEN:
2024-10-31 14:49:03 -07:00
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
2024-11-07 11:15:03 -08:00
prefill_content = random.choice(const.PREFILL_LIST)
2024-10-31 14:49:03 -07:00
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
2024-11-07 11:15:03 -08:00
# the model will continue the final message in the chat instead of starting a new one
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
2024-11-01 10:43:34 -07:00
extra_body = {
**const.arch_function_generation_params,
"continue_final_message": True,
"add_generation_prompt": False,
}
2024-10-31 14:49:03 -07:00
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
2024-11-01 10:43:34 -07:00
extra_body=extra_body,
2024-10-31 14:49:03 -07:00
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = first_token_content
for token in resp:
2024-10-30 17:00:30 -07:00
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
2024-10-31 14:49:03 -07:00
else:
2024-11-05 08:42:57 -08:00
logger.info("Stream is disabled, not engaging pre-filling")
2024-11-04 10:21:11 -08:00
full_response = resp.choices[0].message.content
2024-10-30 17:00:30 -07:00
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
2024-10-30 17:00:30 -07:00
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
2024-10-31 12:51:11 -07:00
chat_completion_response = ChatCompletionResponse(
choices=[choice], model=client_model_name
)
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
2024-10-30 17:00:30 -07:00
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
2024-10-30 17:00:30 -07:00
return chat_completion_response