From 1f383eafc43ae8d2a9219d654be084550ee52f18 Mon Sep 17 00:00:00 2001 From: cotran Date: Thu, 7 Nov 2024 11:15:03 -0800 Subject: [PATCH] address cmt --- e2e_tests/common.py | 10 ++++++++ e2e_tests/test_prompt_gateway.py | 25 ++++++------------- model_server/app/commons/constants.py | 5 ++-- .../app/function_calling/model_utils.py | 10 +++++--- 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 7ccee7c4..1edb6517 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -10,6 +10,16 @@ LLM_GATEWAY_ENDPOINT = os.getenv( ) ARCH_STATE_HEADER = "x-arch-state" +PREFILL_LIST = [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", +] + def get_data_chunks(stream, n=1): chunks = [] diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index e1ff8a9c..4a3e7eb7 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -3,20 +3,12 @@ import pytest import requests from deepdiff import DeepDiff -from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks - - -def get_prefill_list(): - prefill_list = [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ] - return prefill_list +from common import ( + PROMPT_GATEWAY_ENDPOINT, + PREFILL_LIST, + get_arch_messages, + get_data_chunks, +) @pytest.mark.parametrize("stream", [True, False]) @@ -126,10 +118,9 @@ def test_prompt_gateway_arch_direct_response(stream): message = choices[0]["message"]["content"] assert "Could you provide the following details days" not in message - prefill_list = get_prefill_list() assert any( - message.startswith(word) for word in prefill_list - ), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'" + message.startswith(word) for word in PREFILL_LIST + ), f"Expected assistant message to start with one of {PREFILL_LIST}, but got '{assistant_message}'" @pytest.mark.parametrize("stream", [True, False]) diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 78877bc6..d4e01d12 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -8,8 +8,9 @@ from app.prompt_guard.model_handler import ArchGuardHanlder logger = utils.get_model_server_logger() arch_function_hanlder = ArchFunctionHandler() -prefill_list = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] -prefill_enabled = True +PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] +PREFILL_ENABLED = True +TOOL_CALL_TOKEN = "" arch_function_endpoint = "https://api.fc.archgw.com/v1" arch_function_client = utils.get_client(arch_function_endpoint) arch_function_generation_params = { diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index f595571b..6e7b926c 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -87,14 +87,14 @@ async def chat_completion(req: ChatMessage, res: Response): resp = const.arch_function_client.chat.completions.create( messages=messages, model=client_model_name, - stream=const.prefill_enabled, + stream=const.PREFILL_ENABLED, extra_body=const.arch_function_generation_params, ) except Exception as e: logger.error(f"model_server <= arch_function: error: {e}") raise - if const.prefill_enabled: + if const.PREFILL_ENABLED: first_token_content = "" for token in resp: first_token_content = token.choices[ @@ -104,14 +104,16 @@ async def chat_completion(req: ChatMessage, res: Response): break # Check if the first token requires tool call handling - if first_token_content != "": + if first_token_content != const.TOOL_CALL_TOKEN: # Engage pre-filling response if no tool call is indicated resp.close() logger.info("Tool call is not found! Engage pre filling") - prefill_content = random.choice(const.prefill_list) + prefill_content = random.choice(const.PREFILL_LIST) messages.append({"role": "assistant", "content": prefill_content}) # Send a new completion request with the updated messages + # the model will continue the final message in the chat instead of starting a new one + # disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response. extra_body = { **const.arch_function_generation_params, "continue_final_message": True,