mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 09:16:24 +02:00
refactor demos (#398)
This commit is contained in:
parent
2bd61d628c
commit
b3c95a6698
93 changed files with 338 additions and 1042 deletions
60
demos/shared/test_runner/test_demos.py
Normal file
60
demos/shared/test_runner/test_demos.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
import os
|
||||
from common import get_arch_messages
|
||||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
ARCHGW_ENDPOINT = os.getenv(
|
||||
"ARCHGW_ENDPOINT", "http://localhost:10000/v1/chat/completions"
|
||||
)
|
||||
|
||||
# Load test data from YAML file
|
||||
with open(os.getenv("TEST_DATA", "test_data.yaml"), "r") as file:
|
||||
test_data_yaml = yaml.safe_load(file)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
pytest.param(test_case, id=test_case["id"])
|
||||
for test_case in test_data_yaml["test_cases"]
|
||||
],
|
||||
)
|
||||
def test_demos(test_data):
|
||||
input = test_data["input"]
|
||||
expected_tools = test_data["expected_tools"]
|
||||
expected_output_contains = test_data["expected_output_contains"]
|
||||
|
||||
response = requests.post(ARCHGW_ENDPOINT, json=input)
|
||||
assert response.status_code == 200
|
||||
# ensure that response is json
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
|
||||
# ensure that model responded according to the expectation
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
assert expected_output_contains.lower() in choices[0]["message"]["content"].lower()
|
||||
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
|
||||
# remove dynamic id from tool_calls
|
||||
for tool_call in tool_calls:
|
||||
tool_call.pop("id", None)
|
||||
diff = DeepDiff(expected_tools, tool_calls, ignore_string_case=True)
|
||||
assert not diff
|
||||
Loading…
Add table
Add a link
Reference in a new issue