diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index 27f2d7de..066f0072 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -66,7 +66,7 @@ impl FilterContext { prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - embeddings_store: None, + embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), active_embedding_calls_count: 0, } @@ -305,7 +305,9 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - if self.embeddings_store.is_some() { + if self.embeddings_store.is_some() + && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() + { info!("embeddings store initialized"); self.set_tick_period(Duration::from_secs(0)); } else { diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index feb3d616..c67cd11b 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -35,10 +35,10 @@ impl HttpContext for StreamContext { let request_path = self.get_http_request_header(":path").unwrap_or_default(); if request_path == HEALTHZ_PATH { - if self.embeddings_store.is_none() { - self.send_http_response(503, vec![], None); - } else { + if self.is_embedding_store_initialized() { self.send_http_response(200, vec![], None); + } else { + self.send_http_response(503, vec![], None); } return Action::Continue; } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9e8f8a60..65478a6c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -61,7 +61,7 @@ pub struct StreamCallContext { pub struct StreamContext { system_prompt: Rc>, - prompt_targets: Rc>, + pub prompt_targets: Rc>, pub embeddings_store: Option>, overrides: Rc>, pub metrics: Rc, @@ -109,10 +109,21 @@ impl StreamContext { request_id: None, } } + fn embeddings_store(&self) -> &EmbeddingsStore { - self.embeddings_store - .as_ref() - .expect("embeddings store is not set") + self.embeddings_store.as_ref().unwrap() + } + + pub fn is_embedding_store_initialized(&self) -> bool { + if self.embeddings_store.as_ref().is_none() { + return false; + } + + if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() { + return true; + } + + false } pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { @@ -223,7 +234,7 @@ impl StreamContext { "embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; @@ -234,7 +245,7 @@ impl StreamContext { "description embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; let similarity_score_description = diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 1bf581c5..46f2dfd8 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -161,6 +161,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( @@ -244,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), @@ -262,7 +263,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { ) .returning(Some(101)) .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(0)) + .expect_set_tick_period_millis(Some(5000)) .execute_and_expect(ReturnType::None) .unwrap(); @@ -289,7 +290,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { 0, ) .expect_log( - Some(LogLevel::Debug), + Some(LogLevel::Trace), Some( format!( "filter_context: on_http_call_response called with token_id: {:?}", @@ -332,7 +333,7 @@ llm_providers: overrides: # confidence threshold for prompt target intent matching - prompt_target_intent_matching_threshold: 0.6 + prompt_target_intent_matching_threshold: 0.0 system_prompt: | You are a helpful assistant.