diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 80b58610..6ababa70 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,8 +3,8 @@ use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, FunctionDefinition, FunctionParameter, - FunctionParameters, Message, ParameterType, ToolCall, ToolType, + ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, + ParameterType, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -636,9 +636,10 @@ impl StreamContext { }; arch_fc_response.choices[0] - .message - .tool_calls - .clone_into(&mut self.tool_calls); + .message + .tool_calls + .clone_into(&mut self.tool_calls); + if self.tool_calls.as_ref().unwrap().len() > 1 { warn!( "multiple tool calls not supported yet, tool_calls count found: {}", @@ -653,10 +654,44 @@ impl StreamContext { //TODO: add resolver name to the response so the client can send the response back to the correct resolver + let direct_response_str = if self.streaming_response { + let chunks = [ + ChatCompletionStreamResponse::new( + None, + Some(ASSISTANT_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_owned()), + ), + ChatCompletionStreamResponse::new( + Some( + arch_fc_response.choices[0] + .message + .content + .as_ref() + .unwrap() + .clone(), + ), + None, + Some(ARCH_FC_MODEL_NAME.to_owned()), + ), + ]; + + let mut response_str = String::new(); + for chunk in chunks.iter() { + response_str.push_str("data: "); + response_str.push_str(&serde_json::to_string(&chunk).unwrap()); + response_str.push_str("\n\n"); + } + response_str + } else { + body_str + }; + + if self.streaming_response {} + return self.send_http_response( StatusCode::OK.as_u16().into(), vec![("Powered-By", "Katanemo")], - Some(body_str.as_bytes()), + Some(direct_response_str.as_bytes()), ); } @@ -989,8 +1024,7 @@ impl StreamContext { response_str.push_str("\n\n"); } response_str - } - else { + } else { String::from_utf8(body).unwrap() }; diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 203965cf..00ab20ad 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -29,6 +29,34 @@ def test_prompt_gateway(stream): assert response_json.get("model").startswith("gpt-4o-mini") +@pytest.mark.parametrize("stream", [True, False]) +def test_prompt_gateway_arch_direct_response(stream): + body = { + "messages": [ + { + "role": "user", + "content": "how is the weather", + } + ], + "stream": stream, + } + response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream) + assert response.status_code == 200 + if stream: + chunks = get_data_chunks(response, n=3) + assert len(chunks) > 0 + response_json = json.loads(chunks[0]) + # if its streaming we return tool call and api call in first two chunks + assert response_json.get("model").startswith("Arch") + else: + response_json = response.json() + assert response_json.get("model").startswith("Arch") + choices = response_json.get("choices", []) + assert len(choices) > 0 + message = choices[0]["message"]["content"] + assert "Could you provide the following details days" not in message + + @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway_param_gathering(stream): body = {