From 68c2243e83ea0a043427bdc50a89b9771f5b39c7 Mon Sep 17 00:00:00 2001 From: cotran Date: Wed, 6 Nov 2024 16:16:08 -0800 Subject: [PATCH] update fix --- e2e_tests/common.py | 3 +- e2e_tests/test_prompt_gateway.py | 57 ++++++++++++------- model_server/app/commons/constants.py | 1 + .../app/function_calling/model_utils.py | 8 +-- model_server/app/main.py | 5 +- .../app/tests/test_function_calling.py | 2 +- 6 files changed, 43 insertions(+), 33 deletions(-) diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 2c1e9f76..7ccee7c4 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -30,8 +30,7 @@ def get_arch_messages(response_json): arch_messages = [] if response_json and "metadata" in response_json: # load arch_state from metadata - arch_state_str = response_json.get("metadata") or {} - arch_state_str = arch_state_str.get(ARCH_STATE_HEADER, "{}") + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") # parse arch_state into json object arch_state = json.loads(arch_state_str) # load messages from arch_state diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 79c04aa7..d614a089 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -262,8 +262,8 @@ def test_prompt_gateway_default_target(stream): ) -@pytest.mark.parametrize("prefill_enabled", [True, False]) -def test_prompt_gateway_arch_prefill(prefill_enabled): +@pytest.mark.parametrize("stream", [True, False]) +def test_prompt_gateway_arch_prefill(stream): body = { "messages": [ { @@ -271,29 +271,44 @@ def test_prompt_gateway_arch_prefill(prefill_enabled): "content": "how is the weather", } ], - "prefill_enabled": prefill_enabled, + "stream": stream, } response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body) assert response.status_code == 200 - response_json = response.json() - assert response_json.get("model").startswith("Arch") - choices = response_json.get("choices", []) - assert len(choices) > 0 - if prefill_enabled: - prefill_list = [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ] - assistant_message = choices[0]["message"]["content"] - assert any( - assistant_message.startswith(word) for word in prefill_list - ), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'" + if stream: + chunks = get_data_chunks(response, n=3) + assert len(chunks) > 0 + response_json = json.loads(chunks[0]) + # make sure arch responded directly + assert response_json.get("model").startswith("Arch") + # and tool call is null + choices = response_json.get("choices", []) + assert len(choices) > 0 + tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + assert len(tool_calls) == 0 + response_json = json.loads(chunks[1]) + choices = response_json.get("choices", []) + assert len(choices) > 0 + message = choices[0]["delta"]["content"] else: + response_json = response.json() + assert response_json.get("model").startswith("Arch") + choices = response_json.get("choices", []) + assert len(choices) > 0 message = choices[0]["message"]["content"] assert "Could you provide the following details days" not in message + + prefill_list = [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", + ] + + assert any( + message.startswith(word) for word in prefill_list + ), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'" diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index e1060bc8..78877bc6 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -9,6 +9,7 @@ logger = utils.get_model_server_logger() arch_function_hanlder = ArchFunctionHandler() prefill_list = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] +prefill_enabled = True arch_function_endpoint = "https://api.fc.archgw.com/v1" arch_function_client = utils.get_client(arch_function_endpoint) arch_function_generation_params = { diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 35ac9c6e..f595571b 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -64,9 +64,7 @@ def process_messages(history: list[Message]): return updated_history -async def chat_completion( - req: ChatMessage, res: Response, prefill_enabled: bool = True -): +async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") tools_encoded = const.arch_function_hanlder._format_system(req.tools) @@ -89,14 +87,14 @@ async def chat_completion( resp = const.arch_function_client.chat.completions.create( messages=messages, model=client_model_name, - stream=prefill_enabled, + stream=const.prefill_enabled, extra_body=const.arch_function_generation_params, ) except Exception as e: logger.error(f"model_server <= arch_function: error: {e}") raise - if prefill_enabled: + if const.prefill_enabled: first_token_content = "" for token in resp: first_token_content = token.choices[ diff --git a/model_server/app/main.py b/model_server/app/main.py index 9f46457b..801bc36d 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -216,10 +216,7 @@ async def hallucination(req: HallucinationRequest, res: Response): @app.post("/v1/chat/completions") async def chat_completion(req: ChatMessage, res: Response, request: Request): try: - prefill_enabled = ( - request.query_params.get("prefill_enabled", "true").lower() == "true" - ) - result = await arch_function_chat_completion(req, res, prefill_enabled) + result = await arch_function_chat_completion(req, res) return result except Exception as e: logger.error(f"Error in chat_completion: {e}") diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index 05ecd971..c10e4b6f 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -73,7 +73,7 @@ async def test_chat_completion(mock_hanlder, mock_client): mock_hanlder._format_system.return_value = "" response = Response() - chat_response = await chat_completion(request, response, prefill_enabled=True) + chat_response = await chat_completion(request, response) assert isinstance(chat_response, ChatCompletionResponse) assert chat_response.choices[0].message.content is not None