This commit is contained in:
cotran 2024-10-31 14:49:03 -07:00
parent b15736a7fd
commit cd22c71690
3 changed files with 53 additions and 43 deletions

View file

@ -1,7 +1,7 @@
import json
import hashlib
import app.commons.constants as const
import random
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
@ -64,7 +64,9 @@ def process_messages(history: list[Message]):
return updated_history
async def chat_completion(req: ChatMessage, res: Response):
async def chat_completion(
req: ChatMessage, res: Response, prefill_enabled: bool = True
):
logger.info("starting request")
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
@ -81,55 +83,61 @@ async def chat_completion(req: ChatMessage, res: Response):
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
)
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=True,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
# Retrieve the first token, handling the Stream object carefully
first_token_content = ""
try:
while True:
first_token = next(resp) # Synchronously retrieve tokens
first_token_content = first_token.choices[
if prefill_enabled:
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=True,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
first_token_content = ""
for token in resp:
first_token_content = token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
except StopIteration:
print("No non-empty tokens found.")
return None
# Check if the first token requires tool call handling
if first_token_content != "<tool_call>":
# Engage pre-filling response if no tool call is indicated
logger.info("Tool call is not found! Engage pre filling")
messages.append({"role": "assistant", "content": "Sure!"})
# Check if the first token requires tool call handling
if first_token_content != "<tool_call>":
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
prefill_content = random.choice(const.prefill_list)
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = "<tool_call>"
try:
while True:
token = next(resp) # Retrieve each token synchronously
# Send a new completion request with the updated messages
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = first_token_content
for token in resp:
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
except StopIteration:
pass # End of stream
else:
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = resp.choices[0].message.content
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)