mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed bugs in function_calling.rs that were breaking tests. All good now
This commit is contained in:
parent
60e489099d
commit
1f5784a9ff
5 changed files with 71 additions and 15 deletions
|
|
@ -596,7 +596,7 @@ static_resources:
|
|||
clusters:
|
||||
|
||||
- name: arch
|
||||
connect_timeout: 0.5s
|
||||
connect_timeout: 5s
|
||||
type: LOGICAL_DNS
|
||||
dns_lookup_family: V4_ONLY
|
||||
lb_policy: ROUND_ROBIN
|
||||
|
|
|
|||
|
|
@ -987,10 +987,13 @@ impl ArchFunctionHandler {
|
|||
|
||||
info!("[arch-fc]: raw model response: {}", response_dict.raw_response);
|
||||
|
||||
// General model response (no intent matched - should route to default target)
|
||||
let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) {
|
||||
// When arch-fc returns a "response" field, it means no intent was matched
|
||||
// Return empty content and empty tool_calls so prompt_gateway routes to default target
|
||||
ResponseMessage {
|
||||
role: Role::Assistant,
|
||||
content: response_dict.response.clone(),
|
||||
content: Some(String::new()),
|
||||
refusal: None,
|
||||
annotations: None,
|
||||
audio: None,
|
||||
|
|
@ -1105,6 +1108,14 @@ impl ArchFunctionHandler {
|
|||
}
|
||||
};
|
||||
|
||||
// Create metadata with the raw model response
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
"x-arch-fc-model-response".to_string(),
|
||||
serde_json::to_value(&response_dict.raw_response)
|
||||
.unwrap_or_else(|_| Value::String(response_dict.raw_response.clone())),
|
||||
);
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
id: format!("chatcmpl-{}", uuid::Uuid::new_v4()),
|
||||
object: Some("chat.completion".to_string()),
|
||||
|
|
@ -1125,7 +1136,7 @@ impl ArchFunctionHandler {
|
|||
},
|
||||
system_fingerprint: None,
|
||||
service_tier: None,
|
||||
metadata: None,
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
info!("[response arch-fc]: {:?}", chat_completion_response);
|
||||
|
|
|
|||
|
|
@ -264,13 +264,6 @@ impl StreamContext {
|
|||
.tool_calls
|
||||
.clone_into(&mut self.tool_calls);
|
||||
|
||||
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||
warn!(
|
||||
"multiple tool calls not supported yet, tool_calls count found: {}",
|
||||
self.tool_calls.as_ref().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
// This means that Arch FC did not have enough information to resolve the function call
|
||||
// Arch FC probably responded with a message asking for more information.
|
||||
|
|
@ -314,6 +307,14 @@ impl StreamContext {
|
|||
);
|
||||
}
|
||||
|
||||
// At this point, we know tool_calls is not None and not empty
|
||||
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||
warn!(
|
||||
"multiple tool calls not supported yet, tool_calls count found: {}",
|
||||
self.tool_calls.as_ref().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
// update prompt target name from the tool call response
|
||||
callout_context.prompt_target_name =
|
||||
Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone());
|
||||
|
|
|
|||
|
|
@ -22,6 +22,28 @@ from common import (
|
|||
)
|
||||
|
||||
|
||||
def normalize_tool_call_arguments(tool_call):
|
||||
"""
|
||||
Normalize tool call arguments to ensure they are always a dict.
|
||||
|
||||
According to OpenAI API spec, the 'arguments' field should be a JSON string,
|
||||
but for easier testing we parse it into a dict here.
|
||||
|
||||
Args:
|
||||
tool_call: A tool call dict that may have 'arguments' as either a string or dict
|
||||
|
||||
Returns:
|
||||
A tool call dict with 'arguments' guaranteed to be a dict
|
||||
"""
|
||||
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
|
||||
try:
|
||||
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If parsing fails, keep it as is
|
||||
pass
|
||||
return tool_call
|
||||
|
||||
|
||||
def test_prompt_gateway(httpserver: HTTPServer):
|
||||
simple_fixture = TEST_CASE_FIXTURES["SIMPLE"]
|
||||
input = simple_fixture["input"]
|
||||
|
|
@ -67,7 +89,7 @@ def test_prompt_gateway(httpserver: HTTPServer):
|
|||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,28 @@ def cleanup_tool_call(tool_call):
|
|||
return tool_call.strip()
|
||||
|
||||
|
||||
def normalize_tool_call_arguments(tool_call):
|
||||
"""
|
||||
Normalize tool call arguments to ensure they are always a dict.
|
||||
|
||||
According to OpenAI API spec, the 'arguments' field should be a JSON string,
|
||||
but for easier testing we parse it into a dict here.
|
||||
|
||||
Args:
|
||||
tool_call: A tool call dict that may have 'arguments' as either a string or dict
|
||||
|
||||
Returns:
|
||||
A tool call dict with 'arguments' guaranteed to be a dict
|
||||
"""
|
||||
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
|
||||
try:
|
||||
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If parsing fails, keep it as is
|
||||
pass
|
||||
return tool_call
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway(stream):
|
||||
expected_tool_call = {
|
||||
|
|
@ -62,7 +84,7 @@ def test_prompt_gateway(stream):
|
|||
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
|
||||
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0])
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -106,7 +128,7 @@ def test_prompt_gateway(stream):
|
|||
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
|
||||
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
|
||||
assert len(tool_calls_list) > 0
|
||||
tool_call = tool_calls_list[0]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls_list[0])
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -241,7 +263,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
@ -275,7 +297,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue