diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index d6b95656..b1c4f487 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -596,7 +596,7 @@ static_resources: clusters: - name: arch - connect_timeout: 0.5s + connect_timeout: 5s type: LOGICAL_DNS dns_lookup_family: V4_ONLY lb_policy: ROUND_ROBIN diff --git a/crates/brightstaff/src/handlers/function_calling.rs b/crates/brightstaff/src/handlers/function_calling.rs index 6c853ad4..3d67ff91 100644 --- a/crates/brightstaff/src/handlers/function_calling.rs +++ b/crates/brightstaff/src/handlers/function_calling.rs @@ -987,10 +987,13 @@ impl ArchFunctionHandler { info!("[arch-fc]: raw model response: {}", response_dict.raw_response); + // General model response (no intent matched - should route to default target) let model_message = if response_dict.response.as_ref().map_or(false, |s| !s.is_empty()) { + // When arch-fc returns a "response" field, it means no intent was matched + // Return empty content and empty tool_calls so prompt_gateway routes to default target ResponseMessage { role: Role::Assistant, - content: response_dict.response.clone(), + content: Some(String::new()), refusal: None, annotations: None, audio: None, @@ -1105,6 +1108,14 @@ impl ArchFunctionHandler { } }; + // Create metadata with the raw model response + let mut metadata = HashMap::new(); + metadata.insert( + "x-arch-fc-model-response".to_string(), + serde_json::to_value(&response_dict.raw_response) + .unwrap_or_else(|_| Value::String(response_dict.raw_response.clone())), + ); + let chat_completion_response = ChatCompletionsResponse { id: format!("chatcmpl-{}", uuid::Uuid::new_v4()), object: Some("chat.completion".to_string()), @@ -1125,7 +1136,7 @@ impl ArchFunctionHandler { }, system_fingerprint: None, service_tier: None, - metadata: None, + metadata: Some(metadata), }; info!("[response arch-fc]: {:?}", chat_completion_response); diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9efbba21..1e9d507b 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -264,13 +264,6 @@ impl StreamContext { .tool_calls .clone_into(&mut self.tool_calls); - if self.tool_calls.as_ref().unwrap().len() > 1 { - warn!( - "multiple tool calls not supported yet, tool_calls count found: {}", - self.tool_calls.as_ref().unwrap().len() - ); - } - if self.tool_calls.is_none() || self.tool_calls.as_ref().unwrap().is_empty() { // This means that Arch FC did not have enough information to resolve the function call // Arch FC probably responded with a message asking for more information. @@ -314,6 +307,14 @@ impl StreamContext { ); } + // At this point, we know tool_calls is not None and not empty + if self.tool_calls.as_ref().unwrap().len() > 1 { + warn!( + "multiple tool calls not supported yet, tool_calls count found: {}", + self.tool_calls.as_ref().unwrap().len() + ); + } + // update prompt target name from the tool call response callout_context.prompt_target_name = Some(self.tool_calls.as_ref().unwrap()[0].function.name.clone()); diff --git a/tests/archgw/test_prompt_gateway.py b/tests/archgw/test_prompt_gateway.py index b207ebe0..0e1b4317 100644 --- a/tests/archgw/test_prompt_gateway.py +++ b/tests/archgw/test_prompt_gateway.py @@ -22,6 +22,28 @@ from common import ( ) +def normalize_tool_call_arguments(tool_call): + """ + Normalize tool call arguments to ensure they are always a dict. + + According to OpenAI API spec, the 'arguments' field should be a JSON string, + but for easier testing we parse it into a dict here. + + Args: + tool_call: A tool call dict that may have 'arguments' as either a string or dict + + Returns: + A tool call dict with 'arguments' guaranteed to be a dict + """ + if "arguments" in tool_call and isinstance(tool_call["arguments"], str): + try: + tool_call["arguments"] = json.loads(tool_call["arguments"]) + except (json.JSONDecodeError, TypeError): + # If parsing fails, keep it as is + pass + return tool_call + + def test_prompt_gateway(httpserver: HTTPServer): simple_fixture = TEST_CASE_FIXTURES["SIMPLE"] input = simple_fixture["input"] @@ -67,7 +89,7 @@ def test_prompt_gateway(httpserver: HTTPServer): tool_calls_message = arch_messages[0] tool_calls = tool_calls_message.get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = normalize_tool_call_arguments(tool_calls[0]["function"]) diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) assert not diff diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index 2edab55d..a55e740c 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -24,6 +24,28 @@ def cleanup_tool_call(tool_call): return tool_call.strip() +def normalize_tool_call_arguments(tool_call): + """ + Normalize tool call arguments to ensure they are always a dict. + + According to OpenAI API spec, the 'arguments' field should be a JSON string, + but for easier testing we parse it into a dict here. + + Args: + tool_call: A tool call dict that may have 'arguments' as either a string or dict + + Returns: + A tool call dict with 'arguments' guaranteed to be a dict + """ + if "arguments" in tool_call and isinstance(tool_call["arguments"], str): + try: + tool_call["arguments"] = json.loads(tool_call["arguments"]) + except (json.JSONDecodeError, TypeError): + # If parsing fails, keep it as is + pass + return tool_call + + @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway(stream): expected_tool_call = { @@ -62,7 +84,7 @@ def test_prompt_gateway(stream): print("cleaned_tool_call_str: ", cleaned_tool_call_str) tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0] + tool_call = normalize_tool_call_arguments(tool_calls[0]) location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"] @@ -106,7 +128,7 @@ def test_prompt_gateway(stream): print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json)) tool_calls_list = cleaned_tool_call_json.get("tool_calls", []) assert len(tool_calls_list) > 0 - tool_call = tool_calls_list[0] + tool_call = normalize_tool_call_arguments(tool_calls_list[0]) location = tool_call["arguments"]["location"] assert expected_tool_call["arguments"]["location"] in location.lower() del expected_tool_call["arguments"]["location"] @@ -241,7 +263,7 @@ def test_prompt_gateway_param_tool_call(stream): assert role == "assistant" tool_calls = choices[0].get("delta", {}).get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = normalize_tool_call_arguments(tool_calls[0]["function"]) diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) assert not diff @@ -275,7 +297,7 @@ def test_prompt_gateway_param_tool_call(stream): tool_calls_message = arch_messages[0] tool_calls = tool_calls_message.get("tool_calls", []) assert len(tool_calls) > 0 - tool_call = tool_calls[0]["function"] + tool_call = normalize_tool_call_arguments(tool_calls[0]["function"]) diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True) assert not diff