fixed bugs in function_calling.rs that were breaking tests. All good now

This commit is contained in:
Salman Paracha 2025-11-15 14:52:55 -08:00
parent 60e489099d
commit 1f5784a9ff
5 changed files with 71 additions and 15 deletions

View file

@ -22,6 +22,28 @@ from common import (
)
def normalize_tool_call_arguments(tool_call):
"""
Normalize tool call arguments to ensure they are always a dict.
According to OpenAI API spec, the 'arguments' field should be a JSON string,
but for easier testing we parse it into a dict here.
Args:
tool_call: A tool call dict that may have 'arguments' as either a string or dict
Returns:
A tool call dict with 'arguments' guaranteed to be a dict
"""
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
try:
tool_call["arguments"] = json.loads(tool_call["arguments"])
except (json.JSONDecodeError, TypeError):
# If parsing fails, keep it as is
pass
return tool_call
def test_prompt_gateway(httpserver: HTTPServer):
simple_fixture = TEST_CASE_FIXTURES["SIMPLE"]
input = simple_fixture["input"]
@ -67,7 +89,7 @@ def test_prompt_gateway(httpserver: HTTPServer):
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff

View file

@ -24,6 +24,28 @@ def cleanup_tool_call(tool_call):
return tool_call.strip()
def normalize_tool_call_arguments(tool_call):
"""
Normalize tool call arguments to ensure they are always a dict.
According to OpenAI API spec, the 'arguments' field should be a JSON string,
but for easier testing we parse it into a dict here.
Args:
tool_call: A tool call dict that may have 'arguments' as either a string or dict
Returns:
A tool call dict with 'arguments' guaranteed to be a dict
"""
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
try:
tool_call["arguments"] = json.loads(tool_call["arguments"])
except (json.JSONDecodeError, TypeError):
# If parsing fails, keep it as is
pass
return tool_call
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway(stream):
expected_tool_call = {
@ -62,7 +84,7 @@ def test_prompt_gateway(stream):
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]
tool_call = normalize_tool_call_arguments(tool_calls[0])
location = tool_call["arguments"]["location"]
assert expected_tool_call["arguments"]["location"] in location.lower()
del expected_tool_call["arguments"]["location"]
@ -106,7 +128,7 @@ def test_prompt_gateway(stream):
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
assert len(tool_calls_list) > 0
tool_call = tool_calls_list[0]
tool_call = normalize_tool_call_arguments(tool_calls_list[0])
location = tool_call["arguments"]["location"]
assert expected_tool_call["arguments"]["location"] in location.lower()
del expected_tool_call["arguments"]["location"]
@ -241,7 +263,7 @@ def test_prompt_gateway_param_tool_call(stream):
assert role == "assistant"
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff
@ -275,7 +297,7 @@ def test_prompt_gateway_param_tool_call(stream):
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
assert not diff