From 9a6ae2efee34ca1301fb206d1a871ea2251b4408 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 5 Nov 2024 12:04:36 -0600 Subject: [PATCH] retry embeddings fetch (#245) --- crates/prompt_gateway/src/filter_context.rs | 82 ++++++++++++++------- crates/prompt_gateway/src/http_context.rs | 6 +- crates/prompt_gateway/src/stream_context.rs | 23 ++++-- crates/prompt_gateway/tests/integration.rs | 9 ++- 4 files changed, 79 insertions(+), 41 deletions(-) diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index de120369..0b44fac9 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -11,7 +11,8 @@ use common::http::CallArgs; use common::http::Client; use common::stats::Gauge; use common::stats::IncrementingMetric; -use log::debug; +use http::StatusCode; +use log::{debug, info, trace, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; @@ -53,6 +54,7 @@ pub struct FilterContext { prompt_guards: Rc, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, + active_embedding_calls_count: u32, } impl FilterContext { @@ -66,22 +68,26 @@ impl FilterContext { prompt_guards: Rc::new(PromptGuards::default()), embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), + active_embedding_calls_count: 0, } } - fn process_prompt_targets(&self) { - for values in self.prompt_targets.iter() { - let prompt_target = values.1; - self.schedule_embeddings_call( - &prompt_target.name, - &prompt_target.description, - EmbeddingType::Description, - ); - } + fn process_prompt_targets(&mut self) { + let prompt_target_description: Vec<(String, String)> = self + .prompt_targets + .iter() + .map(|(k, v)| (k.clone(), v.description.clone())) + .collect(); + + prompt_target_description + .iter() + .for_each(|(name, description)| { + self.schedule_embeddings_call(name, description, EmbeddingType::Description); + }); } fn schedule_embeddings_call( - &self, + &mut self, prompt_target_name: &str, input: &str, embedding_type: EmbeddingType, @@ -116,6 +122,7 @@ impl FilterContext { embedding_type, }; + self.active_embedding_calls_count += 1; if let Err(error) = self.http_call(call_args, call_context) { panic!("{error}") } @@ -123,9 +130,9 @@ impl FilterContext { fn embedding_response_handler( &mut self, - body_size: usize, embedding_type: EmbeddingType, prompt_target_name: String, + body: Vec, ) { let prompt_target = self .prompt_targets @@ -137,9 +144,6 @@ impl FilterContext { ) }); - let body = self - .get_http_call_response_body(0, body_size) - .expect("No body in response"); if !body.is_empty() { let mut embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { @@ -208,7 +212,7 @@ impl Context for FilterContext { body_size: usize, _num_trailers: usize, ) { - debug!( + trace!( "filter_context: on_http_call_response called with token_id: {:?}", token_id ); @@ -218,13 +222,26 @@ impl Context for FilterContext { .remove(&token_id) .expect("invalid token_id"); + self.active_embedding_calls_count -= 1; self.metrics.active_http_calls.increment(-1); + let body_bytes = self.get_http_call_response_body(0, body_size).unwrap(); - self.embedding_response_handler( - body_size, - callout_data.embedding_type, - callout_data.prompt_target_name, - ) + if let Some(status_code) = self.get_http_call_response_header(":status") { + if status_code == StatusCode::OK.as_str() { + self.embedding_response_handler( + callout_data.embedding_type, + callout_data.prompt_target_name, + body_bytes, + ); + } else { + warn!( + "Received non-200 status code: {} for callout with token_id: {}: body_str: {}", + status_code, + token_id, + String::from_utf8(body_bytes).unwrap() + ); + } + } } } @@ -262,10 +279,7 @@ impl RootContext for FilterContext { context_id ); - let embedding_store = match self.embeddings_store.as_ref() { - None => return None, - Some(store) => Some(Rc::clone(store)), - }; + let embedding_store = self.embeddings_store.as_ref().map(Rc::clone); Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), @@ -287,8 +301,20 @@ impl RootContext for FilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: prompt gateway mode"); - self.process_prompt_targets(); - self.set_tick_period(Duration::from_secs(0)); + if self.embeddings_store.is_some() + && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() + { + info!("embeddings store initialized"); + self.set_tick_period(Duration::from_secs(0)); + } else { + if self.active_embedding_calls_count == 0 { + info!("retrieving embeddings from embedding server"); + self.process_prompt_targets(); + } else { + info!("waiting for embeddings store to be initialized"); + } + + self.set_tick_period(Duration::from_secs(5)); + } } } diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index feb3d616..c67cd11b 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -35,10 +35,10 @@ impl HttpContext for StreamContext { let request_path = self.get_http_request_header(":path").unwrap_or_default(); if request_path == HEALTHZ_PATH { - if self.embeddings_store.is_none() { - self.send_http_response(503, vec![], None); - } else { + if self.is_embedding_store_initialized() { self.send_http_response(200, vec![], None); + } else { + self.send_http_response(503, vec![], None); } return Action::Continue; } diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 9e8f8a60..65478a6c 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -61,7 +61,7 @@ pub struct StreamCallContext { pub struct StreamContext { system_prompt: Rc>, - prompt_targets: Rc>, + pub prompt_targets: Rc>, pub embeddings_store: Option>, overrides: Rc>, pub metrics: Rc, @@ -109,10 +109,21 @@ impl StreamContext { request_id: None, } } + fn embeddings_store(&self) -> &EmbeddingsStore { - self.embeddings_store - .as_ref() - .expect("embeddings store is not set") + self.embeddings_store.as_ref().unwrap() + } + + pub fn is_embedding_store_initialized(&self) -> bool { + if self.embeddings_store.as_ref().is_none() { + return false; + } + + if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() { + return true; + } + + false } pub fn send_server_error(&self, error: ServerError, override_status_code: Option) { @@ -223,7 +234,7 @@ impl StreamContext { "embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; @@ -234,7 +245,7 @@ impl StreamContext { "description embeddings not found for prompt target name: {}", prompt_name ); - return (prompt_name.clone(), f64::NAN); + return (prompt_name.clone(), 0.0); } }; let similarity_score_description = diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 1bf581c5..46f2dfd8 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -161,6 +161,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Trace), None) + .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( @@ -244,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Info), None) .expect_log(Some(LogLevel::Trace), None) .expect_http_call( Some("arch_internal"), @@ -262,7 +263,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { ) .returning(Some(101)) .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(0)) + .expect_set_tick_period_millis(Some(5000)) .execute_and_expect(ReturnType::None) .unwrap(); @@ -289,7 +290,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { 0, ) .expect_log( - Some(LogLevel::Debug), + Some(LogLevel::Trace), Some( format!( "filter_context: on_http_call_response called with token_id: {:?}", @@ -332,7 +333,7 @@ llm_providers: overrides: # confidence threshold for prompt target intent matching - prompt_target_intent_matching_threshold: 0.6 + prompt_target_intent_matching_threshold: 0.0 system_prompt: | You are a helpful assistant.