retry embeddings fetch (#245)

This commit is contained in:
Adil Hafeez 2024-11-05 12:04:36 -06:00 committed by GitHub
parent 9a5c5cc3a3
commit 9a6ae2efee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 79 additions and 41 deletions

View file

@ -11,7 +11,8 @@ use common::http::CallArgs;
use common::http::Client; use common::http::Client;
use common::stats::Gauge; use common::stats::Gauge;
use common::stats::IncrementingMetric; use common::stats::IncrementingMetric;
use log::debug; use http::StatusCode;
use log::{debug, info, trace, warn};
use proxy_wasm::traits::*; use proxy_wasm::traits::*;
use proxy_wasm::types::*; use proxy_wasm::types::*;
use std::cell::RefCell; use std::cell::RefCell;
@ -53,6 +54,7 @@ pub struct FilterContext {
prompt_guards: Rc<PromptGuards>, prompt_guards: Rc<PromptGuards>,
embeddings_store: Option<Rc<EmbeddingsStore>>, embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore, temp_embeddings_store: EmbeddingsStore,
active_embedding_calls_count: u32,
} }
impl FilterContext { impl FilterContext {
@ -66,22 +68,26 @@ impl FilterContext {
prompt_guards: Rc::new(PromptGuards::default()), prompt_guards: Rc::new(PromptGuards::default()),
embeddings_store: Some(Rc::new(HashMap::new())), embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(), temp_embeddings_store: HashMap::new(),
active_embedding_calls_count: 0,
} }
} }
fn process_prompt_targets(&self) { fn process_prompt_targets(&mut self) {
for values in self.prompt_targets.iter() { let prompt_target_description: Vec<(String, String)> = self
let prompt_target = values.1; .prompt_targets
self.schedule_embeddings_call( .iter()
&prompt_target.name, .map(|(k, v)| (k.clone(), v.description.clone()))
&prompt_target.description, .collect();
EmbeddingType::Description,
); prompt_target_description
} .iter()
.for_each(|(name, description)| {
self.schedule_embeddings_call(name, description, EmbeddingType::Description);
});
} }
fn schedule_embeddings_call( fn schedule_embeddings_call(
&self, &mut self,
prompt_target_name: &str, prompt_target_name: &str,
input: &str, input: &str,
embedding_type: EmbeddingType, embedding_type: EmbeddingType,
@ -116,6 +122,7 @@ impl FilterContext {
embedding_type, embedding_type,
}; };
self.active_embedding_calls_count += 1;
if let Err(error) = self.http_call(call_args, call_context) { if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}") panic!("{error}")
} }
@ -123,9 +130,9 @@ impl FilterContext {
fn embedding_response_handler( fn embedding_response_handler(
&mut self, &mut self,
body_size: usize,
embedding_type: EmbeddingType, embedding_type: EmbeddingType,
prompt_target_name: String, prompt_target_name: String,
body: Vec<u8>,
) { ) {
let prompt_target = self let prompt_target = self
.prompt_targets .prompt_targets
@ -137,9 +144,6 @@ impl FilterContext {
) )
}); });
let body = self
.get_http_call_response_body(0, body_size)
.expect("No body in response");
if !body.is_empty() { if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse = let mut embedding_response: CreateEmbeddingResponse =
match serde_json::from_slice(&body) { match serde_json::from_slice(&body) {
@ -208,7 +212,7 @@ impl Context for FilterContext {
body_size: usize, body_size: usize,
_num_trailers: usize, _num_trailers: usize,
) { ) {
debug!( trace!(
"filter_context: on_http_call_response called with token_id: {:?}", "filter_context: on_http_call_response called with token_id: {:?}",
token_id token_id
); );
@ -218,13 +222,26 @@ impl Context for FilterContext {
.remove(&token_id) .remove(&token_id)
.expect("invalid token_id"); .expect("invalid token_id");
self.active_embedding_calls_count -= 1;
self.metrics.active_http_calls.increment(-1); self.metrics.active_http_calls.increment(-1);
let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
self.embedding_response_handler( if let Some(status_code) = self.get_http_call_response_header(":status") {
body_size, if status_code == StatusCode::OK.as_str() {
callout_data.embedding_type, self.embedding_response_handler(
callout_data.prompt_target_name, callout_data.embedding_type,
) callout_data.prompt_target_name,
body_bytes,
);
} else {
warn!(
"Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
status_code,
token_id,
String::from_utf8(body_bytes).unwrap()
);
}
}
} }
} }
@ -262,10 +279,7 @@ impl RootContext for FilterContext {
context_id context_id
); );
let embedding_store = match self.embeddings_store.as_ref() { let embedding_store = self.embeddings_store.as_ref().map(Rc::clone);
None => return None,
Some(store) => Some(Rc::clone(store)),
};
Some(Box::new(StreamContext::new( Some(Box::new(StreamContext::new(
context_id, context_id,
Rc::clone(&self.metrics), Rc::clone(&self.metrics),
@ -287,8 +301,20 @@ impl RootContext for FilterContext {
} }
fn on_tick(&mut self) { fn on_tick(&mut self) {
debug!("starting up arch filter in mode: prompt gateway mode"); if self.embeddings_store.is_some()
self.process_prompt_targets(); && self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
self.set_tick_period(Duration::from_secs(0)); {
info!("embeddings store initialized");
self.set_tick_period(Duration::from_secs(0));
} else {
if self.active_embedding_calls_count == 0 {
info!("retrieving embeddings from embedding server");
self.process_prompt_targets();
} else {
info!("waiting for embeddings store to be initialized");
}
self.set_tick_period(Duration::from_secs(5));
}
} }
} }

View file

@ -35,10 +35,10 @@ impl HttpContext for StreamContext {
let request_path = self.get_http_request_header(":path").unwrap_or_default(); let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH { if request_path == HEALTHZ_PATH {
if self.embeddings_store.is_none() { if self.is_embedding_store_initialized() {
self.send_http_response(503, vec![], None);
} else {
self.send_http_response(200, vec![], None); self.send_http_response(200, vec![], None);
} else {
self.send_http_response(503, vec![], None);
} }
return Action::Continue; return Action::Continue;
} }

View file

@ -61,7 +61,7 @@ pub struct StreamCallContext {
pub struct StreamContext { pub struct StreamContext {
system_prompt: Rc<Option<String>>, system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>, pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub embeddings_store: Option<Rc<EmbeddingsStore>>, pub embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>, overrides: Rc<Option<Overrides>>,
pub metrics: Rc<WasmMetrics>, pub metrics: Rc<WasmMetrics>,
@ -109,10 +109,21 @@ impl StreamContext {
request_id: None, request_id: None,
} }
} }
fn embeddings_store(&self) -> &EmbeddingsStore { fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store self.embeddings_store.as_ref().unwrap()
.as_ref() }
.expect("embeddings store is not set")
pub fn is_embedding_store_initialized(&self) -> bool {
if self.embeddings_store.as_ref().is_none() {
return false;
}
if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() {
return true;
}
false
} }
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) { pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
@ -223,7 +234,7 @@ impl StreamContext {
"embeddings not found for prompt target name: {}", "embeddings not found for prompt target name: {}",
prompt_name prompt_name
); );
return (prompt_name.clone(), f64::NAN); return (prompt_name.clone(), 0.0);
} }
}; };
@ -234,7 +245,7 @@ impl StreamContext {
"description embeddings not found for prompt target name: {}", "description embeddings not found for prompt target name: {}",
prompt_name prompt_name
); );
return (prompt_name.clone(), f64::NAN); return (prompt_name.clone(), 0.0);
} }
}; };
let similarity_score_description = let similarity_score_description =

View file

@ -161,6 +161,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer)) .returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
@ -244,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module module
.call_proxy_on_tick(filter_context) .call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None) .expect_log(Some(LogLevel::Trace), None)
.expect_http_call( .expect_http_call(
Some("arch_internal"), Some("arch_internal"),
@ -262,7 +263,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
) )
.returning(Some(101)) .returning(Some(101))
.expect_metric_increment("active_http_calls", 1) .expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0)) .expect_set_tick_period_millis(Some(5000))
.execute_and_expect(ReturnType::None) .execute_and_expect(ReturnType::None)
.unwrap(); .unwrap();
@ -289,7 +290,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
0, 0,
) )
.expect_log( .expect_log(
Some(LogLevel::Debug), Some(LogLevel::Trace),
Some( Some(
format!( format!(
"filter_context: on_http_call_response called with token_id: {:?}", "filter_context: on_http_call_response called with token_id: {:?}",
@ -332,7 +333,7 @@ llm_providers:
overrides: overrides:
# confidence threshold for prompt target intent matching # confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6 prompt_target_intent_matching_threshold: 0.0
system_prompt: | system_prompt: |
You are a helpful assistant. You are a helpful assistant.