improve e2e tests

This commit is contained in:
Adil Hafeez 2024-10-28 12:12:15 -07:00
parent 9b1e9ba49d
commit 2517120eeb
8 changed files with 169 additions and 23 deletions

View file

@ -207,7 +207,7 @@ fn successful_request_to_open_ai_chat_completions() {
)
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
@ -325,7 +325,7 @@ fn request_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
// .expect_metric_increment("active_http_calls", 1)
@ -388,7 +388,7 @@ fn request_not_ratelimited() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
// .expect_metric_increment("active_http_calls", 1)

View file

@ -304,7 +304,9 @@ impl HttpContext for StreamContext {
),
];
let response_str = to_server_events(chunks);
let mut response_str = to_server_events(chunks);
// append the original response from the model to the stream
response_str.push_str(&body_utf8);
self.set_http_response_body(0, body_size, response_str.as_bytes());
self.tool_calls = None;
}

View file

@ -76,7 +76,7 @@ async def default_target(req: DefaultTargetRequest, res: Response):
"choices": [
{
"message": {
"role": "user",
"role": "assistant",
"content": "I can help you with weather forecast or insurance claim details",
},
"finish_reason": "completed",

View file

@ -7,10 +7,16 @@ listener:
connect_timeout: 0.005s
llm_providers:
- name: gpt-3.5
- name: gpt-4o-mini
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-3.5-turbo
model: gpt-4o-mini
default: true
- name: gpt-3.5-turbo-0125
access_key: $OPENAI_API_KEY
provider: openai
model: gpt-3.5-turbo-0125
- name: gpt-4o
access_key: $OPENAI_API_KEY

View file

@ -1,3 +1,4 @@
import json
import os
@ -7,6 +8,7 @@ PROMPT_GATEWAY_ENDPOINT = os.getenv(
LLM_GATEWAY_ENDPOINT = os.getenv(
"LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions"
)
ARCH_STATE_HEADER = "x-arch-state"
def get_data_chunks(stream, n=1):
@ -22,3 +24,19 @@ def get_data_chunks(stream, n=1):
if len(chunks) >= n:
break
return chunks
def get_arch_messages(response_json):
arch_messages = []
if response_json and "metadata" in response_json:
# load arch_state from metadata
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
# parse arch_state into json object
arch_state = json.loads(arch_state_str)
# load messages from arch_state
arch_messages_str = arch_state.get("messages", "[]")
# parse messages into json object
arch_messages = json.loads(arch_messages_str)
# append messages from arch gateway to history
return arch_messages
return []

31
e2e_tests/poetry.lock generated
View file

@ -311,6 +311,24 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
[package.extras]
toml = ["tomli"]
[[package]]
name = "deepdiff"
version = "8.0.1"
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
optional = false
python-versions = ">=3.8"
files = [
{file = "deepdiff-8.0.1-py3-none-any.whl", hash = "sha256:42e99004ce603f9a53934c634a57b04ad5900e0d8ed0abb15e635767489cbc05"},
{file = "deepdiff-8.0.1.tar.gz", hash = "sha256:245599a4586ab59bb599ca3517a9c42f3318ff600ded5e80a3432693c8ec3c4b"},
]
[package.dependencies]
orderly-set = "5.2.2"
[package.extras]
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
optimize = ["orjson"]
[[package]]
name = "exceptiongroup"
version = "1.2.2"
@ -361,6 +379,17 @@ files = [
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
]
[[package]]
name = "orderly-set"
version = "5.2.2"
description = "Orderly set"
optional = false
python-versions = ">=3.8"
files = [
{file = "orderly_set-5.2.2-py3-none-any.whl", hash = "sha256:f7a37c95a38c01cdfe41c3ffb62925a318a2286ea0a41790c057fc802aec54da"},
{file = "orderly_set-5.2.2.tar.gz", hash = "sha256:52a18b86aaf3f5d5a498bbdb27bf3253a4e5c57ab38e5b7a56fa00115cd28448"},
]
[[package]]
name = "outcome"
version = "1.3.0.post0"
@ -670,4 +699,4 @@ h11 = ">=0.9.0,<1"
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851"
content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce"

View file

@ -13,6 +13,7 @@ pytest = "^7.3.1"
requests = "^2.29.0"
selenium = "^4.11.2"
pytest-sugar = "^1.0.0"
deepdiff = "^8.0.1"
[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"

View file

@ -1,12 +1,18 @@
import json
import pytest
import requests
from deepdiff import DeepDiff
from common import PROMPT_GATEWAY_ENDPOINT, get_data_chunks
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway(stream):
expected_tool_call = {
"name": "weather_forecast",
"arguments": {"city": "seattle", "days": 10},
}
body = {
"messages": [
{
@ -19,14 +25,56 @@ def test_prompt_gateway(stream):
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response)
assert len(chunks) > 0
chunks = get_data_chunks(response, n=20)
assert len(chunks) > 2
# first chunk is tool calls (role = assistant)
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "assistant"
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff
# second chunk is api call result (role = tool)
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "tool"
# third..end chunk is summarization (role = assistant)
response_json = json.loads(chunks[2])
assert response_json.get("model").startswith("gpt-4o-mini")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "assistant"
else:
response_json = response.json()
assert response_json.get("model").startswith("gpt-4o-mini")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["message"]
assert choices[0]["message"]["role"] == "assistant"
# now verify arch_messages (tool call and api response) that are sent as response metadata
arch_messages = get_arch_messages(response_json)
assert len(arch_messages) == 2
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff
@pytest.mark.parametrize("stream", [True, False])
@ -46,8 +94,13 @@ def test_prompt_gateway_arch_direct_response(stream):
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
# make sure arch responded directly
assert response_json.get("model").startswith("Arch")
# and tool call is null
choices = response_json.get("choices", [])
assert len(choices) > 0
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) == 0
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
@ -74,8 +127,13 @@ def test_prompt_gateway_param_gathering(stream):
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
# make sure arch responded directly
assert response_json.get("model").startswith("Arch")
# and tool call is null
choices = response_json.get("choices", [])
assert len(choices) > 0
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) == 0
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
@ -91,6 +149,7 @@ def test_prompt_gateway_param_tool_call(stream):
"name": "weather_forecast",
"arguments": {"city": "seattle", "days": 2},
}
body = {
"messages": [
{
@ -112,34 +171,56 @@ def test_prompt_gateway_param_tool_call(stream):
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
chunks = get_data_chunks(response, n=20)
assert len(chunks) > 2
# first chunk is tool calls
response_json = json.loads(chunks[0].lower())
assert response_json.get("model").startswith("arch")
# first chunk is tool calls (role = assistant)
response_json = json.loads(chunks[0])
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "assistant"
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
assert tool_call == expected_tool_call
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff
# second chunk is api call result
# second chunk is api call result (role = tool)
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "tool"
# third..end chunk is summarization
# third..end chunk is summarization (role = assistant)
response_json = json.loads(chunks[2])
# if its streaming we return tool call and api call in first two chunks
assert response_json.get("model").startswith("gpt-4o-mini")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "assistant"
else:
response_json = response.json()
assert response_json.get("model").startswith("gpt-4o-mini")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["message"]
assert choices[0]["message"]["role"] == "assistant"
# now verify arch_messages (tool call and api response) that are sent as response metadata
arch_messages = get_arch_messages(response_json)
assert len(arch_messages) == 2
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff
@pytest.mark.parametrize("stream", [True, False])
@ -160,6 +241,9 @@ def test_prompt_gateway_default_target(stream):
assert len(chunks) > 0
response_json = json.loads(chunks[0])
assert response_json.get("model").startswith("api_server")
assert len(response_json.get("choices", [])) > 0
assert response_json.get("choices")[0]["delta"]["role"] == "assistant"
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
@ -170,3 +254,9 @@ def test_prompt_gateway_default_target(stream):
else:
response_json = response.json()
assert response_json.get("model").startswith("api_server")
assert len(response_json.get("choices")) > 0
assert response_json.get("choices")[0]["message"]["role"] == "assistant"
assert (
response_json.get("choices")[0]["message"]["content"]
== "I can help you with weather forecast or insurance claim details"
)