fix e2e tests

This commit is contained in:
Adil Hafeez 2025-04-14 10:46:42 -07:00
parent 750a162856
commit d1a745c324
No known key found for this signature in database
GPG key ID: 9B18EF7691369645

View file

@ -2,6 +2,7 @@ import json
import pytest
import requests
from deepdiff import DeepDiff
import re
from common import (
PROMPT_GATEWAY_ENDPOINT,
@ -11,6 +12,15 @@ from common import (
)
def cleanup_tool_call(tool_call):
pattern = r"```json\n(.*?)\n```"
match = re.search(pattern, tool_call, re.DOTALL)
if match:
tool_call = match.group(1)
return tool_call.strip()
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway(stream):
expected_tool_call = {
@ -44,7 +54,10 @@ def test_prompt_gateway(stream):
assert role == "assistant"
print(f"choices: {choices}")
tool_call_str = choices[0].get("delta", {}).get("content", "")
tool_calls = json.loads(tool_call_str).get("tool_calls", [])
print("tool_call_str: ", tool_call_str)
cleaned_tool_call_str = cleanup_tool_call(tool_call_str)
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]
location = tool_call["arguments"]["location"]
@ -80,11 +93,17 @@ def test_prompt_gateway(stream):
assert choices[0]["message"]["role"] == "assistant"
# now verify arch_messages (tool call and api response) that are sent as response metadata
arch_messages = get_arch_messages(response_json)
print("arch_messages: ", json.dumps(arch_messages))
assert len(arch_messages) == 2
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
print("tool_calls_message: ", tool_calls_message)
tool_calls = tool_calls_message.get("content", [])
cleaned_tool_call_str = cleanup_tool_call(tool_calls)
cleaned_tool_call_json = json.loads(cleaned_tool_call_str)
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
assert len(tool_calls_list) > 0
tool_call = tool_calls_list[0]
location = tool_call["arguments"]["location"]
assert expected_tool_call["arguments"]["location"] in location.lower()
del expected_tool_call["arguments"]["location"]