From 3e7f7be838e242fc97a157fe7d31522075462b5b Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 17 Oct 2024 17:59:59 -0700 Subject: [PATCH] Code refactor and some improvements - see description - this is follow up to PR#190 - revert rename of files - bring in fix for panic from https://github.com/katanemo/arch/pull/183 - --- crates/common/src/configuration.rs | 14 ------ ...lm_filter_context.rs => filter_context.rs} | 25 +++++----- crates/llm_gateway/src/lib.rs | 8 ++-- ...lm_stream_context.rs => stream_context.rs} | 12 ++--- ...pt_filter_context.rs => filter_context.rs} | 47 +++++++------------ crates/prompt_gateway/src/lib.rs | 8 ++-- ...pt_stream_context.rs => stream_context.rs} | 14 +++--- 7 files changed, 49 insertions(+), 79 deletions(-) rename crates/llm_gateway/src/{llm_filter_context.rs => filter_context.rs} (80%) rename crates/llm_gateway/src/{llm_stream_context.rs => stream_context.rs} (98%) rename crates/prompt_gateway/src/{prompt_filter_context.rs => filter_context.rs} (86%) rename crates/prompt_gateway/src/{prompt_stream_context.rs => stream_context.rs} (99%) diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 63ab156c..293dad09 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -229,20 +229,6 @@ mod test { let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); assert_eq!(config.version, "v0.1"); - let open_ai_provider = config - .llm_providers - .iter() - .find(|p| p.name.to_lowercase() == "openai") - .unwrap(); - assert_eq!(open_ai_provider.name.to_lowercase(), "openai"); - assert_eq!( - open_ai_provider.access_key, - Some("OPENAI_API_KEY".to_string()) - ); - assert_eq!(open_ai_provider.model, "gpt-4o"); - assert_eq!(open_ai_provider.default, Some(true)); - assert_eq!(open_ai_provider.stream, Some(true)); - let prompt_guards = config.prompt_guards.as_ref().unwrap(); let input_guards = &prompt_guards.input_guards; let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); diff --git a/crates/llm_gateway/src/llm_filter_context.rs b/crates/llm_gateway/src/filter_context.rs similarity index 80% rename from crates/llm_gateway/src/llm_filter_context.rs rename to crates/llm_gateway/src/filter_context.rs index e1ed2620..be80c390 100644 --- a/crates/llm_gateway/src/llm_filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,4 +1,4 @@ -use crate::llm_stream_context::LlmGatewayStreamContext; +use crate::stream_context::StreamContext; use common::configuration::Configuration; use common::http::Client; use common::llm_providers::LlmProviders; @@ -28,19 +28,19 @@ impl WasmMetrics { } #[derive(Debug)] -pub struct FilterCallContext {} +pub struct CallContext {} #[derive(Debug)] -pub struct LlmGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: RefCell>, + callouts: RefCell>, llm_providers: Option>, } -impl LlmGatewayFilterContext { - pub fn new() -> LlmGatewayFilterContext { - LlmGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, @@ -48,8 +48,8 @@ impl LlmGatewayFilterContext { } } -impl Client for LlmGatewayFilterContext { - type CallContext = FilterCallContext; +impl Client for FilterContext { + type CallContext = CallContext; fn callouts(&self) -> &RefCell> { &self.callouts @@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext { } } -impl Context for LlmGatewayFilterContext {} +impl Context for FilterContext {} // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for LlmGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - Some(Box::new(LlmGatewayStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone( diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index 766d32bb..e2ad9025 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use llm_filter_context::LlmGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod llm_filter_context; -mod llm_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(LlmGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/stream_context.rs similarity index 98% rename from crates/llm_gateway/src/llm_stream_context.rs rename to crates/llm_gateway/src/stream_context.rs index 6c585a72..e1790552 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::llm_filter_context::WasmMetrics; +use crate::filter_context::WasmMetrics; use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ToolCall, ToolCallState, @@ -34,7 +34,7 @@ pub enum ServerError { BadRequest { why: String }, } -pub struct LlmGatewayStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, tool_calls: Option>, @@ -52,10 +52,10 @@ pub struct LlmGatewayStreamContext { request_id: Option, } -impl LlmGatewayStreamContext { +impl StreamContext { #[allow(clippy::too_many_arguments)] pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { - LlmGatewayStreamContext { + StreamContext { context_id, metrics, chat_completions_request: None, @@ -160,7 +160,7 @@ impl LlmGatewayStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for LlmGatewayStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -418,4 +418,4 @@ impl HttpContext for LlmGatewayStreamContext { } } -impl Context for LlmGatewayStreamContext {} +impl Context for StreamContext {} diff --git a/crates/prompt_gateway/src/prompt_filter_context.rs b/crates/prompt_gateway/src/filter_context.rs similarity index 86% rename from crates/prompt_gateway/src/prompt_filter_context.rs rename to crates/prompt_gateway/src/filter_context.rs index 0c25ee5c..655b391f 100644 --- a/crates/prompt_gateway/src/prompt_filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,6 +1,6 @@ -use crate::prompt_stream_context::PromptStreamContext; +use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; +use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_INTERNAL_CLUSTER_NAME; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; @@ -10,7 +10,6 @@ use common::embeddings::{ }; use common::http::CallArgs; use common::http::Client; -use common::llm_providers::LlmProviders; use common::stats::Gauge; use common::stats::IncrementingMetric; use log::debug; @@ -45,31 +44,27 @@ pub struct FilterCallContext { } #[derive(Debug)] -pub struct PromptGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, overrides: Rc>, system_prompt: Rc>, prompt_targets: Rc>, - mode: GatewayMode, prompt_guards: Rc, - llm_providers: Option>, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, } -impl PromptGatewayFilterContext { - pub fn new() -> PromptGatewayFilterContext { - PromptGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), system_prompt: Rc::new(None), prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - mode: GatewayMode::Prompt, - llm_providers: None, embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), } @@ -117,7 +112,7 @@ impl PromptGatewayFilterContext { Duration::from_secs(60), ); - let call_context = crate::prompt_filter_context::FilterCallContext { + let call_context = crate::filter_context::FilterCallContext { prompt_target_name: String::from(prompt_target_name), embedding_type, }; @@ -194,7 +189,7 @@ impl PromptGatewayFilterContext { } } -impl Client for PromptGatewayFilterContext { +impl Client for FilterContext { type CallContext = FilterCallContext; fn callouts(&self) -> &RefCell> { @@ -206,7 +201,7 @@ impl Client for PromptGatewayFilterContext { } } -impl Context for PromptGatewayFilterContext { +impl Context for FilterContext { fn on_http_call_response( &mut self, token_id: u32, @@ -235,7 +230,7 @@ impl Context for PromptGatewayFilterContext { } // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for PromptGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -254,17 +249,11 @@ impl RootContext for PromptGatewayFilterContext { } self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); - self.mode = config.mode.unwrap_or_default(); if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) } - match config.llm_providers.try_into() { - Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), - Err(err) => panic!("{err}"), - } - true } @@ -274,12 +263,11 @@ impl RootContext for PromptGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store = match self.mode { - GatewayMode::Llm => None, - GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), + let embedding_store = match self.embeddings_store.as_ref() { + None => return None, + Some(store) => Some(Rc::clone(store)), }; - Some(Box::new(PromptStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), @@ -300,11 +288,8 @@ impl RootContext for PromptGatewayFilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: {:?}", self.mode); - if self.mode == GatewayMode::Prompt { - self.process_prompt_targets(); - } - + debug!("starting up arch filter in mode: prompt gateway mode"); + self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); } } diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 75edea5d..e2ad9025 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use prompt_filter_context::PromptGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod prompt_filter_context; -mod prompt_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(PromptGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/prompt_gateway/src/prompt_stream_context.rs b/crates/prompt_gateway/src/stream_context.rs similarity index 99% rename from crates/prompt_gateway/src/prompt_stream_context.rs rename to crates/prompt_gateway/src/stream_context.rs index d208f5e8..97c67974 100644 --- a/crates/prompt_gateway/src/prompt_stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, @@ -81,7 +81,7 @@ pub enum ServerError { NoMessagesFound { why: String }, } -pub struct PromptStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, system_prompt: Rc>, @@ -102,7 +102,7 @@ pub struct PromptStreamContext { request_id: Option, } -impl PromptStreamContext { +impl StreamContext { #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, @@ -113,7 +113,7 @@ impl PromptStreamContext { overrides: Rc>, embeddings_store: Option>, ) -> Self { - PromptStreamContext { + StreamContext { context_id, metrics, system_prompt, @@ -1031,7 +1031,7 @@ impl PromptStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for PromptStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -1346,7 +1346,7 @@ impl HttpContext for PromptStreamContext { } } -impl Context for PromptStreamContext { +impl Context for StreamContext { fn on_http_call_response( &mut self, token_id: u32, @@ -1392,7 +1392,7 @@ impl Context for PromptStreamContext { } } -impl Client for PromptStreamContext { +impl Client for StreamContext { type CallContext = StreamCallContext; fn callouts(&self) -> &RefCell> {