mirror of
https://github.com/katanemo/plano.git
synced 2026-06-23 15:38:07 +02:00
refactor
This commit is contained in:
parent
b580dce9cc
commit
bf06bf358d
3 changed files with 54 additions and 74 deletions
|
|
@ -244,7 +244,6 @@ pub mod open_ai {
|
||||||
pub metadata: Option<HashMap<String, String>>,
|
pub metadata: Option<HashMap<String, String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// create constructor for ChatCompletionsResponse
|
|
||||||
impl ChatCompletionsResponse {
|
impl ChatCompletionsResponse {
|
||||||
pub fn new(message: String) -> Self {
|
pub fn new(message: String) -> Self {
|
||||||
ChatCompletionsResponse {
|
ChatCompletionsResponse {
|
||||||
|
|
@ -279,14 +278,19 @@ pub mod open_ai {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ChatCompletionStreamResponse {
|
impl ChatCompletionStreamResponse {
|
||||||
pub fn new(response: Option<String>, role: Option<String>, model: Option<String>) -> Self {
|
pub fn new(
|
||||||
|
response: Option<String>,
|
||||||
|
role: Option<String>,
|
||||||
|
model: Option<String>,
|
||||||
|
tool_calls: Option<Vec<ToolCall>>,
|
||||||
|
) -> Self {
|
||||||
ChatCompletionStreamResponse {
|
ChatCompletionStreamResponse {
|
||||||
model,
|
model,
|
||||||
choices: vec![ChunkChoice {
|
choices: vec![ChunkChoice {
|
||||||
delta: Delta {
|
delta: Delta {
|
||||||
role,
|
role,
|
||||||
content: response,
|
content: response,
|
||||||
tool_calls: None,
|
tool_calls,
|
||||||
model: None,
|
model: None,
|
||||||
tool_call_id: None,
|
tool_call_id: None,
|
||||||
},
|
},
|
||||||
|
|
@ -374,6 +378,16 @@ pub mod open_ai {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_call_id: Option<String>,
|
pub tool_call_id: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn to_server_events(chunks: Vec<ChatCompletionStreamResponse>) -> String {
|
||||||
|
let mut response_str = String::new();
|
||||||
|
for chunk in chunks.iter() {
|
||||||
|
response_str.push_str("data: ");
|
||||||
|
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
||||||
|
response_str.push_str("\n\n");
|
||||||
|
}
|
||||||
|
response_str
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ use std::{collections::HashMap, time::Duration};
|
||||||
use common::{
|
use common::{
|
||||||
common_types::{
|
common_types::{
|
||||||
open_ai::{
|
open_ai::{
|
||||||
ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest, ChunkChoice, Delta,
|
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionsRequest,
|
||||||
},
|
},
|
||||||
PromptGuardRequest, PromptGuardTask,
|
PromptGuardRequest, PromptGuardTask,
|
||||||
},
|
},
|
||||||
|
|
@ -250,7 +250,7 @@ impl HttpContext for StreamContext {
|
||||||
Some(chunk) => chunk,
|
Some(chunk) => chunk,
|
||||||
None => {
|
None => {
|
||||||
warn!(
|
warn!(
|
||||||
"response body empy, chunk_start: {}, chunk_size: {}",
|
"response body empty, chunk_start: {}, chunk_size: {}",
|
||||||
0, body_size
|
0, body_size
|
||||||
);
|
);
|
||||||
return Action::Continue;
|
return Action::Continue;
|
||||||
|
|
@ -288,44 +288,24 @@ impl HttpContext for StreamContext {
|
||||||
if self.streaming_response {
|
if self.streaming_response {
|
||||||
trace!("streaming response");
|
trace!("streaming response");
|
||||||
|
|
||||||
if self.tool_calls.is_some() {
|
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
|
||||||
let tool_call_chunk = ChatCompletionStreamResponse {
|
let chunks = vec![
|
||||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
ChatCompletionStreamResponse::new(
|
||||||
choices: vec![ChunkChoice {
|
None,
|
||||||
delta: Delta {
|
Some(ASSISTANT_ROLE.to_string()),
|
||||||
role: Some(ASSISTANT_ROLE.to_string()),
|
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||||
tool_calls: self.tool_calls.to_owned(),
|
self.tool_calls.to_owned(),
|
||||||
content: None,
|
),
|
||||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
ChatCompletionStreamResponse::new(
|
||||||
tool_call_id: None,
|
self.tool_call_response.clone(),
|
||||||
},
|
Some(TOOL_ROLE.to_string()),
|
||||||
finish_reason: None,
|
Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||||
}],
|
None,
|
||||||
};
|
),
|
||||||
|
];
|
||||||
|
|
||||||
let tool_call_chunk_str = serde_json::to_string(&tool_call_chunk).unwrap();
|
let response_str = to_server_events(chunks);
|
||||||
|
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||||
let api_call_chunk = ChatCompletionStreamResponse {
|
|
||||||
model: None,
|
|
||||||
choices: vec![ChunkChoice {
|
|
||||||
delta: Delta {
|
|
||||||
role: Some(TOOL_ROLE.to_string()),
|
|
||||||
tool_calls: None,
|
|
||||||
content: self.tool_call_response.clone(),
|
|
||||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
|
||||||
tool_call_id: None,
|
|
||||||
},
|
|
||||||
finish_reason: None,
|
|
||||||
}],
|
|
||||||
};
|
|
||||||
|
|
||||||
let api_call_chunk_str = serde_json::to_string(&api_call_chunk).unwrap();
|
|
||||||
let chunk_str = format!(
|
|
||||||
"data: {}\n\ndata: {}\n\n{}",
|
|
||||||
tool_call_chunk_str, api_call_chunk_str, body_utf8
|
|
||||||
);
|
|
||||||
|
|
||||||
self.set_http_response_body(0, body_size, chunk_str.as_bytes());
|
|
||||||
self.tool_calls = None;
|
self.tool_calls = None;
|
||||||
}
|
}
|
||||||
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
|
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||||
use crate::hallucination::extract_messages_for_hallucination;
|
use crate::hallucination::extract_messages_for_hallucination;
|
||||||
use acap::cos;
|
use acap::cos;
|
||||||
use common::common_types::open_ai::{
|
use common::common_types::open_ai::{
|
||||||
ArchState, ChatCompletionStreamResponse, ChatCompletionTool, ChatCompletionsRequest,
|
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
|
||||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
|
||||||
ParameterType, ToolCall, ToolType,
|
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||||
};
|
};
|
||||||
use common::common_types::{
|
use common::common_types::{
|
||||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||||
|
|
@ -333,26 +333,22 @@ impl StreamContext {
|
||||||
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
|
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
|
||||||
|
|
||||||
let response_str = if self.streaming_response {
|
let response_str = if self.streaming_response {
|
||||||
let chunks = [
|
let chunks = vec![
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
None,
|
None,
|
||||||
Some(ASSISTANT_ROLE.to_string()),
|
Some(ASSISTANT_ROLE.to_string()),
|
||||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
Some(response),
|
Some(response),
|
||||||
None,
|
None,
|
||||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut response_str = String::new();
|
to_server_events(chunks)
|
||||||
for chunk in chunks.iter() {
|
|
||||||
response_str.push_str("data: ");
|
|
||||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
|
||||||
response_str.push_str("\n\n");
|
|
||||||
}
|
|
||||||
response_str
|
|
||||||
} else {
|
} else {
|
||||||
let chat_completion_response = ChatCompletionsResponse::new(response);
|
let chat_completion_response = ChatCompletionsResponse::new(response);
|
||||||
serde_json::to_string(&chat_completion_response).unwrap()
|
serde_json::to_string(&chat_completion_response).unwrap()
|
||||||
|
|
@ -636,9 +632,9 @@ impl StreamContext {
|
||||||
};
|
};
|
||||||
|
|
||||||
arch_fc_response.choices[0]
|
arch_fc_response.choices[0]
|
||||||
.message
|
.message
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.clone_into(&mut self.tool_calls);
|
.clone_into(&mut self.tool_calls);
|
||||||
|
|
||||||
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||||
warn!(
|
warn!(
|
||||||
|
|
@ -655,11 +651,12 @@ impl StreamContext {
|
||||||
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
||||||
|
|
||||||
let direct_response_str = if self.streaming_response {
|
let direct_response_str = if self.streaming_response {
|
||||||
let chunks = [
|
let chunks = vec![
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
None,
|
None,
|
||||||
Some(ASSISTANT_ROLE.to_string()),
|
Some(ASSISTANT_ROLE.to_string()),
|
||||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
Some(
|
Some(
|
||||||
|
|
@ -672,22 +669,15 @@ impl StreamContext {
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut response_str = String::new();
|
to_server_events(chunks)
|
||||||
for chunk in chunks.iter() {
|
|
||||||
response_str.push_str("data: ");
|
|
||||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
|
||||||
response_str.push_str("\n\n");
|
|
||||||
}
|
|
||||||
response_str
|
|
||||||
} else {
|
} else {
|
||||||
body_str
|
body_str
|
||||||
};
|
};
|
||||||
|
|
||||||
if self.streaming_response {}
|
|
||||||
|
|
||||||
self.tool_calls = None;
|
self.tool_calls = None;
|
||||||
return self.send_http_response(
|
return self.send_http_response(
|
||||||
StatusCode::OK.as_u16().into(),
|
StatusCode::OK.as_u16().into(),
|
||||||
|
|
@ -1005,26 +995,22 @@ impl StreamContext {
|
||||||
let chat_completion_response =
|
let chat_completion_response =
|
||||||
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
|
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
|
||||||
|
|
||||||
let chunks = [
|
let chunks = vec![
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
None,
|
None,
|
||||||
Some(ASSISTANT_ROLE.to_string()),
|
Some(ASSISTANT_ROLE.to_string()),
|
||||||
Some(chat_completion_response.model.clone()),
|
Some(chat_completion_response.model.clone()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
ChatCompletionStreamResponse::new(
|
ChatCompletionStreamResponse::new(
|
||||||
chat_completion_response.choices[0].message.content.clone(),
|
chat_completion_response.choices[0].message.content.clone(),
|
||||||
None,
|
None,
|
||||||
Some(chat_completion_response.model.clone()),
|
Some(chat_completion_response.model.clone()),
|
||||||
|
None,
|
||||||
),
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut response_str = String::new();
|
to_server_events(chunks)
|
||||||
for chunk in chunks.iter() {
|
|
||||||
response_str.push_str("data: ");
|
|
||||||
response_str.push_str(&serde_json::to_string(&chunk).unwrap());
|
|
||||||
response_str.push_str("\n\n");
|
|
||||||
}
|
|
||||||
response_str
|
|
||||||
} else {
|
} else {
|
||||||
String::from_utf8(body).unwrap()
|
String::from_utf8(body).unwrap()
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue