mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 08:12:48 +02:00
Add support for streaming and fixes few issues (see description) (#202)
This commit is contained in:
parent
29ff8da60f
commit
662a840ac5
45 changed files with 2266 additions and 477 deletions
|
|
@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
|||
use crate::hallucination::extract_messages_for_hallucination;
|
||||
use acap::cos;
|
||||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall,
|
||||
ToolType,
|
||||
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
|
||||
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
|
||||
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -12,7 +12,12 @@ use common::common_types::{
|
|||
};
|
||||
use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
||||
use common::consts::{
|
||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, HALLUCINATION_TEMPLATE, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
|
||||
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
|
||||
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
|
||||
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
|
||||
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
|
||||
ZEROSHOT_INTERNAL_HOST,
|
||||
};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -57,7 +62,7 @@ pub struct StreamCallContext {
|
|||
pub struct StreamContext {
|
||||
system_prompt: Rc<Option<String>>,
|
||||
prompt_targets: Rc<HashMap<String, PromptTarget>>,
|
||||
embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
|
||||
|
|
@ -66,9 +71,8 @@ pub struct StreamContext {
|
|||
pub tool_call_response: Option<String>,
|
||||
pub arch_state: Option<Vec<ArchState>>,
|
||||
pub request_body_size: usize,
|
||||
pub streaming_response: bool,
|
||||
pub user_prompt: Option<Message>,
|
||||
pub response_tokens: usize,
|
||||
pub streaming_response: bool,
|
||||
pub is_chat_completions_request: bool,
|
||||
pub chat_completions_request: Option<ChatCompletionsRequest>,
|
||||
pub prompt_guards: Rc<PromptGuards>,
|
||||
|
|
@ -99,7 +103,6 @@ impl StreamContext {
|
|||
request_body_size: 0,
|
||||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
response_tokens: 0,
|
||||
is_chat_completions_request: false,
|
||||
prompt_guards,
|
||||
overrides,
|
||||
|
|
@ -300,13 +303,17 @@ impl StreamContext {
|
|||
body: Vec<u8>,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
|
||||
debug!("archgw <= hallucination response: {}", boyd_str);
|
||||
let body_str = String::from_utf8(body).expect("could not convert body to string");
|
||||
debug!("archgw <= hallucination response: {}", body_str);
|
||||
let hallucination_response: HallucinationClassificationResponse =
|
||||
match serde_json::from_str(boyd_str.as_str()) {
|
||||
match serde_json::from_str(body_str.as_str()) {
|
||||
Ok(hallucination_response) => hallucination_response,
|
||||
Err(e) => {
|
||||
warn!("error deserializing hallucination response: {}", e);
|
||||
warn!(
|
||||
"error deserializing hallucination response: {}, body: {}",
|
||||
e,
|
||||
body_str.as_str()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
|
|
@ -323,37 +330,36 @@ impl StreamContext {
|
|||
|
||||
if !keys_with_low_score.is_empty() {
|
||||
let response =
|
||||
HALLUCINATION_TEMPLATE.to_string()
|
||||
+ &keys_with_low_score.join(", ")
|
||||
+ " ?";
|
||||
let message = Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
choices: vec![Choice {
|
||||
message,
|
||||
index: 0,
|
||||
finish_reason: "done".to_string(),
|
||||
}],
|
||||
usage: None,
|
||||
model: ARCH_FC_MODEL_NAME.to_string(),
|
||||
metadata: None,
|
||||
};
|
||||
let response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(response),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
trace!("hallucination response: {:?}", chat_completion_response);
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
let chat_completion_response = ChatCompletionsResponse::new(response);
|
||||
serde_json::to_string(&chat_completion_response).unwrap()
|
||||
};
|
||||
debug!("hallucination response: {:?}", response_str);
|
||||
// make sure on_http_response_body does not attach tool calls and tool response to the response
|
||||
self.tool_calls = None;
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(
|
||||
serde_json::to_string(&chat_completion_response)
|
||||
.unwrap()
|
||||
.as_bytes(),
|
||||
),
|
||||
Some(response_str.as_bytes()),
|
||||
);
|
||||
} else {
|
||||
// not a hallucination, resume the flow
|
||||
|
|
@ -629,6 +635,7 @@ impl StreamContext {
|
|||
.message
|
||||
.tool_calls
|
||||
.clone_into(&mut self.tool_calls);
|
||||
|
||||
if self.tool_calls.as_ref().unwrap().len() > 1 {
|
||||
warn!(
|
||||
"multiple tool calls not supported yet, tool_calls count found: {}",
|
||||
|
|
@ -643,10 +650,39 @@ impl StreamContext {
|
|||
|
||||
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
||||
|
||||
let direct_response_str = if self.streaming_response {
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
Some(
|
||||
arch_fc_response.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
),
|
||||
None,
|
||||
Some(ARCH_FC_MODEL_NAME.to_owned()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
body_str
|
||||
};
|
||||
|
||||
self.tool_calls = None;
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(body_str.as_bytes()),
|
||||
Some(direct_response_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
|
|
@ -943,7 +979,7 @@ impl StreamContext {
|
|||
self.get_embeddings(callout_context);
|
||||
}
|
||||
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
|
||||
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.get(callout_context.prompt_target_name.as_ref().unwrap())
|
||||
|
|
@ -951,8 +987,34 @@ impl StreamContext {
|
|||
.clone();
|
||||
|
||||
// check if the default target should be dispatched to the LLM provider
|
||||
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
|
||||
let default_target_response_str = String::from_utf8(body).unwrap();
|
||||
if !prompt_target
|
||||
.auto_llm_dispatch_on_response
|
||||
.unwrap_or_default()
|
||||
{
|
||||
let default_target_response_str = if self.streaming_response {
|
||||
let chat_completion_response =
|
||||
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
|
||||
|
||||
let chunks = vec![
|
||||
ChatCompletionStreamResponse::new(
|
||||
None,
|
||||
Some(ASSISTANT_ROLE.to_string()),
|
||||
Some(chat_completion_response.model.clone()),
|
||||
None,
|
||||
),
|
||||
ChatCompletionStreamResponse::new(
|
||||
chat_completion_response.choices[0].message.content.clone(),
|
||||
None,
|
||||
Some(chat_completion_response.model.clone()),
|
||||
None,
|
||||
),
|
||||
];
|
||||
|
||||
to_server_events(chunks)
|
||||
} else {
|
||||
String::from_utf8(body).unwrap()
|
||||
};
|
||||
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
|
|
@ -960,20 +1022,20 @@ impl StreamContext {
|
|||
);
|
||||
return;
|
||||
}
|
||||
|
||||
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(chat_completions_resp) => chat_completions_resp,
|
||||
Err(e) => {
|
||||
warn!("error deserializing default target response: {}", e);
|
||||
warn!(
|
||||
"error deserializing default target response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8(body).unwrap()
|
||||
);
|
||||
return self.send_server_error(ServerError::Deserialization(e), None);
|
||||
}
|
||||
};
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
let mut messages = callout_context.request_body.messages;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
// add system prompt
|
||||
match prompt_target.system_prompt.as_ref() {
|
||||
None => {}
|
||||
|
|
@ -989,13 +1051,24 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
messages.append(&mut callout_context.request_body.messages);
|
||||
|
||||
let api_resp = chat_completions_resp.choices[0]
|
||||
.message
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
|
||||
let user_message = messages.pop().unwrap();
|
||||
let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp);
|
||||
messages.push(Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(api_resp.clone()),
|
||||
content: Some(message),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: self
|
||||
.chat_completions_request
|
||||
|
|
@ -1009,11 +1082,32 @@ impl StreamContext {
|
|||
stream_options: callout_context.request_body.stream_options,
|
||||
metadata: None,
|
||||
};
|
||||
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("archgw => (default target) llm request: {}", json_resp);
|
||||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
pub fn generate_toll_call_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn generate_api_response_message(&mut self) -> Message {
|
||||
Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Client for StreamContext {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue