This commit is contained in:
cotran 2024-11-04 10:21:11 -08:00
parent 13fac83381
commit 0910fcdcfa
2 changed files with 13 additions and 23 deletions

View file

@ -85,18 +85,18 @@ async def chat_completion(
# Retrieve the first token, handling the Stream object carefully
if prefill_enabled:
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=True,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=prefill_enabled,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
if prefill_enabled:
first_token_content = ""
for token in resp:
first_token_content = token.choices[
@ -133,17 +133,7 @@ async def chat_completion(
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
else:
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = resp.choices[0].message.content
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
full_response = resp.choices[0].message.content
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)

View file

@ -73,7 +73,7 @@ async def test_chat_completion(mock_hanlder, mock_client):
mock_hanlder._format_system.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
chat_response = await chat_completion(request, response, prefill_enabled=True)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None