diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index d152e16e..35404096 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -244,7 +244,6 @@ pub mod open_ai { pub metadata: Option>, } - // create constructor for ChatCompletionsResponse impl ChatCompletionsResponse { pub fn new(message: String) -> Self { ChatCompletionsResponse { @@ -279,14 +278,19 @@ pub mod open_ai { } impl ChatCompletionStreamResponse { - pub fn new(response: Option, role: Option, model: Option) -> Self { + pub fn new( + response: Option, + role: Option, + model: Option, + tool_calls: Option>, + ) -> Self { ChatCompletionStreamResponse { model, choices: vec![ChunkChoice { delta: Delta { role, content: response, - tool_calls: None, + tool_calls, model: None, tool_call_id: None, }, @@ -374,6 +378,16 @@ pub mod open_ai { #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, } + + pub fn to_server_events(chunks: Vec) -> String { + let mut response_str = String::new(); + for chunk in chunks.iter() { + response_str.push_str("data: "); + response_str.push_str(&serde_json::to_string(&chunk).unwrap()); + response_str.push_str("\n\n"); + } + response_str + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index 69bdb531..9871cc63 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration}; use common::{ common_types::{ open_ai::{ - ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChunkChoice, Delta, + to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, }, PromptGuardRequest, PromptGuardTask, }, @@ -250,7 +250,7 @@ impl HttpContext for StreamContext { Some(chunk) => chunk, None => { warn!( - "response body empy, chunk_start: {}, chunk_size: {}", + "response body empty, chunk_start: {}, chunk_size: {}", 0, body_size ); return Action::Continue; @@ -288,44 +288,24 @@ impl HttpContext for StreamContext { if self.streaming_response { trace!("streaming response"); - if self.tool_calls.is_some() { - let tool_call_chunk = ChatCompletionStreamResponse { - model: Some(ARCH_FC_MODEL_NAME.to_string()), - choices: vec![ChunkChoice { - delta: Delta { - role: Some(ASSISTANT_ROLE.to_string()), - tool_calls: self.tool_calls.to_owned(), - content: None, - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_call_id: None, - }, - finish_reason: None, - }], - }; + if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() { + let chunks = vec![ + ChatCompletionStreamResponse::new( + None, + Some(ASSISTANT_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_string()), + self.tool_calls.to_owned(), + ), + ChatCompletionStreamResponse::new( + self.tool_call_response.clone(), + Some(TOOL_ROLE.to_string()), + Some(ARCH_FC_MODEL_NAME.to_string()), + None, + ), + ]; - let tool_call_chunk_str = serde_json::to_string(&tool_call_chunk).unwrap(); - - let api_call_chunk = ChatCompletionStreamResponse { - model: None, - choices: vec![ChunkChoice { - delta: Delta { - role: Some(TOOL_ROLE.to_string()), - tool_calls: None, - content: self.tool_call_response.clone(), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_call_id: None, - }, - finish_reason: None, - }], - }; - - let api_call_chunk_str = serde_json::to_string(&api_call_chunk).unwrap(); - let chunk_str = format!( - "data: {}\n\ndata: {}\n\n{}", - tool_call_chunk_str, api_call_chunk_str, body_utf8 - ); - - self.set_http_response_body(0, body_size, chunk_str.as_bytes()); + let response_str = to_server_events(chunks); + self.set_http_response_body(0, body_size, response_str.as_bytes()); self.tool_calls = None; } } else if let Some(tool_calls) = self.tool_calls.as_ref() { diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index c4a15f47..5d79d181 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use crate::hallucination::extract_messages_for_hallucination; use acap::cos; use common::common_types::open_ai::{ - ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, - ParameterType, ToolCall, ToolType, + to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool, + ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter, + FunctionParameters, Message, ParameterType, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -333,26 +333,22 @@ impl StreamContext { HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?"; let response_str = if self.streaming_response { - let chunks = [ + let chunks = vec![ ChatCompletionStreamResponse::new( None, Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_owned()), + None, ), ChatCompletionStreamResponse::new( Some(response), None, Some(ARCH_FC_MODEL_NAME.to_owned()), + None, ), ]; - let mut response_str = String::new(); - for chunk in chunks.iter() { - response_str.push_str("data: "); - response_str.push_str(&serde_json::to_string(&chunk).unwrap()); - response_str.push_str("\n\n"); - } - response_str + to_server_events(chunks) } else { let chat_completion_response = ChatCompletionsResponse::new(response); serde_json::to_string(&chat_completion_response).unwrap() @@ -636,9 +632,9 @@ impl StreamContext { }; arch_fc_response.choices[0] - .message - .tool_calls - .clone_into(&mut self.tool_calls); + .message + .tool_calls + .clone_into(&mut self.tool_calls); if self.tool_calls.as_ref().unwrap().len() > 1 { warn!( @@ -655,11 +651,12 @@ impl StreamContext { //TODO: add resolver name to the response so the client can send the response back to the correct resolver let direct_response_str = if self.streaming_response { - let chunks = [ + let chunks = vec![ ChatCompletionStreamResponse::new( None, Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_owned()), + None, ), ChatCompletionStreamResponse::new( Some( @@ -672,22 +669,15 @@ impl StreamContext { ), None, Some(ARCH_FC_MODEL_NAME.to_owned()), + None, ), ]; - let mut response_str = String::new(); - for chunk in chunks.iter() { - response_str.push_str("data: "); - response_str.push_str(&serde_json::to_string(&chunk).unwrap()); - response_str.push_str("\n\n"); - } - response_str + to_server_events(chunks) } else { body_str }; - if self.streaming_response {} - self.tool_calls = None; return self.send_http_response( StatusCode::OK.as_u16().into(), @@ -1005,26 +995,22 @@ impl StreamContext { let chat_completion_response = serde_json::from_slice::(&body).unwrap(); - let chunks = [ + let chunks = vec![ ChatCompletionStreamResponse::new( None, Some(ASSISTANT_ROLE.to_string()), Some(chat_completion_response.model.clone()), + None, ), ChatCompletionStreamResponse::new( chat_completion_response.choices[0].message.content.clone(), None, Some(chat_completion_response.model.clone()), + None, ), ]; - let mut response_str = String::new(); - for chunk in chunks.iter() { - response_str.push_str("data: "); - response_str.push_str(&serde_json::to_string(&chunk).unwrap()); - response_str.push_str("\n\n"); - } - response_str + to_server_events(chunks) } else { String::from_utf8(body).unwrap() };