diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d846666a..ac33c76c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -19,12 +19,12 @@ jobs: - name: Setup | Install wasm toolchain run: rustup target add wasm32-wasi - - name: Build wasm module for prompt_gateway - run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi - - name: Run Tests on common crate run: cd crates/common && cargo test + - name: Build wasm module for prompt_gateway + run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi + - name: Run Tests on prompt_gateway crate run: cd crates/prompt_gateway && cargo test diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index a362da9c..4651c610 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -14,6 +14,7 @@ derivative = "2.2.0" thiserror = "1.0.64" tiktoken-rs = "0.5.9" rand = "0.8.5" +serde_json = "1.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 63ab156c..293dad09 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -229,20 +229,6 @@ mod test { let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); assert_eq!(config.version, "v0.1"); - let open_ai_provider = config - .llm_providers - .iter() - .find(|p| p.name.to_lowercase() == "openai") - .unwrap(); - assert_eq!(open_ai_provider.name.to_lowercase(), "openai"); - assert_eq!( - open_ai_provider.access_key, - Some("OPENAI_API_KEY".to_string()) - ); - assert_eq!(open_ai_provider.model, "gpt-4o"); - assert_eq!(open_ai_provider.default, Some(true)); - assert_eq!(open_ai_provider.stream, Some(true)); - let prompt_guards = config.prompt_guards.as_ref().unwrap(); let input_guards = &prompt_guards.input_guards; let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 76244f6b..ce119eab 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -12,7 +12,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; +pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const ARCH_STATE_HEADER: &str = "x-arch-state"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs new file mode 100644 index 00000000..fd634915 --- /dev/null +++ b/crates/common/src/errors.rs @@ -0,0 +1,39 @@ +use proxy_wasm::types::Status; + +use crate::ratelimit; + +#[derive(thiserror::Error, Debug)] +pub enum ClientError { + #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] + DispatchError { + upstream_name: String, + path: String, + internal_status: Status, + }, +} + +#[derive(thiserror::Error, Debug)] +pub enum ServerError { + #[error(transparent)] + HttpDispatch(ClientError), + #[error(transparent)] + Deserialization(serde_json::Error), + #[error(transparent)] + Serialization(serde_json::Error), + #[error("{0}")] + LogicError(String), + #[error("upstream error response authority={authority}, path={path}, status={status}")] + Upstream { + authority: String, + path: String, + status: String, + }, + #[error("jailbreak detected: {0}")] + Jailbreak(String), + #[error("{why}")] + NoMessagesFound { why: String }, + #[error(transparent)] + ExceededRatelimit(ratelimit::Error), + #[error("{why}")] + BadRequest { why: String }, +} diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 21380b0f..842818e2 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,4 @@ -use crate::stats::{Gauge, IncrementingMetric}; +use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; @@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> { } } -#[derive(thiserror::Error, Debug)] -pub enum ClientError { - #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] - DispatchError { - upstream_name: String, - path: String, - internal_status: Status, - }, -} - pub trait Client: Context { type CallContext: Debug; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 27a51803..c23443ca 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -10,3 +10,4 @@ pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; +pub mod errors; diff --git a/crates/llm_gateway/Cargo.lock b/crates/llm_gateway/Cargo.lock index 35182863..19ce3747 100644 --- a/crates/llm_gateway/Cargo.lock +++ b/crates/llm_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/llm_gateway/src/llm_filter_context.rs b/crates/llm_gateway/src/filter_context.rs similarity index 80% rename from crates/llm_gateway/src/llm_filter_context.rs rename to crates/llm_gateway/src/filter_context.rs index e1ed2620..be80c390 100644 --- a/crates/llm_gateway/src/llm_filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,4 +1,4 @@ -use crate::llm_stream_context::LlmGatewayStreamContext; +use crate::stream_context::StreamContext; use common::configuration::Configuration; use common::http::Client; use common::llm_providers::LlmProviders; @@ -28,19 +28,19 @@ impl WasmMetrics { } #[derive(Debug)] -pub struct FilterCallContext {} +pub struct CallContext {} #[derive(Debug)] -pub struct LlmGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: RefCell>, + callouts: RefCell>, llm_providers: Option>, } -impl LlmGatewayFilterContext { - pub fn new() -> LlmGatewayFilterContext { - LlmGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, @@ -48,8 +48,8 @@ impl LlmGatewayFilterContext { } } -impl Client for LlmGatewayFilterContext { - type CallContext = FilterCallContext; +impl Client for FilterContext { + type CallContext = CallContext; fn callouts(&self) -> &RefCell> { &self.callouts @@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext { } } -impl Context for LlmGatewayFilterContext {} +impl Context for FilterContext {} // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for LlmGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - Some(Box::new(LlmGatewayStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone( diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index 766d32bb..e2ad9025 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use llm_filter_context::LlmGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod llm_filter_context; -mod llm_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(LlmGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/stream_context.rs similarity index 95% rename from crates/llm_gateway/src/llm_stream_context.rs rename to crates/llm_gateway/src/stream_context.rs index 6c585a72..655f76ff 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::llm_filter_context::WasmMetrics; +use crate::filter_context::WasmMetrics; use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ToolCall, ToolCallState, @@ -8,6 +8,7 @@ use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE, }; +use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::ratelimit::Header; use common::{ratelimit, routing, tokenizer}; @@ -22,25 +23,12 @@ use std::rc::Rc; use common::stats::IncrementingMetric; -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - Deserialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), - #[error("{why}")] - BadRequest { why: String }, -} - -pub struct LlmGatewayStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, tool_calls: Option>, tool_call_response: Option, arch_state: Option>, - request_body_size: usize, ratelimit_selector: Option
, streaming_response: bool, user_prompt: Option, @@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext { request_id: Option, } -impl LlmGatewayStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { - LlmGatewayStreamContext { + StreamContext { context_id, metrics, chat_completions_request: None, tool_calls: None, tool_call_response: None, arch_state: None, - request_body_size: 0, ratelimit_selector: None, streaming_response: false, user_prompt: None, @@ -160,7 +146,7 @@ impl LlmGatewayStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for LlmGatewayStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Continue; } - self.request_body_size = body_size; - // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; // remove metadata from the request body deserialized_body.metadata = None; @@ -418,4 +401,4 @@ impl HttpContext for LlmGatewayStreamContext { } } -impl Context for LlmGatewayStreamContext {} +impl Context for StreamContext {} diff --git a/crates/prompt_gateway/Cargo.lock b/crates/prompt_gateway/Cargo.lock index 63de3b3f..7679b301 100644 --- a/crates/prompt_gateway/Cargo.lock +++ b/crates/prompt_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/prompt_gateway/src/prompt_filter_context.rs b/crates/prompt_gateway/src/filter_context.rs similarity index 86% rename from crates/prompt_gateway/src/prompt_filter_context.rs rename to crates/prompt_gateway/src/filter_context.rs index 0c25ee5c..655b391f 100644 --- a/crates/prompt_gateway/src/prompt_filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,6 +1,6 @@ -use crate::prompt_stream_context::PromptStreamContext; +use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; +use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_INTERNAL_CLUSTER_NAME; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; @@ -10,7 +10,6 @@ use common::embeddings::{ }; use common::http::CallArgs; use common::http::Client; -use common::llm_providers::LlmProviders; use common::stats::Gauge; use common::stats::IncrementingMetric; use log::debug; @@ -45,31 +44,27 @@ pub struct FilterCallContext { } #[derive(Debug)] -pub struct PromptGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, overrides: Rc>, system_prompt: Rc>, prompt_targets: Rc>, - mode: GatewayMode, prompt_guards: Rc, - llm_providers: Option>, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, } -impl PromptGatewayFilterContext { - pub fn new() -> PromptGatewayFilterContext { - PromptGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), system_prompt: Rc::new(None), prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - mode: GatewayMode::Prompt, - llm_providers: None, embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), } @@ -117,7 +112,7 @@ impl PromptGatewayFilterContext { Duration::from_secs(60), ); - let call_context = crate::prompt_filter_context::FilterCallContext { + let call_context = crate::filter_context::FilterCallContext { prompt_target_name: String::from(prompt_target_name), embedding_type, }; @@ -194,7 +189,7 @@ impl PromptGatewayFilterContext { } } -impl Client for PromptGatewayFilterContext { +impl Client for FilterContext { type CallContext = FilterCallContext; fn callouts(&self) -> &RefCell> { @@ -206,7 +201,7 @@ impl Client for PromptGatewayFilterContext { } } -impl Context for PromptGatewayFilterContext { +impl Context for FilterContext { fn on_http_call_response( &mut self, token_id: u32, @@ -235,7 +230,7 @@ impl Context for PromptGatewayFilterContext { } // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for PromptGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -254,17 +249,11 @@ impl RootContext for PromptGatewayFilterContext { } self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); - self.mode = config.mode.unwrap_or_default(); if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) } - match config.llm_providers.try_into() { - Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), - Err(err) => panic!("{err}"), - } - true } @@ -274,12 +263,11 @@ impl RootContext for PromptGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store = match self.mode { - GatewayMode::Llm => None, - GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), + let embedding_store = match self.embeddings_store.as_ref() { + None => return None, + Some(store) => Some(Rc::clone(store)), }; - Some(Box::new(PromptStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), @@ -300,11 +288,8 @@ impl RootContext for PromptGatewayFilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: {:?}", self.mode); - if self.mode == GatewayMode::Prompt { - self.process_prompt_targets(); - } - + debug!("starting up arch filter in mode: prompt gateway mode"); + self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); } } diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 75edea5d..e2ad9025 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use prompt_filter_context::PromptGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod prompt_filter_context; -mod prompt_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(PromptGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/prompt_gateway/src/prompt_stream_context.rs b/crates/prompt_gateway/src/stream_context.rs similarity index 99% rename from crates/prompt_gateway/src/prompt_stream_context.rs rename to crates/prompt_gateway/src/stream_context.rs index d208f5e8..602f1629 100644 --- a/crates/prompt_gateway/src/prompt_stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, @@ -21,7 +21,8 @@ use common::consts::{ use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::http::{CallArgs, Client, ClientError}; +use common::errors::ClientError; +use common::http::{CallArgs, Client}; use common::stats::Gauge; use http::StatusCode; use log::{debug, info, warn}; @@ -81,7 +82,7 @@ pub enum ServerError { NoMessagesFound { why: String }, } -pub struct PromptStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, system_prompt: Rc>, @@ -102,8 +103,7 @@ pub struct PromptStreamContext { request_id: Option, } -impl PromptStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new( context_id: u32, metrics: Rc, @@ -113,7 +113,7 @@ impl PromptStreamContext { overrides: Rc>, embeddings_store: Option>, ) -> Self { - PromptStreamContext { + StreamContext { context_id, metrics, system_prompt, @@ -1031,7 +1031,7 @@ impl PromptStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for PromptStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -1094,7 +1094,6 @@ impl HttpContext for PromptStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { @@ -1346,7 +1345,7 @@ impl HttpContext for PromptStreamContext { } } -impl Context for PromptStreamContext { +impl Context for StreamContext { fn on_http_call_response( &mut self, token_id: u32, @@ -1392,7 +1391,7 @@ impl Context for PromptStreamContext { } } -impl Client for PromptStreamContext { +impl Client for StreamContext { type CallContext = StreamCallContext; fn callouts(&self) -> &RefCell> {