diff --git a/arch/src/filter_context.rs b/arch/src/filter_context.rs index 09314ff5..ff6342b3 100644 --- a/arch/src/filter_context.rs +++ b/arch/src/filter_context.rs @@ -259,7 +259,7 @@ impl RootContext for FilterContext { self.prompt_targets = Rc::new(prompt_targets); self.mode = config.mode.unwrap_or_default(); - ratelimit::ratelimits(config.ratelimits); + ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) @@ -280,15 +280,10 @@ impl RootContext for FilterContext { ); // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store; - match self.mode { - GatewayMode::Llm => { - embedding_store = None; - } - GatewayMode::Prompt => { - embedding_store = Some(Rc::clone(self.embeddings_store.as_ref().unwrap())) - } - } + let embedding_store = match self.mode { + GatewayMode::Llm => None, + GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), + }; Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), diff --git a/arch/src/ratelimit.rs b/arch/src/ratelimit.rs index 311ceb48..83a85e6c 100644 --- a/arch/src/ratelimit.rs +++ b/arch/src/ratelimit.rs @@ -404,6 +404,14 @@ mod test { use std::num::NonZero; use std::thread; + #[test] + fn make_ratelimits_optional() { + let ratelimits_config = Vec::new(); + + // Initialize in the main thread. + ratelimits(Some(ratelimits_config)); + } + #[test] fn different_threads_have_same_ratelimit_data_structure() { let ratelimits_config = Some(vec![Ratelimit { diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index f9c4be65..bc9e62fa 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -944,7 +944,7 @@ impl StreamContext { ) -> Result<(), ratelimit::Error> { if let Some(selector) = self.ratelimit_selector.take() { // Tokenize and Ratelimit. - if let Ok(token_count) = tokenizer::token_count(model, &json_string) { + if let Ok(token_count) = tokenizer::token_count(model, json_string) { ratelimit::ratelimits(None).read().unwrap().check_limit( model.to_owned(), selector, diff --git a/demos/function_calling/arch_config.yaml b/demos/function_calling/arch_config.yaml index 7b69c031..f409d753 100644 --- a/demos/function_calling/arch_config.yaml +++ b/demos/function_calling/arch_config.yaml @@ -72,14 +72,5 @@ prompt_targets: # if true arch will forward the response to the default LLM auto_llm_dispatch_on_response: true -ratelimits: - - model: gpt-4 - selector: - key: selector-key - value: selector-value - limit: - tokens: 1 - unit: minute - tracing: random_sampling: 100