From d1a745c3248d05066c881d0105bd1a165f6759f7 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 14 Apr 2025 10:46:42 -0700 Subject: [PATCH] fix e2e tests --- tests/e2e/test_prompt_gateway.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index f70f2a33..87459f57 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -2,6 +2,7 @@ import json import pytest import requests from deepdiff import DeepDiff +import re from common import ( PROMPT_GATEWAY_ENDPOINT, @@ -11,6 +12,15 @@ from common import ( ) +def cleanup_tool_call(tool_call): + pattern = r"```json\n(.*?)\n```" + match = re.search(pattern, tool_call, re.DOTALL) + if match: + tool_call = match.group(1) + + return tool_call.strip() + + @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway(stream): expected_tool_call = { @@ -44,7 +54,10 @@ def test_prompt_gateway(stream): assert role == "assistant" print(f"choices: {choices}") tool_call_str = choices[0].get("delta", {}).get("content", "") - tool_calls = json.loads(tool_call_str).get("tool_calls", []) + print("tool_call_str: ", tool_call_str) + cleaned_tool_call_str = cleanup_tool_call(tool_call_str) + print("cleaned_tool_call_str: ", cleaned_tool_call_str) + tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", []) assert len(tool_calls) > 0 tool_call = tool_calls[0] location = tool_call["arguments"]["location"] @@ -80,11 +93,17 @@ def test_prompt_gateway(stream): assert choices[0]["message"]["role"] == "assistant" # now verify arch_messages (tool call and api response) that are sent as response metadata arch_messages = get_arch_messages(response_json) + print("arch_messages: ", json.dumps(arch_messages)) assert len(arch_messages) == 2 tool_calls_message = arch_messages[0] - tool_calls = tool_calls_message.get("tool_calls", []) - assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + print("tool_calls_message: ", tool_calls_message) + tool_calls = tool_calls_message.get("content", []) + cleaned_tool_call_str = cleanup_tool_call(tool_calls) + cleaned_tool_call_json = json.loads(cleaned_tool_call_str) + print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json)) + tool_calls_list = cleaned_tool_call_json.get("tool_calls", []) + assert len(tool_calls_list) > 0 + tool_call = tool_calls_list[0] location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"]