Implement Client trait for StreamContext (#134)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-07 19:50:15 -04:00 committed by GitHub
parent 5bfccd3959
commit c1cfbcd44d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 216 additions and 218 deletions

View file

@ -5,6 +5,7 @@ use crate::consts::{
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::http::{CallArgs, Client, ClientError};
use crate::llm_providers::LlmProviders;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
@ -30,11 +31,13 @@ use public_types::embeddings::{
};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;
#[derive(Debug)]
enum ResponseHandlerType {
GetEmbeddings,
FunctionResolver,
@ -44,7 +47,8 @@ enum ResponseHandlerType {
DefaultTarget,
}
pub struct CallContext {
#[derive(Debug)]
pub struct StreamCallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target_name: Option<String>,
@ -54,13 +58,37 @@ pub struct CallContext {
upstream_cluster_path: Option<String>,
}
#[derive(thiserror::Error, Debug)]
pub enum ServerError {
#[error(transparent)]
HttpDispatch(ClientError),
#[error(transparent)]
Deserialization(serde_json::Error),
#[error(transparent)]
Serialization(serde_json::Error),
#[error("{0}")]
LogicError(String),
#[error("upstream error response authority={authority}, path={path}, status={status}")]
Upstream {
authority: String,
path: String,
status: String,
},
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("jailbreak detected: {0}")]
Jailbreak(String),
#[error("{why}")]
BadRequest { why: String },
}
pub struct StreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
callouts: RefCell<HashMap<u32, StreamCallContext>>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
@ -91,8 +119,8 @@ impl StreamContext {
metrics,
prompt_targets,
embeddings_store,
callouts: RefCell::new(HashMap::new()),
chat_completions_request: None,
callouts: HashMap::new(),
tool_calls: None,
tool_call_response: None,
arch_state: None,
@ -129,11 +157,17 @@ impl StreamContext {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
fn modify_auth_headers(&mut self) -> Result<(), String> {
let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!(
"No access key configured for selected LLM Provider \"{}\"",
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
let llm_provider_api_key_value =
self.llm_provider()
))?;
.access_key
.as_ref()
.ok_or(ServerError::BadRequest {
why: format!(
"No access key configured for selected LLM Provider \"{}\"",
self.llm_provider()
),
})?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
@ -159,7 +193,7 @@ impl StreamContext {
});
}
fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
override_status_code
@ -167,18 +201,15 @@ impl StreamContext {
.as_u16()
.into(),
vec![],
Some(error.as_bytes()),
Some(format!("{error}").as_bytes()),
);
}
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
return self.send_server_error(
format!("Error deserializing embedding response: {:?}", e),
None,
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -248,13 +279,13 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
let error = format!("Error serializing zero shot request: {}", error);
return self.send_server_error(error, None);
return self.send_server_error(ServerError::Serialization(error), None);
}
};
let token_id = match self.dispatch_http_call(
let call_args = CallArgs::new(
MODEL_SERVER_NAME,
"/zeroshot",
vec![
(":method", "POST"),
(":path", "/zeroshot"),
@ -266,49 +297,24 @@ impl StreamContext {
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
return self.send_server_error(error_msg, None);
}
};
debug!(
"dispatched call to model_server/zeroshot token_id={}",
token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
mut callout_context: CallContext,
mut callout_context: StreamCallContext,
) {
let zeroshot_intent_response: ZeroShotClassificationResponse =
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
self.send_server_error(
format!(
"Error deserializing zeroshot intent detection response: {:?}",
e
),
None,
);
return;
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -390,40 +396,34 @@ impl StreamContext {
);
let arch_messages_json = serde_json::to_string(&params).unwrap();
debug!("no prompt target found with similarity score above threshold, using default prompt target");
let token_id = match self.dispatch_http_call(
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let call_args = CallArgs::new(
&upstream_endpoint,
&upstream_path,
vec![
(":method", "POST"),
(":path", &upstream_path),
(":authority", &upstream_endpoint),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
],
Some(arch_messages_json.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg =
format!("Error dispatching HTTP call for default-target: {:?}", e);
return self
.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
self.metrics.active_http_calls.increment(1);
);
callout_context.response_handler_type = ResponseHandlerType::DefaultTarget;
callout_context.prompt_target_name = Some(default_prompt_target.name.clone());
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
if let Err(e) = self.http_call(call_args, callout_context) {
return self.send_server_error(
ServerError::HttpDispatch(e),
Some(StatusCode::BAD_REQUEST),
);
}
return;
}
self.resume_http_request();
return;
}
@ -433,7 +433,9 @@ impl StreamContext {
Some(prompt_target) => prompt_target.clone(),
None => {
return self.send_server_error(
format!("Prompt target not found: {}", prompt_target_name),
ServerError::LogicError(format!(
"Prompt target not found: {prompt_target_name}"
)),
None,
);
}
@ -499,62 +501,42 @@ impl StreamContext {
msg_body
}
Err(e) => {
return self
.send_server_error(format!("Error serializing request_params: {:?}", e), None);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
let token_id = match self.dispatch_http_call(
let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string();
let call_args = CallArgs::new(
ARC_FC_CLUSTER,
"/v1/chat/completions",
vec![
(":method", "POST"),
(":path", "/v1/chat/completions"),
(":authority", ARC_FC_CLUSTER),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()),
],
Some(msg_body.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
debug!(
"dispatched call to function {} token_id={}",
ARC_FC_CLUSTER, token_id
);
self.metrics.active_http_calls.increment(1);
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target_name = Some(prompt_target.name);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
}
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: StreamCallContext) {
let body_str = String::from_utf8(body).unwrap();
debug!("arch <= app response body: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
Err(e) => {
return self.send_server_error(
format!(
"Error deserializing function resolver response into ChatCompletion: {:?}",
e
),
None,
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
@ -607,11 +589,12 @@ impl StreamContext {
let endpoint = prompt_target.endpoint.unwrap();
let path: String = endpoint.path.unwrap_or(String::from("/"));
let token_id = match self.dispatch_http_call(
let call_args = CallArgs::new(
&endpoint.name,
&path,
vec![
(":method", "POST"),
(":path", path.as_ref()),
(":path", &path),
(":authority", endpoint.name.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
@ -619,39 +602,33 @@ impl StreamContext {
Some(tool_params_json_str.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!(
"Error dispatching call to cluster: {}, path: {}, err: {:?}",
&endpoint.name, path, e
);
debug!("{}", error_msg);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
);
callout_context.upstream_cluster = Some(endpoint.name.clone());
callout_context.upstream_cluster_path = Some(path.clone());
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
self.tool_calls = Some(tool_calls.clone());
callout_context.upstream_cluster = Some(endpoint.name);
callout_context.upstream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
if let Err(e) = self.http_call(call_args, callout_context) {
self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST));
}
self.metrics.active_http_calls.increment(1);
}
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
let headers = self.get_http_call_response_headers();
if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") {
if http_status.1 != StatusCode::OK.as_str() {
let error_msg = format!(
"Error in function call response: cluster: {}, path: {}, status code: {}",
callout_context.upstream_cluster.unwrap(),
callout_context.upstream_cluster_path.unwrap(),
http_status.1
fn function_call_response_handler(
&mut self,
body: Vec<u8>,
callout_context: StreamCallContext,
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
return self.send_server_error(
ServerError::Upstream {
authority: callout_context.upstream_cluster.unwrap(),
path: callout_context.upstream_cluster_path.unwrap(),
status: http_status,
},
None,
);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
} else {
warn!("http status code not found in api response");
@ -714,8 +691,7 @@ impl StreamContext {
let json_string = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
return self
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!("arch => openai request body: {}", json_string);
@ -733,7 +709,7 @@ impl StreamContext {
Ok(_) => (),
Err(err) => {
self.send_server_error(
format!("Exceeded Ratelimit: {}", err),
ServerError::ExceededRatelimit(err),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
@ -747,7 +723,7 @@ impl StreamContext {
self.resume_http_request();
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
debug!("prompt_guard_resp: {:?}", prompt_guard_resp);
@ -757,14 +733,17 @@ impl StreamContext {
let msg = self
.prompt_guards
.jailbreak_on_exception_message()
.unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking.");
return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST));
.unwrap_or("refrain from discussing jailbreaking.");
return self.send_server_error(
ServerError::Jailbreak(String::from(msg)),
Some(StatusCode::BAD_REQUEST),
);
}
self.get_embeddings(callout_context);
}
fn get_embeddings(&mut self, callout_context: CallContext) {
fn get_embeddings(&mut self, callout_context: StreamCallContext) {
let user_message = callout_context.user_message.unwrap();
let get_embeddings_input = CreateEmbeddingRequest {
// Need to clone into input because user_message is used below.
@ -778,13 +757,13 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
let error_msg = format!("Error serializing embeddings input: {}", error);
return self.send_server_error(error_msg, None);
return self.send_server_error(ServerError::Deserialization(error), None);
}
};
let token_id = match self.dispatch_http_call(
let call_args = CallArgs::new(
MODEL_SERVER_NAME,
"/embeddings",
vec![
(":method", "POST"),
(":path", "/embeddings"),
@ -796,19 +775,8 @@ impl StreamContext {
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!("dispatched call to model_server/embeddings: {:?}", e);
return self.send_server_error(error_msg, None);
}
};
debug!(
"dispatched call to model_server/embeddings token_id={}",
token_id
);
let call_context = CallContext {
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target_name: None,
@ -817,17 +785,13 @@ impl StreamContext {
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics.active_http_calls.increment(1);
if let Err(e) = self.http_call(call_args, call_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
}
fn default_target_handler(&self, body: Vec<u8>, callout_context: CallContext) {
fn default_target_handler(&self, body: Vec<u8>, callout_context: StreamCallContext) {
let prompt_target = self
.prompt_targets
.get(callout_context.prompt_target_name.as_ref().unwrap())
@ -856,10 +820,7 @@ impl StreamContext {
let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) {
Ok(chat_completions_resp) => chat_completions_resp,
Err(e) => {
return self.send_server_error(
format!("Error deserializing default target response: {:?}", e),
None,
);
return self.send_server_error(ServerError::Deserialization(e), None);
}
};
let api_resp = chat_completions_resp.choices[0]
@ -948,9 +909,9 @@ impl HttpContext for StreamContext {
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
Err(e) => {
self.send_server_error(
format!("Failed to deserialize: {}", msg),
ServerError::Deserialization(e),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
@ -958,10 +919,10 @@ impl HttpContext for StreamContext {
},
None => {
self.send_server_error(
format!(
ServerError::LogicError(format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
),
)),
None,
);
return Action::Pause;
@ -1018,7 +979,7 @@ impl HttpContext for StreamContext {
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve");
let callout_context = CallContext {
let callout_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: user_message_str.clone(),
prompt_target_name: None,
@ -1046,14 +1007,14 @@ impl HttpContext for StreamContext {
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
Ok(json_data) => json_data,
Err(error) => {
let error_msg = format!("Error serializing prompt guard request: {}", error);
self.send_server_error(error_msg, None);
self.send_server_error(ServerError::Serialization(error), None);
return Action::Pause;
}
};
let token_id = match self.dispatch_http_call(
let call_args = CallArgs::new(
MODEL_SERVER_NAME,
"/guard",
vec![
(":method", "POST"),
(":path", "/guard"),
@ -1065,21 +1026,8 @@ impl HttpContext for StreamContext {
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
let error_msg = format!(
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
e
);
self.send_server_error(error_msg, None);
return Action::Pause;
}
};
debug!("dispatched HTTP call to arch_guard token_id={}", token_id);
let call_context = CallContext {
);
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
@ -1088,14 +1036,10 @@ impl HttpContext for StreamContext {
upstream_cluster: None,
upstream_cluster_path: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics.active_http_calls.increment(1);
if let Err(e) = self.http_call(call_args, call_context) {
self.send_server_error(ServerError::HttpDispatch(e), None);
}
Action::Pause
}
@ -1130,7 +1074,10 @@ impl HttpContext for StreamContext {
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(String::from("parsing error in streaming data"), None);
self.send_server_error(
ServerError::LogicError(String::from("parsing error in streaming data")),
None,
);
return Action::Pause;
}
};
@ -1141,7 +1088,9 @@ impl HttpContext for StreamContext {
Err(_) => {
if chat_completions_data != "[NONE]" {
self.send_server_error(
String::from("error in streaming response"),
ServerError::LogicError(String::from(
"error in streaming response",
)),
None,
);
return Action::Continue;
@ -1168,14 +1117,7 @@ impl HttpContext for StreamContext {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
self.send_server_error(
format!(
"error in non-streaming response: {}\n response was={}",
e,
String::from_utf8(body).unwrap()
),
None,
);
self.send_server_error(ServerError::Deserialization(e), None);
return Action::Pause;
}
};
@ -1260,7 +1202,11 @@ impl Context for StreamContext {
body_size: usize,
_num_trailers: usize,
) {
let callout_context = self.callouts.remove(&token_id).expect("invalid token_id");
let callout_context = self
.callouts
.get_mut()
.remove(&token_id)
.expect("invalid token_id");
self.metrics.active_http_calls.increment(-1);
if let Some(body) = self.get_http_call_response_body(0, body_size) {
@ -1284,9 +1230,21 @@ impl Context for StreamContext {
}
} else {
self.send_server_error(
String::from("No response body in inline HTTP request"),
ServerError::LogicError(String::from("No response body in inline HTTP request")),
None,
);
}
}
}
impl Client for StreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts
}
fn active_http_calls(&self) -> &crate::stats::Gauge {
&self.metrics.active_http_calls
}
}