This commit is contained in:
Adil Hafeez 2024-10-27 23:30:15 -07:00
parent b580dce9cc
commit bf06bf358d
3 changed files with 54 additions and 74 deletions

View file

@ -244,7 +244,6 @@ pub mod open_ai {
pub metadata: Option<HashMap<String, String>>, pub metadata: Option<HashMap<String, String>>,
} }
// create constructor for ChatCompletionsResponse
impl ChatCompletionsResponse { impl ChatCompletionsResponse {
pub fn new(message: String) -> Self { pub fn new(message: String) -> Self {
ChatCompletionsResponse { ChatCompletionsResponse {
@ -279,14 +278,19 @@ pub mod open_ai {
} }
impl ChatCompletionStreamResponse { impl ChatCompletionStreamResponse {
pub fn new(response: Option<String>, role: Option<String>, model: Option<String>) -> Self { pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse { ChatCompletionStreamResponse {
model, model,
choices: vec![ChunkChoice { choices: vec![ChunkChoice {
delta: Delta { delta: Delta {
role, role,
content: response, content: response,
tool_calls: None, tool_calls,
model: None, model: None,
tool_call_id: None, tool_call_id: None,
}, },
@ -374,6 +378,16 @@ pub mod open_ai {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>, pub tool_call_id: Option<String>,
} }
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
use common::{ use common::{
common_types::{ common_types::{
open_ai::{ open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChunkChoice, Delta, to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
}, },
PromptGuardRequest, PromptGuardTask, PromptGuardRequest, PromptGuardTask,
}, },
@ -250,7 +250,7 @@ impl HttpContext for StreamContext {
Some(chunk) => chunk, Some(chunk) => chunk,
None => { None => {
warn!( warn!(
"response body empy, chunk_start: {}, chunk_size: {}", "response body empty, chunk_start: {}, chunk_size: {}",
0, body_size 0, body_size
); );
return Action::Continue; return Action::Continue;
@ -288,44 +288,24 @@ impl HttpContext for StreamContext {
if self.streaming_response { if self.streaming_response {
trace!("streaming response"); trace!("streaming response");
if self.tool_calls.is_some() { if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
let tool_call_chunk = ChatCompletionStreamResponse { let chunks = vec![
model: Some(ARCH_FC_MODEL_NAME.to_string()), ChatCompletionStreamResponse::new(
choices: vec![ChunkChoice { None,
delta: Delta { Some(ASSISTANT_ROLE.to_string()),
role: Some(ASSISTANT_ROLE.to_string()), Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.to_owned(), self.tool_calls.to_owned(),
content: None, ),
model: Some(ARCH_FC_MODEL_NAME.to_string()), ChatCompletionStreamResponse::new(
tool_call_id: None, self.tool_call_response.clone(),
}, Some(TOOL_ROLE.to_string()),
finish_reason: None, Some(ARCH_FC_MODEL_NAME.to_string()),
}], None,
}; ),
];
let tool_call_chunk_str = serde_json::to_string(&tool_call_chunk).unwrap(); let response_str = to_server_events(chunks);
self.set_http_response_body(0, body_size, response_str.as_bytes());
let api_call_chunk = ChatCompletionStreamResponse {
model: None,
choices: vec![ChunkChoice {
delta: Delta {
role: Some(TOOL_ROLE.to_string()),
tool_calls: None,
content: self.tool_call_response.clone(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_call_id: None,
},
finish_reason: None,
}],
};
let api_call_chunk_str = serde_json::to_string(&api_call_chunk).unwrap();
let chunk_str = format!(
"data: {}\n\ndata: {}\n\n{}",
tool_call_chunk_str, api_call_chunk_str, body_utf8
);
self.set_http_response_body(0, body_size, chunk_str.as_bytes());
self.tool_calls = None; self.tool_calls = None;
} }
} else if let Some(tool_calls) = self.tool_calls.as_ref() { } else if let Some(tool_calls) = self.tool_calls.as_ref() {

View file

@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::hallucination::extract_messages_for_hallucination; use crate::hallucination::extract_messages_for_hallucination;
use acap::cos; use acap::cos;
use common::common_types::open_ai::{ use common::common_types::open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest, to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
ParameterType, ToolCall, ToolType, FunctionParameters, Message, ParameterType, ToolCall, ToolType,
}; };
use common::common_types::{ use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -333,26 +333,22 @@ impl StreamContext {
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?"; HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
let response_str = if self.streaming_response { let response_str = if self.streaming_response {
let chunks = [ let chunks = vec![
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
None, None,
Some(ASSISTANT_ROLE.to_string()), Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()), Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
), ),
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
Some(response), Some(response),
None, None,
Some(ARCH_FC_MODEL_NAME.to_owned()), Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
), ),
]; ];
let mut response_str = String::new(); to_server_events(chunks)
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
} else { } else {
let chat_completion_response = ChatCompletionsResponse::new(response); let chat_completion_response = ChatCompletionsResponse::new(response);
serde_json::to_string(&chat_completion_response).unwrap() serde_json::to_string(&chat_completion_response).unwrap()
@ -636,9 +632,9 @@ impl StreamContext {
}; };
arch_fc_response.choices[0] arch_fc_response.choices[0]
.message .message
.tool_calls .tool_calls
.clone_into(&mut self.tool_calls); .clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 { if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!( warn!(
@ -655,11 +651,12 @@ impl StreamContext {
//TODO: add resolver name to the response so the client can send the response back to the correct resolver //TODO: add resolver name to the response so the client can send the response back to the correct resolver
let direct_response_str = if self.streaming_response { let direct_response_str = if self.streaming_response {
let chunks = [ let chunks = vec![
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
None, None,
Some(ASSISTANT_ROLE.to_string()), Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()), Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
), ),
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
Some( Some(
@ -672,22 +669,15 @@ impl StreamContext {
), ),
None, None,
Some(ARCH_FC_MODEL_NAME.to_owned()), Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
), ),
]; ];
let mut response_str = String::new(); to_server_events(chunks)
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
} else { } else {
body_str body_str
}; };
if self.streaming_response {}
self.tool_calls = None; self.tool_calls = None;
return self.send_http_response( return self.send_http_response(
StatusCode::OK.as_u16().into(), StatusCode::OK.as_u16().into(),
@ -1005,26 +995,22 @@ impl StreamContext {
let chat_completion_response = let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap(); serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
let chunks = [ let chunks = vec![
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
None, None,
Some(ASSISTANT_ROLE.to_string()), Some(ASSISTANT_ROLE.to_string()),
Some(chat_completion_response.model.clone()), Some(chat_completion_response.model.clone()),
None,
), ),
ChatCompletionStreamResponse::new( ChatCompletionStreamResponse::new(
chat_completion_response.choices[0].message.content.clone(), chat_completion_response.choices[0].message.content.clone(),
None, None,
Some(chat_completion_response.model.clone()), Some(chat_completion_response.model.clone()),
None,
), ),
]; ];
let mut response_str = String::new(); to_server_events(chunks)
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
} else { } else {
String::from_utf8(body).unwrap() String::from_utf8(body).unwrap()
}; };