mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
improve e2e tests
This commit is contained in:
parent
9b1e9ba49d
commit
2517120eeb
8 changed files with 169 additions and 23 deletions
|
|
@ -207,7 +207,7 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
|
|
@ -325,7 +325,7 @@ fn request_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
// .expect_metric_increment("active_http_calls", 1)
|
||||
|
|
@ -388,7 +388,7 @@ fn request_not_ratelimited() {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
||||
.returning(Some(chat_completions_request_body))
|
||||
// The actual call is not important in this test, we just need to grab the token_id
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Trace), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
// .expect_metric_increment("active_http_calls", 1)
|
||||
|
|
|
|||
|
|
@ -304,7 +304,9 @@ impl HttpContext for StreamContext {
|
|||
),
|
||||
];
|
||||
|
||||
let response_str = to_server_events(chunks);
|
||||
let mut response_str = to_server_events(chunks);
|
||||
// append the original response from the model to the stream
|
||||
response_str.push_str(&body_utf8);
|
||||
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||
self.tool_calls = None;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -76,7 +76,7 @@ async def default_target(req: DefaultTargetRequest, res: Response):
|
|||
"choices": [
|
||||
{
|
||||
"message": {
|
||||
"role": "user",
|
||||
"role": "assistant",
|
||||
"content": "I can help you with weather forecast or insurance claim details",
|
||||
},
|
||||
"finish_reason": "completed",
|
||||
|
|
|
|||
|
|
@ -7,10 +7,16 @@ listener:
|
|||
connect_timeout: 0.005s
|
||||
|
||||
llm_providers:
|
||||
- name: gpt-3.5
|
||||
- name: gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-3.5-turbo
|
||||
model: gpt-4o-mini
|
||||
default: true
|
||||
|
||||
- name: gpt-3.5-turbo-0125
|
||||
access_key: $OPENAI_API_KEY
|
||||
provider: openai
|
||||
model: gpt-3.5-turbo-0125
|
||||
|
||||
- name: gpt-4o
|
||||
access_key: $OPENAI_API_KEY
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
|
||||
|
|
@ -7,6 +8,7 @@ PROMPT_GATEWAY_ENDPOINT = os.getenv(
|
|||
LLM_GATEWAY_ENDPOINT = os.getenv(
|
||||
"LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions"
|
||||
)
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
|
||||
|
||||
def get_data_chunks(stream, n=1):
|
||||
|
|
@ -22,3 +24,19 @@ def get_data_chunks(stream, n=1):
|
|||
if len(chunks) >= n:
|
||||
break
|
||||
return chunks
|
||||
|
||||
|
||||
def get_arch_messages(response_json):
|
||||
arch_messages = []
|
||||
if response_json and "metadata" in response_json:
|
||||
# load arch_state from metadata
|
||||
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
|
||||
# parse arch_state into json object
|
||||
arch_state = json.loads(arch_state_str)
|
||||
# load messages from arch_state
|
||||
arch_messages_str = arch_state.get("messages", "[]")
|
||||
# parse messages into json object
|
||||
arch_messages = json.loads(arch_messages_str)
|
||||
# append messages from arch gateway to history
|
||||
return arch_messages
|
||||
return []
|
||||
|
|
|
|||
31
e2e_tests/poetry.lock
generated
31
e2e_tests/poetry.lock
generated
|
|
@ -311,6 +311,24 @@ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.1
|
|||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "deepdiff"
|
||||
version = "8.0.1"
|
||||
description = "Deep Difference and Search of any Python object/data. Recreate objects by adding adding deltas to each other."
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "deepdiff-8.0.1-py3-none-any.whl", hash = "sha256:42e99004ce603f9a53934c634a57b04ad5900e0d8ed0abb15e635767489cbc05"},
|
||||
{file = "deepdiff-8.0.1.tar.gz", hash = "sha256:245599a4586ab59bb599ca3517a9c42f3318ff600ded5e80a3432693c8ec3c4b"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
orderly-set = "5.2.2"
|
||||
|
||||
[package.extras]
|
||||
cli = ["click (==8.1.7)", "pyyaml (==6.0.1)"]
|
||||
optimize = ["orjson"]
|
||||
|
||||
[[package]]
|
||||
name = "exceptiongroup"
|
||||
version = "1.2.2"
|
||||
|
|
@ -361,6 +379,17 @@ files = [
|
|||
{file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "orderly-set"
|
||||
version = "5.2.2"
|
||||
description = "Orderly set"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "orderly_set-5.2.2-py3-none-any.whl", hash = "sha256:f7a37c95a38c01cdfe41c3ffb62925a318a2286ea0a41790c057fc802aec54da"},
|
||||
{file = "orderly_set-5.2.2.tar.gz", hash = "sha256:52a18b86aaf3f5d5a498bbdb27bf3253a4e5c57ab38e5b7a56fa00115cd28448"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "outcome"
|
||||
version = "1.3.0.post0"
|
||||
|
|
@ -670,4 +699,4 @@ h11 = ">=0.9.0,<1"
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "298580707d8fb1a66f6d363810a710aff42e66d25eb02f6ecb50fba428e76851"
|
||||
content-hash = "6ae4fa6397091b87b63698201a08d7d97628ed65992d46514f118768b46b99ce"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ pytest = "^7.3.1"
|
|||
requests = "^2.29.0"
|
||||
selenium = "^4.11.2"
|
||||
pytest-sugar = "^1.0.0"
|
||||
deepdiff = "^8.0.1"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest-cov = "^4.1.0"
|
||||
|
|
|
|||
|
|
@ -1,12 +1,18 @@
|
|||
import json
|
||||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from common import PROMPT_GATEWAY_ENDPOINT, get_data_chunks
|
||||
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway(stream):
|
||||
expected_tool_call = {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "seattle", "days": 10},
|
||||
}
|
||||
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
|
|
@ -19,14 +25,56 @@ def test_prompt_gateway(stream):
|
|||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
|
||||
assert response.status_code == 200
|
||||
if stream:
|
||||
chunks = get_data_chunks(response)
|
||||
assert len(chunks) > 0
|
||||
chunks = get_data_chunks(response, n=20)
|
||||
assert len(chunks) > 2
|
||||
|
||||
# first chunk is tool calls (role = assistant)
|
||||
response_json = json.loads(chunks[0])
|
||||
# if its streaming we return tool call and api call in first two chunks
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
# second chunk is api call result (role = tool)
|
||||
response_json = json.loads(chunks[1])
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "tool"
|
||||
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
assert response_json.get("model").startswith("gpt-4o-mini")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("gpt-4o-mini")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
|
|
@ -46,8 +94,13 @@ def test_prompt_gateway_arch_direct_response(stream):
|
|||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
# if its streaming we return tool call and api call in first two chunks
|
||||
# make sure arch responded directly
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
# and tool call is null
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) == 0
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
|
|
@ -74,8 +127,13 @@ def test_prompt_gateway_param_gathering(stream):
|
|||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
# if its streaming we return tool call and api call in first two chunks
|
||||
# make sure arch responded directly
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
# and tool call is null
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) == 0
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
|
|
@ -91,6 +149,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
"name": "weather_forecast",
|
||||
"arguments": {"city": "seattle", "days": 2},
|
||||
}
|
||||
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
|
|
@ -112,34 +171,56 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
|
||||
assert response.status_code == 200
|
||||
if stream:
|
||||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
chunks = get_data_chunks(response, n=20)
|
||||
assert len(chunks) > 2
|
||||
|
||||
# first chunk is tool calls
|
||||
response_json = json.loads(chunks[0].lower())
|
||||
assert response_json.get("model").startswith("arch")
|
||||
# first chunk is tool calls (role = assistant)
|
||||
response_json = json.loads(chunks[0])
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
assert tool_call == expected_tool_call
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
# second chunk is api call result
|
||||
# second chunk is api call result (role = tool)
|
||||
response_json = json.loads(chunks[1])
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "tool"
|
||||
|
||||
# third..end chunk is summarization
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
# if its streaming we return tool call and api call in first two chunks
|
||||
assert response_json.get("model").startswith("gpt-4o-mini")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("gpt-4o-mini")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
|
|
@ -160,6 +241,9 @@ def test_prompt_gateway_default_target(stream):
|
|||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices", [])) > 0
|
||||
assert response_json.get("choices")[0]["delta"]["role"] == "assistant"
|
||||
|
||||
response_json = json.loads(chunks[1])
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
|
|
@ -170,3 +254,9 @@ def test_prompt_gateway_default_target(stream):
|
|||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices")) > 0
|
||||
assert response_json.get("choices")[0]["message"]["role"] == "assistant"
|
||||
assert (
|
||||
response_json.get("choices")[0]["message"]["content"]
|
||||
== "I can help you with weather forecast or insurance claim details"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue