Add support for streaming and fixes few issues (see description) (#202)

This commit is contained in:
José Ulises Niño Rivera 2024-10-28 20:05:06 -04:00 committed by GitHub
parent 29ff8da60f
commit 662a840ac5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
45 changed files with 2266 additions and 477 deletions

View file

@ -2,9 +2,9 @@ use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::hallucination::extract_messages_for_hallucination;
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, ToolCall,
ToolType,
to_server_events, ArchState, ChatCompletionStreamResponse, ChatCompletionTool,
ChatCompletionsRequest, ChatCompletionsResponse, FunctionDefinition, FunctionParameter,
FunctionParameters, Message, ParameterType, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -12,7 +12,12 @@ use common::common_types::{
};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, HALLUCINATION_TEMPLATE, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST
ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS,
ARCH_INTERNAL_CLUSTER_NAME, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER,
ASSISTANT_ROLE, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST,
HALLUCINATION_TEMPLATE, MESSAGES_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
ZEROSHOT_INTERNAL_HOST,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -57,7 +62,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
pub metrics: Rc<WasmMetrics>,
pub callouts: RefCell<HashMap<u32, StreamCallContext>>,
@ -66,9 +71,8 @@ pub struct StreamContext {
pub tool_call_response: Option<String>,
pub arch_state: Option<Vec<ArchState>>,
pub request_body_size: usize,
pub streaming_response: bool,
pub user_prompt: Option<Message>,
pub response_tokens: usize,
pub streaming_response: bool,
pub is_chat_completions_request: bool,
pub chat_completions_request: Option<ChatCompletionsRequest>,
pub prompt_guards: Rc<PromptGuards>,
@ -99,7 +103,6 @@ impl StreamContext {
request_body_size: 0,
streaming_response: false,
user_prompt: None,
response_tokens: 0,
is_chat_completions_request: false,
prompt_guards,
overrides,
@ -300,13 +303,17 @@ impl StreamContext {
body: Vec<u8>,
callout_context: StreamCallContext,
) {
let boyd_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", boyd_str);
let body_str = String::from_utf8(body).expect("could not convert body to string");
debug!("archgw <= hallucination response: {}", body_str);
let hallucination_response: HallucinationClassificationResponse =
match serde_json::from_str(boyd_str.as_str()) {
match serde_json::from_str(body_str.as_str()) {
Ok(hallucination_response) => hallucination_response,
Err(e) => {
warn!("error deserializing hallucination response: {}", e);
warn!(
"error deserializing hallucination response: {}, body: {}",
e,
body_str.as_str()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -323,37 +330,36 @@ impl StreamContext {
if !keys_with_low_score.is_empty() {
let response =
HALLUCINATION_TEMPLATE.to_string()
+ &keys_with_low_score.join(", ")
+ " ?";
let message = Message {
role: ASSISTANT_ROLE.to_string(),
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
};
HALLUCINATION_TEMPLATE.to_string() + &keys_with_low_score.join(", ") + " ?";
let chat_completion_response = ChatCompletionsResponse {
choices: vec![Choice {
message,
index: 0,
finish_reason: "done".to_string(),
}],
usage: None,
model: ARCH_FC_MODEL_NAME.to_string(),
metadata: None,
};
let response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(response),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
trace!("hallucination response: {:?}", chat_completion_response);
to_server_events(chunks)
} else {
let chat_completion_response = ChatCompletionsResponse::new(response);
serde_json::to_string(&chat_completion_response).unwrap()
};
debug!("hallucination response: {:?}", response_str);
// make sure on_http_response_body does not attach tool calls and tool response to the response
self.tool_calls = None;
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(
serde_json::to_string(&chat_completion_response)
.unwrap()
.as_bytes(),
),
Some(response_str.as_bytes()),
);
} else {
// not a hallucination, resume the flow
@ -629,6 +635,7 @@ impl StreamContext {
.message
.tool_calls
.clone_into(&mut self.tool_calls);
if self.tool_calls.as_ref().unwrap().len() > 1 {
warn!(
"multiple tool calls not supported yet, tool_calls count found: {}",
@ -643,10 +650,39 @@ impl StreamContext {
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
let direct_response_str = if self.streaming_response {
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
ChatCompletionStreamResponse::new(
Some(
arch_fc_response.choices[0]
.message
.content
.as_ref()
.unwrap()
.clone(),
),
None,
Some(ARCH_FC_MODEL_NAME.to_owned()),
None,
),
];
to_server_events(chunks)
} else {
body_str
};
self.tool_calls = None;
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(body_str.as_bytes()),
Some(direct_response_str.as_bytes()),
);
}
@ -943,7 +979,7 @@ impl StreamContext {
self.get_embeddings(callout_context);
}
pub fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
pub fn default_target_handler(&self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
@ -951,8 +987,34 @@ impl StreamContext {
.clone();
// check if the default target should be dispatched to the LLM provider
if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) {
let default_target_response_str = String::from_utf8(body).unwrap();
if !prompt_target
.auto_llm_dispatch_on_response
.unwrap_or_default()
{
let default_target_response_str = if self.streaming_response {
let chat_completion_response =
serde_json::from_slice::<ChatCompletionsResponse>(&body).unwrap();
let chunks = vec![
ChatCompletionStreamResponse::new(
None,
Some(ASSISTANT_ROLE.to_string()),
Some(chat_completion_response.model.clone()),
None,
),
ChatCompletionStreamResponse::new(
chat_completion_response.choices[0].message.content.clone(),
None,
Some(chat_completion_response.model.clone()),
None,
),
];
to_server_events(chunks)
} else {
String::from_utf8(body).unwrap()
};
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
@ -960,20 +1022,20 @@ impl StreamContext {
);
return;
}
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
warn!("error deserializing default target response: {}", e);
warn!(
"error deserializing default target response: {}, body str: {}",
e,
String::from_utf8(body).unwrap()
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let mut messages = callout_context.request_body.messages;
let mut messages = Vec::new();
// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
@ -989,13 +1051,24 @@ impl StreamContext {
}
}
messages.append(&mut callout_context.request_body.messages);
let api_resp = chat_completions_resp.choices[0]
.message
.content
.as_ref()
.unwrap();
let user_message = messages.pop().unwrap();
let message = format!("{}\ncontext: {}", user_message.content.unwrap(), api_resp);
messages.push(Message {
role: USER_ROLE.to_string(),
content: Some(api_resp.clone()),
content: Some(message),
model: None,
tool_calls: None,
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: self
.chat_completions_request
@ -1009,11 +1082,32 @@ impl StreamContext {
stream_options: callout_context.request_body.stream_options,
metadata: None,
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("archgw => (default target) llm request: {}", json_resp);
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
pub fn generate_toll_call_message(&mut self) -> Message {
Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
}
}
pub fn generate_api_response_message(&mut self) -> Message {
Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
}
}
}
impl Client for StreamContext {