This commit is contained in:
Adil Hafeez 2024-10-27 23:30:15 -07:00
parent b580dce9cc
commit bf06bf358d
3 changed files with 54 additions and 74 deletions

View file

@ -244,7 +244,6 @@ pub mod open_ai {
pub metadata: Option<HashMap<String, String>>,
}
// create constructor for ChatCompletionsResponse
impl ChatCompletionsResponse {
pub fn new(message: String) -> Self {
ChatCompletionsResponse {
@ -279,14 +278,19 @@ pub mod open_ai {
}
impl ChatCompletionStreamResponse {
pub fn new(response: Option<String>, role: Option<String>, model: Option<String>) -> Self {
pub fn new(
response: Option<String>,
role: Option<String>,
model: Option<String>,
tool_calls: Option<Vec<ToolCall>>,
) -> Self {
ChatCompletionStreamResponse {
model,
choices: vec![ChunkChoice {
delta: Delta {
role,
content: response,
tool_calls: None,
tool_calls,
model: None,
tool_call_id: None,
},
@ -374,6 +378,16 @@ pub mod open_ai {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]

View file

@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
use common::{
common_types::{
open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChunkChoice, Delta,
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
},
PromptGuardRequest, PromptGuardTask,
},
@ -250,7 +250,7 @@ impl HttpContext for StreamContext {
Some(chunk) => chunk,
None => {
warn!(
"response body empy, chunk_start: {}, chunk_size: {}",
"response body empty, chunk_start: {}, chunk_size: {}",
0, body_size
);
return Action::Continue;
@ -288,44 +288,24 @@ impl HttpContext for StreamContext {
if self.streaming_response {
trace!("streaming response");
if self.tool_calls.is_some() {
let tool_call_chunk = ChatCompletionStreamResponse {
model: Some(ARCH_FC_MODEL_NAME.to_string()),
choices: vec![ChunkChoice {
delta: Delta {
role: Some(ASSISTANT_ROLE.to_string()),
tool_calls: self.tool_calls.to_owned(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_call_id: None,
},
finish_reason: None,
}],
};
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_string()),
self.tool_calls.to_owned(),
),
ChatCompletionStreamResponse::new(
self.tool_call_response.clone(),
Some(TOOL_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_string()),
None,
),
];
let tool_call_chunk_str = serde_json::to_string(&tool_call_chunk).unwrap();
let api_call_chunk = ChatCompletionStreamResponse {
model: None,
choices: vec![ChunkChoice {
delta: Delta {
role: Some(TOOL_ROLE.to_string()),
tool_calls: None,
content: self.tool_call_response.clone(),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_call_id: None,
},
finish_reason: None,
}],
};
let api_call_chunk_str = serde_json::to_string(&api_call_chunk).unwrap();
let chunk_str = format!(
"data: {}\n\ndata: {}\n\n{}",
tool_call_chunk_str, api_call_chunk_str, body_utf8
);
self.set_http_response_body(0, body_size, chunk_str.as_bytes());
let response_str = to_server_events(chunks);
self.set_http_response_body(0, body_size, response_str.as_bytes());
self.tool_calls = None;
}
} else if let Some(tool_calls) = self.tool_calls.as_ref() {

View file

@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
ParameterType, ToolCall, ToolType,
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -333,26 +333,22 @@ impl StreamContext {
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
let response_str = if self.streaming_response {
let chunks = [
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(response),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
to_server_events(chunks)
} else {
let chat_completion_response = ChatCompletionsResponse::new(response);
serde_json::to_string(&chat_completion_response).unwrap()
@ -636,9 +632,9 @@ impl StreamContext {
};
arch_fc_response.choices[0]
.message
.tool_calls
.clone_into(&mut self.tool_calls);
.message
.tool_calls
.clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
@ -655,11 +651,12 @@ impl StreamContext {
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
let direct_response_str = if self.streaming_response {
let chunks = [
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(
@ -672,22 +669,15 @@ impl StreamContext {
),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
to_server_events(chunks)
} else {
body_str
};
if self.streaming_response {}
self.tool_calls = None;
return self.send_http_response(
StatusCode::OK.as_u16().into(),
@ -1005,26 +995,22 @@ impl StreamContext {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
let chunks = [
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(chat_completion_response.model.clone()),
None,
),
ChatCompletionStreamResponse::new(
chat_completion_response.choices[0].message.content.clone(),
None,
Some(chat_completion_response.model.clone()),
None,
),
];
let mut response_str = String::new();
for chunk in chunks.iter() {
response_str.push_str("data: ");
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
response_str.push_str("\n\n");
}
response_str
to_server_events(chunks)
} else {
String::from_utf8(body).unwrap()
};