diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index b4a21a7c..82ab4213 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -1,7 +1,8 @@ use crate::consts::{DEFAULT_EMBEDDING_MODEL, MODEL_SERVER_NAME}; +use crate::http::{CallArgs, Client}; use crate::llm_providers::LlmProviders; use crate::ratelimit; -use crate::stats::{Counter, Gauge, RecordingMetric}; +use crate::stats::{Counter, Gauge, IncrementingMetric}; use crate::stream_context::StreamContext; use log::debug; use proxy_wasm::traits::*; @@ -11,10 +12,10 @@ use public_types::configuration::{Configuration, Overrides, PromptGuards, Prompt use public_types::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use serde_json::to_string; +use std::cell::RefCell; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::rc::Rc; -use std::sync::{OnceLock, RwLock}; use std::time::Duration; #[derive(Copy, Clone, Debug)] @@ -32,102 +33,74 @@ impl WasmMetrics { } } -#[derive(Debug)] -struct CallContext { - prompt_target: String, - embedding_type: EmbeddingType, -} - pub type EmbeddingTypeMap = HashMap>; +pub type EmbeddingsStore = HashMap; + +#[derive(Debug)] +pub struct FilterCallContext { + pub prompt_target_name: String, + pub embedding_type: EmbeddingType, +} #[derive(Debug)] pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: HashMap, + callouts: RefCell>, overrides: Rc>, - prompt_targets: Rc>>, + prompt_targets: Rc>, prompt_guards: Rc, llm_providers: Option>, -} - -pub fn embeddings_store() -> &'static RwLock> { - static EMBEDDINGS: OnceLock>> = OnceLock::new(); - EMBEDDINGS.get_or_init(|| { - let embeddings: HashMap = HashMap::new(); - RwLock::new(embeddings) - }) + embeddings_store: Option>, + temp_embeddings_store: EmbeddingsStore, } impl FilterContext { pub fn new() -> FilterContext { FilterContext { - callouts: HashMap::new(), + callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), - prompt_targets: Rc::new(RwLock::new(HashMap::new())), + prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), llm_providers: None, + embeddings_store: None, + temp_embeddings_store: HashMap::new(), } } - fn process_prompt_targets(&mut self) { - let prompt_targets = match self.prompt_targets.read() { - Ok(prompt_targets) => prompt_targets, - Err(e) => { - panic!("Error reading prompt targets: {:?}", e); - } - }; - for values in prompt_targets.iter() { - let prompt_target = &values.1; - - // schedule embeddings call for prompt target name - let token_id = self.schedule_embeddings_call(prompt_target.name.clone()); - if self - .callouts - .insert(token_id, { - CallContext { - prompt_target: prompt_target.name.clone(), - embedding_type: EmbeddingType::Name, - } - }) - .is_some() - { - panic!("duplicate token_id") - } - - // schedule embeddings call for prompt target description - let token_id = self.schedule_embeddings_call(prompt_target.description.clone()); - if self - .callouts - .insert(token_id, { - CallContext { - prompt_target: prompt_target.name.clone(), - embedding_type: EmbeddingType::Description, - } - }) - .is_some() - { - panic!("duplicate token_id") - } - - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); + fn process_prompt_targets(&self) { + for values in self.prompt_targets.iter() { + let prompt_target = values.1; + self.schedule_embeddings_call( + &prompt_target.name, + &prompt_target.name, + EmbeddingType::Name, + ); + self.schedule_embeddings_call( + &prompt_target.name, + &prompt_target.description, + EmbeddingType::Description, + ); } } - fn schedule_embeddings_call(&self, input: String) -> u32 { + fn schedule_embeddings_call( + &self, + prompt_target_name: &str, + input: &str, + embedding_type: EmbeddingType, + ) { let embeddings_input = CreateEmbeddingRequest { - input: Box::new(CreateEmbeddingRequestInput::String(input)), + input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))), model: String::from(DEFAULT_EMBEDDING_MODEL), encoding_format: None, dimensions: None, user: None, }; + let json_data = serde_json::to_string(&embeddings_input).unwrap(); - let json_data = to_string(&embeddings_input).unwrap(); - let token_id = match self.dispatch_http_call( + let call_args = CallArgs::new( MODEL_SERVER_NAME, vec![ (":method", "POST"), @@ -139,16 +112,16 @@ impl FilterContext { Some(json_data.as_bytes()), vec![], Duration::from_secs(60), - ) { - Ok(token_id) => token_id, - Err(e) => { - panic!( - "Error dispatching HTTP call: {}, error: {:?}", - MODEL_SERVER_NAME, e - ); - } + ); + + let call_context = crate::filter_context::FilterCallContext { + prompt_target_name: String::from(prompt_target_name), + embedding_type, }; - token_id + + if let Err(error) = self.http_call(call_args, call_context) { + panic!("{error}") + } } fn embedding_response_handler( @@ -157,40 +130,79 @@ impl FilterContext { embedding_type: EmbeddingType, prompt_target_name: String, ) { - let prompt_targets = self.prompt_targets.read().unwrap(); - let prompt_target = prompt_targets.get(&prompt_target_name).unwrap(); - if let Some(body) = self.get_http_call_response_body(0, body_size) { - if !body.is_empty() { - let mut embedding_response: CreateEmbeddingResponse = - match serde_json::from_slice(&body) { - Ok(response) => response, - Err(e) => { - panic!( - "Error deserializing embedding response. body: {:?}: {:?}", - String::from_utf8(body).unwrap(), - e - ); - } - }; + let prompt_target = self + .prompt_targets + .get(&prompt_target_name) + .unwrap_or_else(|| { + panic!( + "Received embeddings response for unknown prompt target name={}", + prompt_target_name + ) + }); - let embeddings = embedding_response.data.remove(0).embedding; - log::info!( + let body = self + .get_http_call_response_body(0, body_size) + .expect("No body in response"); + if !body.is_empty() { + let mut embedding_response: CreateEmbeddingResponse = + match serde_json::from_slice(&body) { + Ok(response) => response, + Err(e) => { + panic!( + "Error deserializing embedding response. body: {:?}: {:?}", + String::from_utf8(body).unwrap(), + e + ); + } + }; + + let embeddings = embedding_response.data.remove(0).embedding; + debug!( "Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}", prompt_target.name, prompt_target.description, embedding_type ); - embeddings_store().write().unwrap().insert( - prompt_target.name.clone(), - HashMap::from([(embedding_type, embeddings)]), - ); + let entry = self.temp_embeddings_store.entry(prompt_target_name); + match entry { + Entry::Occupied(_) => { + entry.and_modify(|e| { + if let Entry::Vacant(e) = e.entry(embedding_type) { + e.insert(embeddings); + } else { + panic!( + "Duplicate {:?} for prompt target with name=\"{}\"", + &embedding_type, prompt_target.name + ) + } + }); + } + Entry::Vacant(_) => { + entry.or_insert(HashMap::from([(embedding_type, embeddings)])); + } + } + + if self.prompt_targets.len() == self.temp_embeddings_store.len() { + self.embeddings_store = + Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store))) } - } else { - panic!("No body in response"); } } } + +impl Client for FilterContext { + type CallContext = FilterCallContext; + + fn callouts(&self) -> &RefCell> { + &self.callouts + } + + fn active_http_calls(&self) -> &Gauge { + &self.metrics.active_http_calls + } +} + impl Context for FilterContext { fn on_http_call_response( &mut self, @@ -203,16 +215,18 @@ impl Context for FilterContext { "filter_context: on_http_call_response called with token_id: {:?}", token_id ); - let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); + let callout_data = self + .callouts + .borrow_mut() + .remove(&token_id) + .expect("invalid token_id"); - self.metrics - .active_http_calls - .record(self.callouts.len().try_into().unwrap()); + self.metrics.active_http_calls.increment(-1); self.embedding_response_handler( body_size, callout_data.embedding_type, - callout_data.prompt_target, + callout_data.prompt_target_name, ) } } @@ -231,12 +245,11 @@ impl RootContext for FilterContext { self.overrides = Rc::new(config.overrides); + let mut prompt_targets = HashMap::new(); for pt in config.prompt_targets { - self.prompt_targets - .write() - .unwrap() - .insert(pt.name.clone(), pt.clone()); + prompt_targets.insert(pt.name.clone(), pt.clone()); } + self.prompt_targets = Rc::new(prompt_targets); ratelimit::ratelimits(config.ratelimits); @@ -257,6 +270,10 @@ impl RootContext for FilterContext { "||| create_http_context called with context_id: {:?} |||", context_id ); + + // No StreamContext can be created until the Embedding Store is fully initialized. + self.embeddings_store.as_ref()?; + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), @@ -268,6 +285,11 @@ impl RootContext for FilterContext { .as_ref() .expect("LLM Providers must exist when Streams are being created"), ), + Rc::clone( + self.embeddings_store + .as_ref() + .expect("Embeddings Store must exist when StreamContext is being constructed"), + ), ))) } diff --git a/arch/src/http.rs b/arch/src/http.rs new file mode 100644 index 00000000..dfa683f0 --- /dev/null +++ b/arch/src/http.rs @@ -0,0 +1,84 @@ +use crate::stats::{Gauge, IncrementingMetric}; +use log::debug; +use proxy_wasm::{traits::Context, types::Status}; +use std::{cell::RefCell, collections::HashMap, fmt::Debug, time::Duration}; + +#[derive(Debug)] +pub struct CallArgs<'a> { + upstream: &'a str, + headers: Vec<(&'a str, &'a str)>, + body: Option<&'a [u8]>, + trailers: Vec<(&'a str, &'a str)>, + timeout: Duration, +} + +impl<'a> CallArgs<'a> { + pub fn new( + upstream: &'a str, + headers: Vec<(&'a str, &'a str)>, + body: Option<&'a [u8]>, + trailers: Vec<(&'a str, &'a str)>, + timeout: Duration, + ) -> Self { + CallArgs { + upstream, + headers, + body, + trailers, + timeout, + } + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ClientError { + #[error("Error dispatching HTTP call to `{upstream_name}`, error: {internal_status:?}")] + DispatchError { + upstream_name: String, + internal_status: Status, + }, +} + +pub trait Client: Context { + type CallContext: Debug; + + fn http_call( + &self, + call_args: CallArgs, + call_context: Self::CallContext, + ) -> Result<(), ClientError> { + debug!( + "dispatching http call with args={:?} context={:?}", + call_args, call_context + ); + + match self.dispatch_http_call( + call_args.upstream, + call_args.headers, + call_args.body, + call_args.trailers, + call_args.timeout, + ) { + Ok(id) => { + self.add_call_context(id, call_context); + Ok(()) + } + Err(status) => Err(ClientError::DispatchError { + upstream_name: String::from(call_args.upstream), + internal_status: status.clone(), + }), + } + } + + fn add_call_context(&self, id: u32, call_context: Self::CallContext) { + let callouts = self.callouts(); + if callouts.borrow_mut().insert(id, call_context).is_some() { + panic!("Duplicate http call with id={}", id); + } + self.active_http_calls().increment(1); + } + + fn callouts(&self) -> &RefCell>; + + fn active_http_calls(&self) -> &Gauge; +} diff --git a/arch/src/lib.rs b/arch/src/lib.rs index a6449695..8d8c0b90 100644 --- a/arch/src/lib.rs +++ b/arch/src/lib.rs @@ -4,6 +4,7 @@ use proxy_wasm::types::*; mod consts; mod filter_context; +mod http; mod llm_providers; mod ratelimit; mod routing; diff --git a/arch/src/llm_providers.rs b/arch/src/llm_providers.rs index 75d57817..65cd0d04 100644 --- a/arch/src/llm_providers.rs +++ b/arch/src/llm_providers.rs @@ -18,7 +18,7 @@ impl LlmProviders { } pub fn get(&self, name: &str) -> Option> { - self.providers.get(name).map(|rc| rc.clone()) + self.providers.get(name).cloned() } } diff --git a/arch/src/stats.rs b/arch/src/stats.rs index 250e9017..527713f3 100644 --- a/arch/src/stats.rs +++ b/arch/src/stats.rs @@ -24,6 +24,7 @@ pub trait IncrementingMetric: Metric { } } +#[allow(unused)] pub trait RecordingMetric: Metric { fn record(&self, value: u64) { match hostcalls::record_metric(self.id(), value) { diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index e91720b7..c6a356c5 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -4,7 +4,7 @@ use crate::consts::{ DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; -use crate::filter_context::{embeddings_store, WasmMetrics}; +use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use crate::llm_providers::LlmProviders; use crate::ratelimit::Header; use crate::stats::IncrementingMetric; @@ -31,7 +31,6 @@ use public_types::embeddings::{ use std::collections::HashMap; use std::num::NonZero; use std::rc::Rc; -use std::sync::RwLock; use std::time::Duration; enum ResponseHandlerType { @@ -56,7 +55,8 @@ pub struct CallContext { pub struct StreamContext { context_id: u32, metrics: Rc, - prompt_targets: Rc>>, + prompt_targets: Rc>, + embeddings_store: Rc, overrides: Rc>, callouts: HashMap, ratelimit_selector: Option
, @@ -72,15 +72,17 @@ impl StreamContext { pub fn new( context_id: u32, metrics: Rc, - prompt_targets: Rc>>, + prompt_targets: Rc>, prompt_guards: Rc, overrides: Rc>, llm_providers: Rc, + embeddings_store: Rc, ) -> Self { StreamContext { context_id, metrics, prompt_targets, + embeddings_store, callouts: HashMap::new(), ratelimit_selector: None, streaming_response: false, @@ -174,35 +176,21 @@ impl StreamContext { prompt_embeddings_vector.len() ); - let prompt_target_embeddings = match embeddings_store().read() { - Ok(embeddings) => embeddings, - Err(e) => { - return self - .send_server_error(format!("Error reading embeddings store: {:?}", e), None); - } - }; - - let prompt_targets = match self.prompt_targets.read() { - Ok(prompt_targets) => prompt_targets, - Err(e) => { - self.send_server_error(format!("Error reading prompt targets: {:?}", e), None); - return; - } - }; - - let prompt_target_names = prompt_targets + let prompt_target_names = self + .prompt_targets .iter() // exclude default target .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) .map(|(name, _)| name.clone()) .collect(); - let similarity_scores: Vec<(String, f64)> = prompt_targets + let similarity_scores: Vec<(String, f64)> = self + .prompt_targets .iter() // exclude default prompt target .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) .map(|(prompt_name, _)| { - let pte = match prompt_target_embeddings.get(prompt_name) { + let pte = match self.embeddings_store.get(prompt_name) { Some(embeddings) => embeddings, None => { warn!( @@ -373,8 +361,6 @@ impl StreamContext { debug!("checking for default prompt target"); if let Some(default_prompt_target) = self .prompt_targets - .read() - .unwrap() .values() .find(|pt| pt.default.unwrap_or(false)) { @@ -429,7 +415,7 @@ impl StreamContext { } } - let prompt_target = match self.prompt_targets.read().unwrap().get(&prompt_target_name) { + let prompt_target = match self.prompt_targets.get(&prompt_target_name) { Some(prompt_target) => prompt_target.clone(), None => { return self.send_server_error( @@ -441,7 +427,7 @@ impl StreamContext { info!("prompt_target name: {:?}", prompt_target_name); let mut chat_completion_tools: Vec = Vec::new(); - for pt in self.prompt_targets.read().unwrap().values() { + for pt in self.prompt_targets.values() { // only extract entity names let properties: HashMap = match pt.parameters { // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. @@ -592,13 +578,7 @@ impl StreamContext { let tools_call_name = tool_calls[0].function.name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - let prompt_target = self - .prompt_targets - .read() - .unwrap() - .get(&tools_call_name) - .unwrap() - .clone(); + let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); debug!("prompt_target_name: {}", prompt_target.name); debug!("tool_name(s): {:?}", tool_names); @@ -660,8 +640,6 @@ impl StreamContext { let prompt_target_name = callout_context.prompt_target_name.unwrap(); let prompt_target = self .prompt_targets - .read() - .unwrap() .get(&prompt_target_name) .unwrap() .clone(); @@ -832,8 +810,6 @@ impl StreamContext { fn default_target_handler(&self, body: Vec, callout_context: CallContext) { let prompt_target = self .prompt_targets - .read() - .unwrap() .get(callout_context.prompt_target_name.as_ref().unwrap()) .unwrap() .clone(); diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index db0ca962..7e1249cf 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -6,9 +6,9 @@ use proxy_wasm_test_framework::types::{ use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; use public_types::common_types::PromptGuardResponse; -use public_types::embeddings::embedding::Object; use public_types::embeddings::{ - create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, + create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, + Embedding, }; use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use serde_yaml::Value; @@ -158,7 +158,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { data: vec![Embedding { index: 0, embedding: vec![], - object: Object::default(), + object: embedding::Object::default(), }], model: String::from("test"), object: create_embedding_response::Object::default(), @@ -177,8 +177,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Warn), None) - .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("model_server"), @@ -243,8 +241,130 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .unwrap(); } -fn default_config() -> Configuration { - let config: &str = r#" +fn setup_filter(module: &mut Tester, config: &str) -> i32 { + let filter_context = 1; + + module + .call_proxy_on_context_create(filter_context, 0) + .expect_metric_creation(MetricType::Gauge, "active_http_calls") + .expect_metric_creation(MetricType::Counter, "ratelimited_rq") + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_configure(filter_context, config.len() as i32) + .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) + .returning(Some(&config)) + .execute_and_expect(ReturnType::Bool(true)) + .unwrap(); + + module + .call_proxy_on_tick(filter_context) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(101)) + .expect_metric_increment("active_http_calls", 1) + .expect_log(Some(LogLevel::Debug), None) + .expect_http_call( + Some("model_server"), + Some(vec![ + (":method", "POST"), + (":path", "/embeddings"), + (":authority", "model_server"), + ("content-type", "application/json"), + ("x-envoy-upstream-rq-timeout-ms", "60000"), + ]), + None, + None, + None, + ) + .returning(Some(102)) + .expect_metric_increment("active_http_calls", 1) + .expect_set_tick_period_millis(Some(0)) + .execute_and_expect(ReturnType::None) + .unwrap(); + + let embedding_response = CreateEmbeddingResponse { + data: vec![Embedding { + embedding: vec![], + index: 0, + object: embedding::Object::default(), + }], + model: String::from("test"), + object: create_embedding_response::Object::default(), + usage: Box::new(CreateEmbeddingResponseUsage { + prompt_tokens: 0, + total_tokens: 0, + }), + }; + let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); + module + .call_proxy_on_http_call_response( + filter_context, + 101, + 0, + embedding_response_str.len() as i32, + 0, + ) + .expect_log( + Some(LogLevel::Debug), + Some( + format!( + "filter_context: on_http_call_response called with token_id: {:?}", + 101 + ) + .as_str(), + ), + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&embedding_response_str)) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + module + .call_proxy_on_http_call_response( + filter_context, + 102, + 0, + embedding_response_str.len() as i32, + 0, + ) + .expect_log( + Some(LogLevel::Debug), + Some( + format!( + "filter_context: on_http_call_response called with token_id: {:?}", + 102 + ) + .as_str(), + ), + ) + .expect_metric_increment("active_http_calls", -1) + .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) + .returning(Some(&embedding_response_str)) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::None) + .unwrap(); + + filter_context +} + +fn default_config() -> &'static str { + r#" version: "0.1-beta" listener: @@ -297,24 +417,6 @@ prompt_targets: - Use farenheight for temperature - Use miles per hour for wind speed - - name: insurance_claim_details - type: function_resolver - description: This function resolver provides insurance claim details for a given policy number. - parameters: - - name: policy_number - required: true - description: The policy number for which the insurance claim details are requested. - type: string - - name: include_expired - description: whether to include expired insurance claims in the response. - type: bool - required: true - endpoint: - name: api_server - path: /insurance_claim_details - system_prompt: | - You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries: - - Use policy number to retrieve insurance claim details ratelimits: - model: gpt-4 selector: @@ -323,8 +425,7 @@ ratelimits: limit: tokens: 1 unit: minute -"#; - serde_yaml::from_str(config).unwrap() +"# } #[test] @@ -343,22 +444,7 @@ fn successful_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -419,22 +505,7 @@ fn bad_request_to_open_ai_chat_completions() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -496,21 +567,7 @@ fn request_ratelimited() { .unwrap(); // Setup Filter - let filter_context = 1; - let config = serde_json::to_string(&default_config()).unwrap(); - - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - module - .call_proxy_on_configure(filter_context, config.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; @@ -619,24 +676,11 @@ fn request_not_ratelimited() { .unwrap(); // Setup Filter - let filter_context = 1; - - let mut config = default_config(); + let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; let config_str = serde_json::to_string(&config).unwrap(); - module - .call_proxy_on_context_create(filter_context, 0) - .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") - .execute_and_expect(ReturnType::None) - .unwrap(); - module - .call_proxy_on_configure(filter_context, config_str.len() as i32) - .expect_get_buffer_bytes(Some(BufferType::PluginConfiguration)) - .returning(Some(&config_str)) - .execute_and_expect(ReturnType::Bool(true)) - .unwrap(); + let filter_context = setup_filter(&mut module, &config_str); // Setup HTTP Stream let http_context = 2; diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 9b3e3968..5b6bd794 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -7,7 +7,7 @@ pub struct EmbeddingRequest { pub prompt_target: PromptTarget, } -#[derive(Debug, Clone, Hash, PartialEq, Eq)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub enum EmbeddingType { Name, Description,