diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 67970bf9..e1060bc8 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -8,6 +8,7 @@ from app.prompt_guard.model_handler import ArchGuardHanlder logger = utils.get_model_server_logger() arch_function_hanlder = ArchFunctionHandler() +prefill_list = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] arch_function_endpoint = "https://api.fc.archgw.com/v1" arch_function_client = utils.get_client(arch_function_endpoint) arch_function_generation_params = { diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 8fc49940..e3ceea51 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -1,7 +1,7 @@ import json import hashlib import app.commons.constants as const - +import random from fastapi import Response from pydantic import BaseModel from app.commons.utilities import get_model_server_logger @@ -64,7 +64,9 @@ def process_messages(history: list[Message]): return updated_history -async def chat_completion(req: ChatMessage, res: Response): +async def chat_completion( + req: ChatMessage, res: Response, prefill_enabled: bool = True +): logger.info("starting request") tools_encoded = const.arch_function_hanlder._format_system(req.tools) @@ -81,55 +83,61 @@ async def chat_completion(req: ChatMessage, res: Response): f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}" ) - try: - resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=True, - extra_body=const.arch_function_generation_params, - ) - except Exception as e: - logger.error(f"model_server <= arch_function: error: {e}") - raise - # Retrieve the first token, handling the Stream object carefully - first_token_content = "" - try: - while True: - first_token = next(resp) # Synchronously retrieve tokens - first_token_content = first_token.choices[ + + if prefill_enabled: + try: + resp = const.arch_function_client.chat.completions.create( + messages=messages, + model=client_model_name, + stream=True, + extra_body=const.arch_function_generation_params, + ) + except Exception as e: + logger.error(f"model_server <= arch_function: error: {e}") + raise + first_token_content = "" + for token in resp: + first_token_content = token.choices[ 0 ].delta.content.strip() # Clean up the content if first_token_content: # Break if it's non-empty break - except StopIteration: - print("No non-empty tokens found.") - return None - # Check if the first token requires tool call handling - if first_token_content != "": - # Engage pre-filling response if no tool call is indicated - logger.info("Tool call is not found! Engage pre filling") - messages.append({"role": "assistant", "content": "Sure!"}) + # Check if the first token requires tool call handling + if first_token_content != "": + # Engage pre-filling response if no tool call is indicated + resp.close() + logger.info("Tool call is not found! Engage pre filling") + prefill_content = random.choice(const.prefill_list) + messages.append({"role": "assistant", "content": prefill_content}) - # Send a new completion request with the updated messages - pre_fill_resp = const.arch_function_client.chat.completions.create( - messages=messages, - model=client_model_name, - stream=False, - extra_body=const.arch_function_generation_params, - ) - full_response = pre_fill_resp.choices[0].message.content - else: - # Initialize full response and iterate over tokens to gather the full response - full_response = "" - try: - while True: - token = next(resp) # Retrieve each token synchronously + # Send a new completion request with the updated messages + pre_fill_resp = const.arch_function_client.chat.completions.create( + messages=messages, + model=client_model_name, + stream=False, + extra_body=const.arch_function_generation_params, + ) + full_response = pre_fill_resp.choices[0].message.content + else: + # Initialize full response and iterate over tokens to gather the full response + full_response = first_token_content + for token in resp: if hasattr(token.choices[0].delta, "content"): full_response += token.choices[0].delta.content - except StopIteration: - pass # End of stream + else: + try: + resp = const.arch_function_client.chat.completions.create( + messages=messages, + model=client_model_name, + stream=False, + extra_body=const.arch_function_generation_params, + ) + full_response = resp.choices[0].message.content + except Exception as e: + logger.error(f"model_server <= arch_function: error: {e}") + raise tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response) diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index 3b4a05a4..c10e4b6f 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -1,5 +1,6 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch +import app.commons.constants as const from fastapi import Response from app.function_calling.model_utils import ( process_messages, @@ -86,4 +87,4 @@ async def test_chat_completion(mock_hanlder, mock_client): second_call_args = mock_client.chat.completions.create.call_args_list[1][1] assert second_call_args["stream"] == False assert "model" in second_call_args - assert second_call_args["messages"][-1]["content"] == "Sure!" + assert second_call_args["messages"][-1]["content"] in const.prefill_list