add prefill and test

This commit is contained in:
cotran 2024-10-30 17:00:30 -07:00
parent bb9a774a72
commit 5919f8b9b9
3 changed files with 154 additions and 17 deletions

View file

@ -5,17 +5,17 @@ import app.commons.constants as const
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
role: str
content: str = ""
tool_calls: List[Dict[str, Any]] = []
tool_call_id: str = ""
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
@ -23,6 +23,14 @@ class ChatMessage(BaseModel):
tools: List[Dict[str, Any]]
class Choice(BaseModel):
message: Message
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
@ -70,23 +78,63 @@ async def chat_completion(req: ChatMessage, res: Response):
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
stream=True,
extra_body=const.arch_function_generation_params,
)
tool_calls = const.arch_function_hanlder.extract_tool_calls(
resp.choices[0].message.content
)
# Retrieve the first token, handling the Stream object carefully
first_token_content = ""
try:
while True:
first_token = next(resp) # Synchronously retrieve tokens
first_token_content = first_token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
except StopIteration:
print("No non-empty tokens found.")
return None
# Check if the first token requires tool call handling
if first_token_content != "<tool_call>":
# Engage pre-filling response if no tool call is indicated
logger.info("Tool call is not found! Engage pre filling")
messages.append({"role": "assistant", "content": "Sure!"})
# Send a new completion request with the updated messages
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = "<tool_call>"
try:
while True:
token = next(resp) # Retrieve each token synchronously
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
except StopIteration:
pass # End of stream
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
chat_completion_response = ChatCompletionResponse(choices=[choice])
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}"
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
return resp
return chat_completion_response