diff --git a/arch/Cargo.lock b/arch/Cargo.lock index 388a329d..6e9a2e5d 100644 --- a/arch/Cargo.lock +++ b/arch/Cargo.lock @@ -410,6 +410,17 @@ dependencies = [ "uuid", ] +[[package]] +name = "derivative" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcc3dd5e9e9c0b295d6e1e4d811fb6f157d5ffd784b8d202fc62eac8035a770b" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -747,6 +758,7 @@ name = "intelligent-prompt-gateway" version = "0.1.0" dependencies = [ "acap", + "derivative", "governor", "http", "log", diff --git a/arch/Cargo.toml b/arch/Cargo.toml index 15fe482b..430703b8 100644 --- a/arch/Cargo.toml +++ b/arch/Cargo.toml @@ -21,6 +21,7 @@ tiktoken-rs = "0.5.9" acap = "0.3.0" rand = "0.8.5" thiserror = "1.0.64" +derivative = "2.2.0" sha2 = "0.10.8" [dev-dependencies] diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index cb2eb732..fa0f29fc 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -97,6 +97,7 @@ impl FilterContext { let call_args = CallArgs::new( MODEL_SERVER_NAME, + "/embeddings", vec![ (":method", "POST"), (":path", "/embeddings"), diff --git a/arch/src/http.rs b/arch/src/http.rs index 592e7c5f..93cf5118 100644 --- a/arch/src/http.rs +++ b/arch/src/http.rs @@ -1,12 +1,16 @@ use crate::stats::{Gauge, IncrementingMetric}; +use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; -#[derive(Debug)] +#[derive(Derivative)] +#[derivative(Debug)] pub struct CallArgs<'a> { upstream: &'a str, + path: &'a str, headers: Vec<(&'a str, &'a str)>, + #[derivative(Debug = "ignore")] body: Option<&'a [u8]>, trailers: Vec<(&'a str, &'a str)>, timeout: Duration, @@ -15,6 +19,7 @@ pub struct CallArgs<'a> { impl<'a> CallArgs<'a> { pub fn new( upstream: &'a str, + path: &'a str, headers: Vec<(&'a str, &'a str)>, body: Option<&'a [u8]>, trailers: Vec<(&'a str, &'a str)>, @@ -22,6 +27,7 @@ impl<'a> CallArgs<'a> { ) -> Self { CallArgs { upstream, + path, headers, body, trailers, @@ -32,9 +38,10 @@ impl<'a> CallArgs<'a> { #[derive(thiserror::Error, Debug)] pub enum ClientError { - #[error("Error dispatching HTTP call to `{upstream_name}`, error: {internal_status:?}")] + #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] DispatchError { upstream_name: String, + path: String, internal_status: Status, }, } @@ -46,7 +53,7 @@ pub trait Client: Context { &self, call_args: CallArgs, call_context: Self::CallContext, - ) -> Result<(), ClientError> { + ) -> Result { debug!( "dispatching http call with args={:?} context={:?}", call_args, call_context @@ -61,10 +68,11 @@ pub trait Client: Context { ) { Ok(id) => { self.add_call_context(id, call_context); - Ok(()) + Ok(id) } Err(status) => Err(ClientError::DispatchError { upstream_name: String::from(call_args.upstream), + path: String::from(call_args.path), internal_status: status, }), } diff --git a/arch/src/ratelimit.rs b/arch/src/ratelimit.rs index 42554bbe..311ceb48 100644 --- a/arch/src/ratelimit.rs +++ b/arch/src/ratelimit.rs @@ -2,6 +2,7 @@ use governor::{DefaultKeyedRateLimiter, InsufficientCapacity, Quota}; use log::debug; use public_types::configuration; use public_types::configuration::{Limit, Ratelimit, TimeUnit}; +use std::fmt::Display; use std::num::{NonZero, NonZeroU32}; use std::sync::RwLock; use std::{collections::HashMap, sync::OnceLock}; @@ -28,13 +29,18 @@ pub struct RatelimitMap { } // This version of Header demands that the user passes a header value to match on. -#[allow(unused)] -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Header { pub key: String, pub value: String, } +impl Display for Header { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + impl From
for configuration::Header { fn from(header: Header) -> Self { Self { @@ -44,6 +50,16 @@ impl From
for configuration::Header { } } +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("exceeded limit provider={provider}, selector={selector}, tokens_used={tokens_used}")] + ExceededLimit { + provider: String, + selector: Header, + tokens_used: NonZeroU32, + }, +} + impl RatelimitMap { // n.b new is private so that the only access to the Ratelimits can be done via the static // reference inside a RwLock via ratelimit::ratelimits(). @@ -82,7 +98,7 @@ impl RatelimitMap { provider: String, selector: Header, tokens_used: NonZeroU32, - ) -> Result<(), String> { + ) -> Result<(), Error> { debug!( "Checking limit for provider={}, with selector={:?}, consuming tokens={:?}", provider, selector, tokens_used @@ -96,7 +112,7 @@ impl RatelimitMap { Some(limit) => limit, }; - let mut config_selector = configuration::Header::from(selector); + let mut config_selector = configuration::Header::from(selector.clone()); let (limit, limit_key) = match provider_limits.get(&config_selector) { // This is a specific limit, i.e one that was configured with both key, and value. @@ -119,8 +135,11 @@ impl RatelimitMap { match limit.check_key_n(&limit_key, tokens_used) { Ok(Ok(())) => Ok(()), - Ok(Err(_)) => Err(String::from("Not allowed")), - Err(InsufficientCapacity(_)) => Err(String::from("Not allowed")), + Ok(Err(_)) | Err(InsufficientCapacity(_)) => Err(Error::ExceededLimit { + provider, + selector, + tokens_used, + }), } } } diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 26ba4858..7cd02734 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -5,6 +5,7 @@ use crate::consts::{ RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::http::{CallArgs, Client, ClientError}; use crate::llm_providers::LlmProviders; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; @@ -30,11 +31,13 @@ use public_types::embeddings::{ }; use serde_json::Value; use sha2::{Digest, Sha256}; +use std::cell::RefCell; use std::collections::HashMap; use std::num::NonZero; use std::rc::Rc; use std::time::Duration; +#[derive(Debug)] enum ResponseHandlerType { GetEmbeddings, FunctionResolver, @@ -44,7 +47,8 @@ enum ResponseHandlerType { DefaultTarget, } -pub struct CallContext { +#[derive(Debug)] +pub struct StreamCallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target_name: Option, @@ -54,13 +58,37 @@ pub struct CallContext { upstream_cluster_path: Option, } +#[derive(thiserror::Error, Debug)] +pub enum ServerError { + #[error(transparent)] + HttpDispatch(ClientError), + #[error(transparent)] + Deserialization(serde_json::Error), + #[error(transparent)] + Serialization(serde_json::Error), + #[error("{0}")] + LogicError(String), + #[error("upstream error response authority={authority}, path={path}, status={status}")] + Upstream { + authority: String, + path: String, + status: String, + }, + #[error(transparent)] + ExceededRatelimit(ratelimit::Error), + #[error("jailbreak detected: {0}")] + Jailbreak(String), + #[error("{why}")] + BadRequest { why: String }, +} + pub struct StreamContext { context_id: u32, metrics: Rc, prompt_targets: Rc>, embeddings_store: Rc, overrides: Rc>, - callouts: HashMap, + callouts: RefCell>, tool_calls: Option>, tool_call_response: Option, arch_state: Option>, @@ -91,8 +119,8 @@ impl StreamContext { metrics, prompt_targets, embeddings_store, + callouts: RefCell::new(HashMap::new()), chat_completions_request: None, - callouts: HashMap::new(), tool_calls: None, tool_call_response: None, arch_state: None, @@ -129,11 +157,17 @@ impl StreamContext { self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); } - fn modify_auth_headers(&mut self) -> Result<(), String> { - let llm_provider_api_key_value = self.llm_provider().access_key.as_ref().ok_or(format!( - "No access key configured for selected LLM Provider \"{}\"", + fn modify_auth_headers(&mut self) -> Result<(), ServerError> { + let llm_provider_api_key_value = self.llm_provider() - ))?; + .access_key + .as_ref() + .ok_or(ServerError::BadRequest { + why: format!( + "No access key configured for selected LLM Provider \"{}\"", + self.llm_provider() + ), + })?; let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); @@ -159,7 +193,7 @@ impl StreamContext { }); } - fn send_server_error(&self, error: String, override_status_code: Option) { + fn send_server_error(&self, error: ServerError, override_status_code: Option) { debug!("server error occurred: {}", error); self.send_http_response( override_status_code @@ -167,18 +201,15 @@ impl StreamContext { .as_u16() .into(), vec![], - Some(error.as_bytes()), + Some(format!("{error}").as_bytes()), ); } - fn embeddings_handler(&mut self, body: Vec, mut callout_context: CallContext) { + fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { - return self.send_server_error( - format!("Error deserializing embedding response: {:?}", e), - None, - ); + return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -248,13 +279,13 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { - let error = format!("Error serializing zero shot request: {}", error); - return self.send_server_error(error, None); + return self.send_server_error(ServerError::Serialization(error), None); } }; - let token_id = match self.dispatch_http_call( + let call_args = CallArgs::new( MODEL_SERVER_NAME, + "/zeroshot", vec![ (":method", "POST"), (":path", "/zeroshot"), @@ -266,49 +297,24 @@ impl StreamContext { Some(json_data.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = format!( - "Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}", - e - ); - return self.send_server_error(error_msg, None); - } - }; - debug!( - "dispatched call to model_server/zeroshot token_id={}", - token_id ); - - self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; - if self.callouts.insert(token_id, callout_context).is_some() { - panic!( - "duplicate token_id={} in embedding server requests", - token_id - ) + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); } } fn zero_shot_intent_detection_resp_handler( &mut self, body: Vec, - mut callout_context: CallContext, + mut callout_context: StreamCallContext, ) { let zeroshot_intent_response: ZeroShotClassificationResponse = match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { - self.send_server_error( - format!( - "Error deserializing zeroshot intent detection response: {:?}", - e - ), - None, - ); - return; + return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -390,40 +396,34 @@ impl StreamContext { ); let arch_messages_json = serde_json::to_string(¶ms).unwrap(); debug!("no prompt target found with similarity score above threshold, using default prompt target"); - let token_id = match self.dispatch_http_call( + + let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); + let call_args = CallArgs::new( &upstream_endpoint, + &upstream_path, vec![ (":method", "POST"), (":path", &upstream_path), (":authority", &upstream_endpoint), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), - ( - "x-envoy-upstream-rq-timeout-ms", - ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(), - ), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), ], Some(arch_messages_json.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = - format!("Error dispatching HTTP call for default-target: {:?}", e); - return self - .send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); - } - }; - - self.metrics.active_http_calls.increment(1); + ); callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); - if self.callouts.insert(token_id, callout_context).is_some() { - panic!("duplicate token_id") + + if let Err(e) = self.http_call(call_args, callout_context) { + return self.send_server_error( + ServerError::HttpDispatch(e), + Some(StatusCode::BAD_REQUEST), + ); } - return; } + self.resume_http_request(); return; } @@ -433,7 +433,9 @@ impl StreamContext { Some(prompt_target) => prompt_target.clone(), None => { return self.send_server_error( - format!("Prompt target not found: {}", prompt_target_name), + ServerError::LogicError(format!( + "Prompt target not found: {prompt_target_name}" + )), None, ); } @@ -499,62 +501,42 @@ impl StreamContext { msg_body } Err(e) => { - return self - .send_server_error(format!("Error serializing request_params: {:?}", e), None); + return self.send_server_error(ServerError::Serialization(e), None); } }; - let token_id = match self.dispatch_http_call( + let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); + let call_args = CallArgs::new( ARC_FC_CLUSTER, + "/v1/chat/completions", vec![ (":method", "POST"), (":path", "/v1/chat/completions"), (":authority", ARC_FC_CLUSTER), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), - ( - "x-envoy-upstream-rq-timeout-ms", - ARCH_FC_REQUEST_TIMEOUT_MS.to_string().as_str(), - ), + ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), ], Some(msg_body.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = format!("Error dispatching HTTP call for function-call: {:?}", e); - return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); - } - }; - - debug!( - "dispatched call to function {} token_id={}", - ARC_FC_CLUSTER, token_id ); - - self.metrics.active_http_calls.increment(1); callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; callout_context.prompt_target_name = Some(prompt_target.name); - if self.callouts.insert(token_id, callout_context).is_some() { - panic!("duplicate token_id") + + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } } - fn function_resolver_handler(&mut self, body: Vec, mut callout_context: CallContext) { + fn function_resolver_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { let body_str = String::from_utf8(body).unwrap(); debug!("arch <= app response body: {}", body_str); let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { - return self.send_server_error( - format!( - "Error deserializing function resolver response into ChatCompletion: {:?}", - e - ), - None, - ); + return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -607,11 +589,12 @@ impl StreamContext { let endpoint = prompt_target.endpoint.unwrap(); let path: String = endpoint.path.unwrap_or(String::from("/")); - let token_id = match self.dispatch_http_call( + let call_args = CallArgs::new( &endpoint.name, + &path, vec![ (":method", "POST"), - (":path", path.as_ref()), + (":path", &path), (":authority", endpoint.name.as_str()), ("content-type", "application/json"), ("x-envoy-max-retries", "3"), @@ -619,39 +602,33 @@ impl StreamContext { Some(tool_params_json_str.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = format!( - "Error dispatching call to cluster: {}, path: {}, err: {:?}", - &endpoint.name, path, e - ); - debug!("{}", error_msg); - return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); - } - }; + ); + callout_context.upstream_cluster = Some(endpoint.name.clone()); + callout_context.upstream_cluster_path = Some(path.clone()); + callout_context.response_handler_type = ResponseHandlerType::FunctionCall; self.tool_calls = Some(tool_calls.clone()); - callout_context.upstream_cluster = Some(endpoint.name); - callout_context.upstream_cluster_path = Some(path); - callout_context.response_handler_type = ResponseHandlerType::FunctionCall; - if self.callouts.insert(token_id, callout_context).is_some() { - panic!("duplicate token_id") + + if let Err(e) = self.http_call(call_args, callout_context) { + self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } - self.metrics.active_http_calls.increment(1); } - fn function_call_response_handler(&mut self, body: Vec, callout_context: CallContext) { - let headers = self.get_http_call_response_headers(); - if let Some(http_status) = headers.iter().find(|(key, _)| key == ":status") { - if http_status.1 != StatusCode::OK.as_str() { - let error_msg = format!( - "Error in function call response: cluster: {}, path: {}, status code: {}", - callout_context.upstream_cluster.unwrap(), - callout_context.upstream_cluster_path.unwrap(), - http_status.1 + fn function_call_response_handler( + &mut self, + body: Vec, + callout_context: StreamCallContext, + ) { + if let Some(http_status) = self.get_http_call_response_header(":status") { + if http_status != StatusCode::OK.as_str() { + return self.send_server_error( + ServerError::Upstream { + authority: callout_context.upstream_cluster.unwrap(), + path: callout_context.upstream_cluster_path.unwrap(), + status: http_status, + }, + None, ); - return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); } } else { warn!("http status code not found in api response"); @@ -714,8 +691,7 @@ impl StreamContext { let json_string = match serde_json::to_string(&chat_completions_request) { Ok(json_string) => json_string, Err(e) => { - return self - .send_server_error(format!("Error serializing request_body: {:?}", e), None); + return self.send_server_error(ServerError::Serialization(e), None); } }; debug!("arch => openai request body: {}", json_string); @@ -733,7 +709,7 @@ impl StreamContext { Ok(_) => (), Err(err) => { self.send_server_error( - format!("Exceeded Ratelimit: {}", err), + ServerError::ExceededRatelimit(err), Some(StatusCode::TOO_MANY_REQUESTS), ); self.metrics.ratelimited_rq.increment(1); @@ -747,7 +723,7 @@ impl StreamContext { self.resume_http_request(); } - fn arch_guard_handler(&mut self, body: Vec, callout_context: CallContext) { + fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); debug!("prompt_guard_resp: {:?}", prompt_guard_resp); @@ -757,14 +733,17 @@ impl StreamContext { let msg = self .prompt_guards .jailbreak_on_exception_message() - .unwrap_or("Jailbreak detected. Please refrain from discussing jailbreaking."); - return self.send_server_error(msg.to_string(), Some(StatusCode::BAD_REQUEST)); + .unwrap_or("refrain from discussing jailbreaking."); + return self.send_server_error( + ServerError::Jailbreak(String::from(msg)), + Some(StatusCode::BAD_REQUEST), + ); } self.get_embeddings(callout_context); } - fn get_embeddings(&mut self, callout_context: CallContext) { + fn get_embeddings(&mut self, callout_context: StreamCallContext) { let user_message = callout_context.user_message.unwrap(); let get_embeddings_input = CreateEmbeddingRequest { // Need to clone into input because user_message is used below. @@ -778,13 +757,13 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { - let error_msg = format!("Error serializing embeddings input: {}", error); - return self.send_server_error(error_msg, None); + return self.send_server_error(ServerError::Deserialization(error), None); } }; - let token_id = match self.dispatch_http_call( + let call_args = CallArgs::new( MODEL_SERVER_NAME, + "/embeddings", vec![ (":method", "POST"), (":path", "/embeddings"), @@ -796,19 +775,8 @@ impl StreamContext { Some(json_data.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = format!("dispatched call to model_server/embeddings: {:?}", e); - return self.send_server_error(error_msg, None); - } - }; - debug!( - "dispatched call to model_server/embeddings token_id={}", - token_id ); - - let call_context = CallContext { + let call_context = StreamCallContext { response_handler_type: ResponseHandlerType::GetEmbeddings, user_message: Some(user_message), prompt_target_name: None, @@ -817,17 +785,13 @@ impl StreamContext { upstream_cluster: None, upstream_cluster_path: None, }; - if self.callouts.insert(token_id, call_context).is_some() { - panic!( - "duplicate token_id={} in embedding server requests", - token_id - ) - } - self.metrics.active_http_calls.increment(1); + if let Err(e) = self.http_call(call_args, call_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); + } } - fn default_target_handler(&self, body: Vec, callout_context: CallContext) { + fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { let prompt_target = self .prompt_targets .get(callout_context.prompt_target_name.as_ref().unwrap()) @@ -856,10 +820,7 @@ impl StreamContext { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { - return self.send_server_error( - format!("Error deserializing default target response: {:?}", e), - None, - ); + return self.send_server_error(ServerError::Deserialization(e), None); } }; let api_resp = chat_completions_resp.choices[0] @@ -948,9 +909,9 @@ impl HttpContext for StreamContext { match self.get_http_request_body(0, body_size) { Some(body_bytes) => match serde_json::from_slice(&body_bytes) { Ok(deserialized) => deserialized, - Err(msg) => { + Err(e) => { self.send_server_error( - format!("Failed to deserialize: {}", msg), + ServerError::Deserialization(e), Some(StatusCode::BAD_REQUEST), ); return Action::Pause; @@ -958,10 +919,10 @@ impl HttpContext for StreamContext { }, None => { self.send_server_error( - format!( + ServerError::LogicError(format!( "Failed to obtain body bytes even though body_size is {}", body_size - ), + )), None, ); return Action::Pause; @@ -1018,7 +979,7 @@ impl HttpContext for StreamContext { if !prompt_guard_jailbreak_task { debug!("Missing input guard. Making inline call to retrieve"); - let callout_context = CallContext { + let callout_context = StreamCallContext { response_handler_type: ResponseHandlerType::ArchGuard, user_message: user_message_str.clone(), prompt_target_name: None, @@ -1046,14 +1007,14 @@ impl HttpContext for StreamContext { let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { Ok(json_data) => json_data, Err(error) => { - let error_msg = format!("Error serializing prompt guard request: {}", error); - self.send_server_error(error_msg, None); + self.send_server_error(ServerError::Serialization(error), None); return Action::Pause; } }; - let token_id = match self.dispatch_http_call( + let call_args = CallArgs::new( MODEL_SERVER_NAME, + "/guard", vec![ (":method", "POST"), (":path", "/guard"), @@ -1065,21 +1026,8 @@ impl HttpContext for StreamContext { Some(json_data.as_bytes()), vec![], Duration::from_secs(5), - ) { - Ok(token_id) => token_id, - Err(e) => { - let error_msg = format!( - "Error dispatching embedding server HTTP call for prompt-guard: {:?}", - e - ); - self.send_server_error(error_msg, None); - return Action::Pause; - } - }; - - debug!("dispatched HTTP call to arch_guard token_id={}", token_id); - - let call_context = CallContext { + ); + let call_context = StreamCallContext { response_handler_type: ResponseHandlerType::ArchGuard, user_message: self.user_prompt.as_ref().unwrap().content.clone(), prompt_target_name: None, @@ -1088,14 +1036,10 @@ impl HttpContext for StreamContext { upstream_cluster: None, upstream_cluster_path: None, }; - if self.callouts.insert(token_id, call_context).is_some() { - panic!( - "duplicate token_id={} in embedding server requests", - token_id - ) - } - self.metrics.active_http_calls.increment(1); + if let Err(e) = self.http_call(call_args, call_context) { + self.send_server_error(ServerError::HttpDispatch(e), None); + } Action::Pause } @@ -1130,7 +1074,10 @@ impl HttpContext for StreamContext { let chat_completions_data = match body_str.split_once("data: ") { Some((_, chat_completions_data)) => chat_completions_data, None => { - self.send_server_error(String::from("parsing error in streaming data"), None); + self.send_server_error( + ServerError::LogicError(String::from("parsing error in streaming data")), + None, + ); return Action::Pause; } }; @@ -1141,7 +1088,9 @@ impl HttpContext for StreamContext { Err(_) => { if chat_completions_data != "[NONE]" { self.send_server_error( - String::from("error in streaming response"), + ServerError::LogicError(String::from( + "error in streaming response", + )), None, ); return Action::Continue; @@ -1168,14 +1117,7 @@ impl HttpContext for StreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - self.send_server_error( - format!( - "error in non-streaming response: {}\n response was={}", - e, - String::from_utf8(body).unwrap() - ), - None, - ); + self.send_server_error(ServerError::Deserialization(e), None); return Action::Pause; } }; @@ -1260,7 +1202,11 @@ impl Context for StreamContext { body_size: usize, _num_trailers: usize, ) { - let callout_context = self.callouts.remove(&token_id).expect("invalid token_id"); + let callout_context = self + .callouts + .get_mut() + .remove(&token_id) + .expect("invalid token_id"); self.metrics.active_http_calls.increment(-1); if let Some(body) = self.get_http_call_response_body(0, body_size) { @@ -1284,9 +1230,21 @@ impl Context for StreamContext { } } else { self.send_server_error( - String::from("No response body in inline HTTP request"), + ServerError::LogicError(String::from("No response body in inline HTTP request")), None, ); } } } + +impl Client for StreamContext { + type CallContext = StreamCallContext; + + fn callouts(&self) -> &RefCell> { + &self.callouts + } + + fn active_http_calls(&self) -> &crate::stats::Gauge { + &self.metrics.active_http_calls + } +} diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index 6b36347f..a7f7ae1d 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -571,6 +571,7 @@ fn request_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -589,15 +590,14 @@ fn request_ratelimited() { .execute_and_expect(ReturnType::None) .unwrap(); - let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")]; let body_text = String::from("test body"); module .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) - .expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders)) - .returning(Some(response_headers_with_200)) + .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) + .returning(Some("200")) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) @@ -679,6 +679,7 @@ fn request_not_ratelimited() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("api_server"), Some(vec![ @@ -697,16 +698,14 @@ fn request_not_ratelimited() { .execute_and_expect(ReturnType::None) .unwrap(); - let response_headers_with_200 = vec![(":status", "200"), ("content-type", "application/json")]; - let body_text = String::from("test body"); module .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) .expect_metric_increment("active_http_calls", -1) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&body_text)) - .expect_get_header_map_pairs(Some(MapType::HttpCallResponseHeaders)) - .returning(Some(response_headers_with_200)) + .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) + .returning(Some("200")) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)