Code refactor and some improvements - see description

- this is follow up to PR#190
- revert rename of files
- bring in fix for panic from https://github.com/katanemo/arch/pull/183
-
This commit is contained in:
Adil Hafeez 2024-10-17 17:59:59 -07:00
parent 6cd05572c4
commit 3e7f7be838
7 changed files with 49 additions and 79 deletions

View file

@ -229,20 +229,6 @@ mod test {
let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap();
assert_eq!(config.version, "v0.1"); assert_eq!(config.version, "v0.1");
let open_ai_provider = config
.llm_providers
.iter()
.find(|p| p.name.to_lowercase() == "openai")
.unwrap();
assert_eq!(open_ai_provider.name.to_lowercase(), "openai");
assert_eq!(
open_ai_provider.access_key,
Some("OPENAI_API_KEY".to_string())
);
assert_eq!(open_ai_provider.model, "gpt-4o");
assert_eq!(open_ai_provider.default, Some(true));
assert_eq!(open_ai_provider.stream, Some(true));
let prompt_guards = config.prompt_guards.as_ref().unwrap(); let prompt_guards = config.prompt_guards.as_ref().unwrap();
let input_guards = &prompt_guards.input_guards; let input_guards = &prompt_guards.input_guards;
let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap();

View file

@ -1,4 +1,4 @@
use crate::llm_stream_context::LlmGatewayStreamContext; use crate::stream_context::StreamContext;
use common::configuration::Configuration; use common::configuration::Configuration;
use common::http::Client; use common::http::Client;
use common::llm_providers::LlmProviders; use common::llm_providers::LlmProviders;
@ -28,19 +28,19 @@ impl WasmMetrics {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct FilterCallContext {} pub struct CallContext {}
#[derive(Debug)] #[derive(Debug)]
pub struct LlmGatewayFilterContext { pub struct FilterContext {
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>, callouts: RefCell<HashMap<u32, CallContext>>,
llm_providers: Option<Rc<LlmProviders>>, llm_providers: Option<Rc<LlmProviders>>,
} }
impl LlmGatewayFilterContext { impl FilterContext {
pub fn new() -> LlmGatewayFilterContext { pub fn new() -> FilterContext {
LlmGatewayFilterContext { FilterContext {
callouts: RefCell::new(HashMap::new()), callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()), metrics: Rc::new(WasmMetrics::new()),
llm_providers: None, llm_providers: None,
@ -48,8 +48,8 @@ impl LlmGatewayFilterContext {
} }
} }
impl Client for LlmGatewayFilterContext { impl Client for FilterContext {
type CallContext = FilterCallContext; type CallContext = CallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> { fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
&self.callouts &self.callouts
@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext {
} }
} }
impl Context for LlmGatewayFilterContext {} impl Context for FilterContext {}
// RootContext allows the Rust code to reach into the Envoy Config // RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for LlmGatewayFilterContext { impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool { fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self let config_bytes = self
.get_plugin_configuration() .get_plugin_configuration()
@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext {
context_id context_id
); );
// No StreamContext can be created until the Embedding Store is fully initialized. Some(Box::new(StreamContext::new(
Some(Box::new(LlmGatewayStreamContext::new(
context_id, context_id,
Rc::clone(&self.metrics), Rc::clone(&self.metrics),
Rc::clone( Rc::clone(

View file

@ -1,13 +1,13 @@
use llm_filter_context::LlmGatewayFilterContext; use filter_context::FilterContext;
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use proxy_wasm::types::*; use proxy_wasm::types::*;
mod llm_filter_context; mod filter_context;
mod llm_stream_context; mod stream_context;
proxy_wasm::main! {{ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> { proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(LlmGatewayFilterContext::new()) Box::new(FilterContext::new())
}); });
}} }}

View file

@ -1,4 +1,4 @@
use crate::llm_filter_context::WasmMetrics; use crate::filter_context::WasmMetrics;
use common::common_types::open_ai::{ use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse,
Message, ToolCall, ToolCallState, Message, ToolCall, ToolCallState,
@ -34,7 +34,7 @@ pub enum ServerError {
BadRequest { why: String }, BadRequest { why: String },
} }
pub struct LlmGatewayStreamContext { pub struct StreamContext {
context_id: u32, context_id: u32,
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
tool_calls: Option<Vec<ToolCall>>, tool_calls: Option<Vec<ToolCall>>,
@ -52,10 +52,10 @@ pub struct LlmGatewayStreamContext {
request_id: Option<String>, request_id: Option<String>,
} }
impl LlmGatewayStreamContext { impl StreamContext {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self { pub fn new(context_id: u32, metrics: Rc<WasmMetrics>, llm_providers: Rc<LlmProviders>) -> Self {
LlmGatewayStreamContext { StreamContext {
context_id, context_id,
metrics, metrics,
chat_completions_request: None, chat_completions_request: None,
@ -160,7 +160,7 @@ impl LlmGatewayStreamContext {
} }
// HttpContext is the trait that allows the Rust code to interact with HTTP objects. // HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for LlmGatewayStreamContext { impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response. // the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -418,4 +418,4 @@ impl HttpContext for LlmGatewayStreamContext {
} }
} }
impl Context for LlmGatewayStreamContext {} impl Context for StreamContext {}

View file

@ -1,6 +1,6 @@
use crate::prompt_stream_context::PromptStreamContext; use crate::stream_context::StreamContext;
use common::common_types::EmbeddingType; use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME; use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use common::consts::DEFAULT_EMBEDDING_MODEL; use common::consts::DEFAULT_EMBEDDING_MODEL;
@ -10,7 +10,6 @@ use common::embeddings::{
}; };
use common::http::CallArgs; use common::http::CallArgs;
use common::http::Client; use common::http::Client;
use common::llm_providers::LlmProviders;
use common::stats::Gauge; use common::stats::Gauge;
use common::stats::IncrementingMetric; use common::stats::IncrementingMetric;
use log::debug; use log::debug;
@ -45,31 +44,27 @@ pub struct FilterCallContext {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct PromptGatewayFilterContext { pub struct FilterContext {
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>, callouts: RefCell<HashMap<u32, FilterCallContext>>,
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
system_prompt: Rc<Option<String>>, system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>, prompt_targets: Rc<HashMap<String, PromptTarget>>,
mode: GatewayMode,
prompt_guards: Rc<PromptGuards>, prompt_guards: Rc<PromptGuards>,
llm_providers: Option<Rc<LlmProviders>>,
embeddings_store: Option<Rc<EmbeddingsStore>>, embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore, temp_embeddings_store: EmbeddingsStore,
} }
impl PromptGatewayFilterContext { impl FilterContext {
pub fn new() -> PromptGatewayFilterContext { pub fn new() -> FilterContext {
PromptGatewayFilterContext { FilterContext {
callouts: RefCell::new(HashMap::new()), callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()), metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None), system_prompt: Rc::new(None),
prompt_targets: Rc::new(HashMap::new()), prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None), overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()), prompt_guards: Rc::new(PromptGuards::default()),
mode: GatewayMode::Prompt,
llm_providers: None,
embeddings_store: Some(Rc::new(HashMap::new())), embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(), temp_embeddings_store: HashMap::new(),
} }
@ -117,7 +112,7 @@ impl PromptGatewayFilterContext {
Duration::from_secs(60), Duration::from_secs(60),
); );
let call_context = crate::prompt_filter_context::FilterCallContext { let call_context = crate::filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name), prompt_target_name: String::from(prompt_target_name),
embedding_type, embedding_type,
}; };
@ -194,7 +189,7 @@ impl PromptGatewayFilterContext {
} }
} }
impl Client for PromptGatewayFilterContext { impl Client for FilterContext {
type CallContext = FilterCallContext; type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> { fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
@ -206,7 +201,7 @@ impl Client for PromptGatewayFilterContext {
} }
} }
impl Context for PromptGatewayFilterContext { impl Context for FilterContext {
fn on_http_call_response( fn on_http_call_response(
&mut self, &mut self,
token_id: u32, token_id: u32,
@ -235,7 +230,7 @@ impl Context for PromptGatewayFilterContext {
} }
// RootContext allows the Rust code to reach into the Envoy Config // RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for PromptGatewayFilterContext { impl RootContext for FilterContext {
fn on_configure(&mut self, _: usize) -> bool { fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self let config_bytes = self
.get_plugin_configuration() .get_plugin_configuration()
@ -254,17 +249,11 @@ impl RootContext for PromptGatewayFilterContext {
} }
self.system_prompt = Rc::new(config.system_prompt); self.system_prompt = Rc::new(config.system_prompt);
self.prompt_targets = Rc::new(prompt_targets); self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
if let Some(prompt_guards) = config.prompt_guards { if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards) self.prompt_guards = Rc::new(prompt_guards)
} }
match config.llm_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}
true true
} }
@ -274,12 +263,11 @@ impl RootContext for PromptGatewayFilterContext {
context_id context_id
); );
// No StreamContext can be created until the Embedding Store is fully initialized. let embedding_store = match self.embeddings_store.as_ref() {
let embedding_store = match self.mode { None => return None,
GatewayMode::Llm => None, Some(store) => Some(Rc::clone(store)),
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
}; };
Some(Box::new(PromptStreamContext::new( Some(Box::new(StreamContext::new(
context_id, context_id,
Rc::clone(&self.metrics), Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt), Rc::clone(&self.system_prompt),
@ -300,11 +288,8 @@ impl RootContext for PromptGatewayFilterContext {
} }
fn on_tick(&mut self) { fn on_tick(&mut self) {
debug!("starting up arch filter in mode: {:?}", self.mode); debug!("starting up arch filter in mode: prompt gateway mode");
if self.mode == GatewayMode::Prompt { self.process_prompt_targets();
self.process_prompt_targets();
}
self.set_tick_period(Duration::from_secs(0)); self.set_tick_period(Duration::from_secs(0));
} }
} }

View file

@ -1,13 +1,13 @@
use prompt_filter_context::PromptGatewayFilterContext; use filter_context::FilterContext;
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use proxy_wasm::types::*; use proxy_wasm::types::*;
mod prompt_filter_context; mod filter_context;
mod prompt_stream_context; mod stream_context;
proxy_wasm::main! {{ proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> { proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(PromptGatewayFilterContext::new()) Box::new(FilterContext::new())
}); });
}} }}

