From 2517120eebd0610116e97381ada1dbb7ffab0bbc Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 28 Oct 2024 12:12:15 -0700 Subject: [PATCH] improve e2e tests --- crates/llm_gateway/tests/integration.rs | 6 +- crates/prompt_gateway/src/http_context.rs | 4 +- demos/function_calling/api_server/app/main.py | 2 +- demos/llm_routing/arch_config.yaml | 10 +- e2e_tests/common.py | 18 +++ e2e_tests/poetry.lock | 31 ++++- e2e_tests/pyproject.toml | 1 + e2e_tests/test_prompt_gateway.py | 120 +++++++++++++++--- 8 files changed, 169 insertions(+), 23 deletions(-) diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 7ec92ccd..cc17c738 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -207,7 +207,7 @@ fn successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) @@ -325,7 +325,7 @@ fn request_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) // .expect_metric_increment("active_http_calls", 1) @@ -388,7 +388,7 @@ fn request_not_ratelimited() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) // .expect_metric_increment("active_http_calls", 1) diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 9871cc63..596a6a4e 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -304,7 +304,9 @@ impl HttpContext for StreamContext { ), ]; - let response_str = to_server_events(chunks); + let mut response_str = to_server_events(chunks); + // append the original response from the model to the stream + response_str.push_str(&body_utf8); self.set_http_response_body(0, body_size, response_str.as_bytes()); self.tool_calls = None; } diff --git a/demos/function_calling/api_server/app/main.py b/demos/function_calling/api_server/app/main.py index ed7c8f32..e87a3a21 100644 --- a/demos/function_calling/api_server/app/main.py +++ b/demos/function_calling/api_server/app/main.py @@ -76,7 +76,7 @@ async def default_target(req: DefaultTargetRequest, res: Response): "choices": [ { "message": { - "role": "user", + "role": "assistant", "content": "I can help you with weather forecast or insurance claim details", }, "finish_reason": "completed", diff --git a/demos/llm_routing/arch_config.yaml b/demos/llm_routing/arch_config.yaml index e99b9687..620a1d10 100644 --- a/demos/llm_routing/arch_config.yaml +++ b/demos/llm_routing/arch_config.yaml @@ -7,10 +7,16 @@ listener: connect_timeout: 0.005s llm_providers: - - name: gpt-3.5 + - name: gpt-4o-mini access_key: $OPENAI_API_KEY provider: openai - model: gpt-3.5-turbo + model: gpt-4o-mini + default: true + + - name: gpt-3.5-turbo-0125 + access_key: $OPENAI_API_KEY + provider: openai + model: gpt-3.5-turbo-0125 - name: gpt-4o access_key: $OPENAI_API_KEY diff --git a/e2e_tests/common.py b/e2e_tests/common.py index cd9e53e3..7ccee7c4 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -1,3 +1,4 @@ +import json import os @@ -7,6 +8,7 @@ PROMPT_GATEWAY_ENDPOINT = os.getenv( LLM_GATEWAY_ENDPOINT = os.getenv( "LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions" ) +ARCH_STATE_HEADER = "x-arch-state" def get_data_chunks(stream, n=1): @@ -22,3 +24,19 @@ def get_data_chunks(stream, n=1): if len(chunks) >= n: break return chunks + + +def get_arch_messages(response_json): + arch_messages = [] + if response_json and "metadata" in response_json: + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + return arch_messages + return [] diff --git a/e2e_tests/poetry.lock b/e2e_tests/poetry.lock index cdacfd08..68ebfcf5 100644 --- a/e2e_tests/poetry.lock +++ b/e2e_tests/poetry.lock @@ -311,6 +311,24 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1 [package.extras] toml = ["tomli"] +[[package]] +name = "deepdiff" +version = "8.0.1" +description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other." +optional = false +python-versions = ">=3.8" +files = [ + {file = "deepdiff-8.0.1-py3-none-any.whl", hash = "sha256:42e99004ce603f9a53934c634a57b04ad5900e0d8ed0abb15e635767489cbc05"}, + {file = "deepdiff-8.0.1.tar.gz", hash = "sha256:245599a4586ab59bb599ca3517a9c42f3318ff600ded5e80a3432693c8ec3c4b"}, +] + +[package.dependencies] +orderly-set = "5.2.2" + +[package.extras] +cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"] +optimize = ["orjson"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -361,6 +379,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "orderly-set" +version = "5.2.2" +description = "Orderly set" +optional = false +python-versions = ">=3.8" +files = [ + {file = "orderly_set-5.2.2-py3-none-any.whl", hash = "sha256:f7a37c95a38c01cdfe41c3ffb62925a318a2286ea0a41790c057fc802aec54da"}, + {file = "orderly_set-5.2.2.tar.gz", hash = "sha256:52a18b86aaf3f5d5a498bbdb27bf3253a4e5c57ab38e5b7a56fa00115cd28448"}, +] + [[package]] name = "outcome" version = "1.3.0.post0" @@ -670,4 +699,4 @@ h11 = ">=0.9.0,<1" [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851" +content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce" diff --git a/e2e_tests/pyproject.toml b/e2e_tests/pyproject.toml index 45cfbc25..68724c18 100644 --- a/e2e_tests/pyproject.toml +++ b/e2e_tests/pyproject.toml @@ -13,6 +13,7 @@ pytest = "^7.3.1" requests = "^2.29.0" selenium = "^4.11.2" pytest-sugar = "^1.0.0" +deepdiff = "^8.0.1" [tool.poetry.dev-dependencies] pytest-cov = "^4.1.0" diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 00ab20ad..31f305d4 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -1,12 +1,18 @@ import json import pytest import requests +from deepdiff import DeepDiff -from common import PROMPT_GATEWAY_ENDPOINT, get_data_chunks +from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway(stream): + expected_tool_call = { + "name": "weather_forecast", + "arguments": {"city": "seattle", "days": 10}, + } + body = { "messages": [ { @@ -19,14 +25,56 @@ def test_prompt_gateway(stream): response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) assert response.status_code == 200 if stream: - chunks = get_data_chunks(response) - assert len(chunks) > 0 + chunks = get_data_chunks(response, n=20) + assert len(chunks) > 2 + + # first chunk is tool calls (role = assistant) response_json = json.loads(chunks[0]) - # if its streaming we return tool call and api call in first two chunks assert response_json.get("model").startswith("Arch") + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["delta"] + role = choices[0]["delta"]["role"] + assert role == "assistant" + tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + assert len(tool_calls) > 0 + tool_call = tool_calls[0]["function"] + diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) + assert not diff + + # second chunk is api call result (role = tool) + response_json = json.loads(chunks[1]) + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["delta"] + role = choices[0]["delta"]["role"] + assert role == "tool" + + # third..end chunk is summarization (role = assistant) + response_json = json.loads(chunks[2]) + assert response_json.get("model").startswith("gpt-4o-mini") + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["delta"] + role = choices[0]["delta"]["role"] + assert role == "assistant" + else: response_json = response.json() assert response_json.get("model").startswith("gpt-4o-mini") + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["message"] + assert choices[0]["message"]["role"] == "assistant" + # now verify arch_messages (tool call and api response) that are sent as response metadata + arch_messages = get_arch_messages(response_json) + assert len(arch_messages) == 2 + tool_calls_message = arch_messages[0] + tool_calls = tool_calls_message.get("tool_calls", []) + assert len(tool_calls) > 0 + tool_call = tool_calls[0]["function"] + diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) + assert not diff @pytest.mark.parametrize("stream", [True, False]) @@ -46,8 +94,13 @@ def test_prompt_gateway_arch_direct_response(stream): chunks = get_data_chunks(response, n=3) assert len(chunks) > 0 response_json = json.loads(chunks[0]) - # if its streaming we return tool call and api call in first two chunks + # make sure arch responded directly assert response_json.get("model").startswith("Arch") + # and tool call is null + choices = response_json.get("choices", []) + assert len(choices) > 0 + tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + assert len(tool_calls) == 0 else: response_json = response.json() assert response_json.get("model").startswith("Arch") @@ -74,8 +127,13 @@ def test_prompt_gateway_param_gathering(stream): chunks = get_data_chunks(response, n=3) assert len(chunks) > 0 response_json = json.loads(chunks[0]) - # if its streaming we return tool call and api call in first two chunks + # make sure arch responded directly assert response_json.get("model").startswith("Arch") + # and tool call is null + choices = response_json.get("choices", []) + assert len(choices) > 0 + tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + assert len(tool_calls) == 0 else: response_json = response.json() assert response_json.get("model").startswith("Arch") @@ -91,6 +149,7 @@ def test_prompt_gateway_param_tool_call(stream): "name": "weather_forecast", "arguments": {"city": "seattle", "days": 2}, } + body = { "messages": [ { @@ -112,34 +171,56 @@ def test_prompt_gateway_param_tool_call(stream): response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) assert response.status_code == 200 if stream: - chunks = get_data_chunks(response, n=3) - assert len(chunks) > 0 + chunks = get_data_chunks(response, n=20) + assert len(chunks) > 2 - # first chunk is tool calls - response_json = json.loads(chunks[0].lower()) - assert response_json.get("model").startswith("arch") + # first chunk is tool calls (role = assistant) + response_json = json.loads(chunks[0]) + assert response_json.get("model").startswith("Arch") choices = response_json.get("choices", []) assert len(choices) > 0 + assert "role" in choices[0]["delta"] + role = choices[0]["delta"]["role"] + assert role == "assistant" tool_calls = choices[0].get("delta", {}).get("tool_calls", []) assert len(tool_calls) > 0 tool_call = tool_calls[0]["function"] - assert tool_call == expected_tool_call + diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) + assert not diff - # second chunk is api call result + # second chunk is api call result (role = tool) response_json = json.loads(chunks[1]) choices = response_json.get("choices", []) assert len(choices) > 0 + assert "role" in choices[0]["delta"] role = choices[0]["delta"]["role"] assert role == "tool" - # third..end chunk is summarization + # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - # if its streaming we return tool call and api call in first two chunks assert response_json.get("model").startswith("gpt-4o-mini") + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["delta"] + role = choices[0]["delta"]["role"] + assert role == "assistant" else: response_json = response.json() assert response_json.get("model").startswith("gpt-4o-mini") + choices = response_json.get("choices", []) + assert len(choices) > 0 + assert "role" in choices[0]["message"] + assert choices[0]["message"]["role"] == "assistant" + # now verify arch_messages (tool call and api response) that are sent as response metadata + arch_messages = get_arch_messages(response_json) + assert len(arch_messages) == 2 + tool_calls_message = arch_messages[0] + tool_calls = tool_calls_message.get("tool_calls", []) + assert len(tool_calls) > 0 + tool_call = tool_calls[0]["function"] + diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) + assert not diff @pytest.mark.parametrize("stream", [True, False]) @@ -160,6 +241,9 @@ def test_prompt_gateway_default_target(stream): assert len(chunks) > 0 response_json = json.loads(chunks[0]) assert response_json.get("model").startswith("api_server") + assert len(response_json.get("choices", [])) > 0 + assert response_json.get("choices")[0]["delta"]["role"] == "assistant" + response_json = json.loads(chunks[1]) choices = response_json.get("choices", []) assert len(choices) > 0 @@ -170,3 +254,9 @@ def test_prompt_gateway_default_target(stream): else: response_json = response.json() assert response_json.get("model").startswith("api_server") + assert len(response_json.get("choices")) > 0 + assert response_json.get("choices")[0]["message"]["role"] == "assistant" + assert ( + response_json.get("choices")[0]["message"]["content"] + == "I can help you with weather forecast or insurance claim details" + )