diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 9c2da39b..093eee9c 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -85,18 +85,18 @@ async def chat_completion( # Retrieve the first token, handling the Stream object carefully - if prefill_enabled: - try: - resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=True, - extra_body=const.arch_function_generation_params, - ) - except Exception as e: - logger.error(f"model_server <= arch_function: error: {e}") - raise + try: + resp = const.arch_function_client.chat.completions.create( + messages=messages, + model=client_model_name, + stream=prefill_enabled, + extra_body=const.arch_function_generation_params, + ) + except Exception as e: + logger.error(f"model_server <= arch_function: error: {e}") + raise + if prefill_enabled: first_token_content = "" for token in resp: first_token_content = token.choices[ @@ -133,17 +133,7 @@ async def chat_completion( if hasattr(token.choices[0].delta, "content"): full_response += token.choices[0].delta.content else: - try: - resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=False, - extra_body=const.arch_function_generation_params, - ) - full_response = resp.choices[0].message.content - except Exception as e: - logger.error(f"model_server <= arch_function: error: {e}") - raise + full_response = resp.choices[0].message.content tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response) diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index c10e4b6f..05ecd971 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -73,7 +73,7 @@ async def test_chat_completion(mock_hanlder, mock_client): mock_hanlder._format_system.return_value = "" response = Response() - chat_response = await chat_completion(request, response) + chat_response = await chat_completion(request, response, prefill_enabled=True) assert isinstance(chat_response, ChatCompletionResponse) assert chat_response.choices[0].message.content is not None