View file

@ -1,4 +1,4 @@
use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use acap::cos; use acap::cos;
use common::common_types::open_ai::{ use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
@ -81,7 +81,7 @@ pub enum ServerError {
NoMessagesFound { why: String }, NoMessagesFound { why: String },
} }
pub struct PromptStreamContext { pub struct StreamContext {
context_id: u32, context_id: u32,
metrics: Rc<WasmMetrics>, metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>, system_prompt: Rc<Option<String>>,
@ -102,7 +102,7 @@ pub struct PromptStreamContext {
request_id: Option<String>, request_id: Option<String>,
} }
impl PromptStreamContext { impl StreamContext {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
context_id: u32, context_id: u32,
@ -113,7 +113,7 @@ impl PromptStreamContext {
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
embeddings_store: Option<Rc<EmbeddingsStore>>, embeddings_store: Option<Rc<EmbeddingsStore>>,
) -> Self { ) -> Self {
PromptStreamContext { StreamContext {
context_id, context_id,
metrics, metrics,
system_prompt, system_prompt,
@ -1031,7 +1031,7 @@ impl PromptStreamContext {
} }
// HttpContext is the trait that allows the Rust code to interact with HTTP objects. // HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for PromptStreamContext { impl HttpContext for StreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response. // the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
@ -1346,7 +1346,7 @@ impl HttpContext for PromptStreamContext {
} }
} }
impl Context for PromptStreamContext { impl Context for StreamContext {
fn on_http_call_response( fn on_http_call_response(
&mut self, &mut self,
token_id: u32, token_id: u32,
@ -1392,7 +1392,7 @@ impl Context for PromptStreamContext {
} }
} }
impl Client for PromptStreamContext { impl Client for StreamContext {
type CallContext = StreamCallContext; type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> { fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {