diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index f122ad30..d65aae08 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -42,9 +42,11 @@ def test_prompt_gateway(stream): assert "role" in choices[0]["delta"] role = choices[0]["delta"]["role"] assert role == "assistant" - tool_calls = choices[0].get("delta", {}).get("tool_calls", []) + print(f"choices: {choices}") + tool_call_str = choices[0].get("delta", {}).get("content", "") + tool_calls = json.loads(tool_call_str).get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = tool_calls[0] location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"]