From e74a3e1e38a5bdbac87f15813ece9661bb2213bd Mon Sep 17 00:00:00 2001 From: cotran Date: Tue, 5 Nov 2024 08:42:57 -0800 Subject: [PATCH] add e2e test --- e2e_tests/.vscode/launch.json | 15 ++++++++ e2e_tests/common.py | 3 +- e2e_tests/test_prompt_gateway.py | 37 +++++++++++++++++++ .../app/function_calling/model_utils.py | 1 + model_server/app/main.py | 9 +++-- 5 files changed, 61 insertions(+), 4 deletions(-) create mode 100644 e2e_tests/.vscode/launch.json diff --git a/e2e_tests/.vscode/launch.json b/e2e_tests/.vscode/launch.json new file mode 100644 index 00000000..6a211d8e --- /dev/null +++ b/e2e_tests/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 7ccee7c4..2c1e9f76 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -30,7 +30,8 @@ def get_arch_messages(response_json): arch_messages = [] if response_json and "metadata" in response_json: # load arch_state from metadata - arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + arch_state_str = response_json.get("metadata") or {} + arch_state_str = arch_state_str.get(ARCH_STATE_HEADER, "{}") # parse arch_state into json object arch_state = json.loads(arch_state_str) # load messages from arch_state diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 31f305d4..0c40f638 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -2,6 +2,7 @@ import json import pytest import requests from deepdiff import DeepDiff +import model_server.app.commons.constants as const from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks @@ -260,3 +261,39 @@ def test_prompt_gateway_default_target(stream): response_json.get("choices")[0]["message"]["content"] == "I can help you with weather forecast or insurance claim details" ) + + +@pytest.mark.parametrize("prefill_enabled", [True, False]) +def test_prompt_gateway_arch_prefill(prefill_enabled): + body = { + "messages": [ + { + "role": "user", + "content": "how is the weather", + } + ], + "prefill_enabled": prefill_enabled, + } + response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body) + assert response.status_code == 200 + if prefill_enabled: + chunks = get_data_chunks(response, n=3) + assert len(chunks) > 0 + response_json = json.loads(chunks[0]) + # make sure arch responded directly + assert response_json.get("model").startswith("Arch") + # and tool call is null + choices = response_json.get("choices", []) + assert len(choices) > 0 + tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + assert len(tool_calls) == 0 + assistant_message = choices[0].get("delta", {}).get("content", "") + assert assistant_message in const.prefill_list + + else: + response_json = response.json() + assert response_json.get("model").startswith("Arch") + choices = response_json.get("choices", []) + assert len(choices) > 0 + message = choices[0]["message"]["content"] + assert "Could you provide the following details days" not in message diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 093eee9c..35ac9c6e 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -133,6 +133,7 @@ async def chat_completion( if hasattr(token.choices[0].delta, "content"): full_response += token.choices[0].delta.content else: + logger.info("Stream is disabled, not engaging pre-filling") full_response = resp.choices[0].message.content tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response) diff --git a/model_server/app/main.py b/model_server/app/main.py index fdf091f0..9f46457b 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -6,7 +6,7 @@ import app.prompt_guard.model_utils as guard_utils from typing import List, Dict from pydantic import BaseModel -from fastapi import FastAPI, Response, HTTPException +from fastapi import FastAPI, Response, HTTPException, Request from app.function_calling.model_utils import ChatMessage from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler @@ -214,9 +214,12 @@ async def hallucination(req: HallucinationRequest, res: Response): @app.post("/v1/chat/completions") -async def chat_completion(req: ChatMessage, res: Response): +async def chat_completion(req: ChatMessage, res: Response, request: Request): try: - result = await arch_function_chat_completion(req, res) + prefill_enabled = ( + request.query_params.get("prefill_enabled", "true").lower() == "true" + ) + result = await arch_function_chat_completion(req, res, prefill_enabled) return result except Exception as e: logger.error(f"Error in chat_completion: {e}")