diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs index de120369..b1546f0e 100644 --- a/crates/prompt_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -11,7 +11,8 @@ use common::http::CallArgs; use common::http::Client; use common::stats::Gauge; use common::stats::IncrementingMetric; -use log::debug; +use http::StatusCode; +use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use std::cell::RefCell; @@ -53,6 +54,7 @@ pub struct FilterContext { prompt_guards: Rc, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, + active_embedding_calls_count: u32, } impl FilterContext { @@ -64,24 +66,28 @@ impl FilterContext { prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - embeddings_store: Some(Rc::new(HashMap::new())), + embeddings_store: None, temp_embeddings_store: HashMap::new(), + active_embedding_calls_count: 0, } } - fn process_prompt_targets(&self) { - for values in self.prompt_targets.iter() { - let prompt_target = values.1; - self.schedule_embeddings_call( - &prompt_target.name, - &prompt_target.description, - EmbeddingType::Description, - ); - } + fn process_prompt_targets(&mut self) { + let prompt_target_description: Vec<(String, String)> = self + .prompt_targets + .iter() + .map(|(k, v)| (k.clone(), v.description.clone())) + .collect(); + + prompt_target_description + .iter() + .for_each(|(name, description)| { + self.schedule_embeddings_call(name, description, EmbeddingType::Description); + }); } fn schedule_embeddings_call( - &self, + &mut self, prompt_target_name: &str, input: &str, embedding_type: EmbeddingType, @@ -116,6 +122,7 @@ impl FilterContext { embedding_type, }; + self.active_embedding_calls_count += 1; if let Err(error) = self.http_call(call_args, call_context) { panic!("{error}") } @@ -218,13 +225,27 @@ impl Context for FilterContext { .remove(&token_id) .expect("invalid token_id"); + let body_bytes = self.get_http_call_response_body(0, body_size).unwrap(); + + self.active_embedding_calls_count -= 1; self.metrics.active_http_calls.increment(-1); - self.embedding_response_handler( - body_size, - callout_data.embedding_type, - callout_data.prompt_target_name, - ) + if let Some(status_code) = self.get_http_call_response_header(":status") { + if status_code == StatusCode::OK.as_str() { + self.embedding_response_handler( + body_size, + callout_data.embedding_type, + callout_data.prompt_target_name, + ); + } else { + warn!( + "Received non-200 status code: {} for callout with token_id: {}: body_str: {}", + status_code, + token_id, + String::from_utf8(body_bytes).unwrap() + ); + } + } } } @@ -263,7 +284,7 @@ impl RootContext for FilterContext { ); let embedding_store = match self.embeddings_store.as_ref() { - None => return None, + None => None, Some(store) => Some(Rc::clone(store)), }; Some(Box::new(StreamContext::new( @@ -288,7 +309,17 @@ impl RootContext for FilterContext { fn on_tick(&mut self) { debug!("starting up arch filter in mode: prompt gateway mode"); - self.process_prompt_targets(); - self.set_tick_period(Duration::from_secs(0)); + if self.embeddings_store.is_some() { + info!("All embeddings have been fetched, disabling tick"); + self.set_tick_period(Duration::from_secs(0)); + } else { + info!("waiting for embeddings to be fetched, continuing to wait"); + if self.active_embedding_calls_count == 0 { + info!("no active calls seen, it seems like embedding calls are done but embedding store is not yet populated, retrying"); + self.process_prompt_targets(); + } + + self.set_tick_period(Duration::from_secs(1)); + } } }