retry embeddings

This commit is contained in:
Adil Hafeez 2024-11-04 12:00:28 -08:00
parent e4d5293af4
commit a1258a3d26

View file

@ -11,7 +11,8 @@ use common::http::CallArgs;
use common::http::Client;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::cell::RefCell;
@ -53,6 +54,7 @@ pub struct FilterContext {
prompt_guards: Rc<PromptGuards>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
temp_embeddings_store: EmbeddingsStore,
active_embedding_calls_count: u32,
}
impl FilterContext {
@ -64,24 +66,28 @@ impl FilterContext {
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
embeddings_store: Some(Rc::new(HashMap::new())),
embeddings_store: None,
temp_embeddings_store: HashMap::new(),
active_embedding_calls_count: 0,
}
}
fn process_prompt_targets(&self) {
for values in self.prompt_targets.iter() {
let prompt_target = values.1;
self.schedule_embeddings_call(
&prompt_target.name,
&prompt_target.description,
EmbeddingType::Description,
);
}
fn process_prompt_targets(&mut self) {
let prompt_target_description: Vec<(String, String)> = self
.prompt_targets
.iter()
.map(|(k, v)| (k.clone(), v.description.clone()))
.collect();
prompt_target_description
.iter()
.for_each(|(name, description)| {
self.schedule_embeddings_call(name, description, EmbeddingType::Description);
});
}
fn schedule_embeddings_call(
&self,
&mut self,
prompt_target_name: &str,
input: &str,
embedding_type: EmbeddingType,
@ -116,6 +122,7 @@ impl FilterContext {
embedding_type,
};
self.active_embedding_calls_count += 1;
if let Err(error) = self.http_call(call_args, call_context) {
panic!("{error}")
}
@ -218,13 +225,27 @@ impl Context for FilterContext {
.remove(&token_id)
.expect("invalid token_id");
let body_bytes = self.get_http_call_response_body(0, body_size).unwrap();
self.active_embedding_calls_count -= 1;
self.metrics.active_http_calls.increment(-1);
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target_name,
)
if let Some(status_code) = self.get_http_call_response_header(":status") {
if status_code == StatusCode::OK.as_str() {
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target_name,
);
} else {
warn!(
"Received non-200 status code: {} for callout with token_id: {}: body_str: {}",
status_code,
token_id,
String::from_utf8(body_bytes).unwrap()
);
}
}
}
}
@ -263,7 +284,7 @@ impl RootContext for FilterContext {
);
let embedding_store = match self.embeddings_store.as_ref() {
None => return None,
None => None,
Some(store) => Some(Rc::clone(store)),
};
Some(Box::new(StreamContext::new(
@ -288,7 +309,17 @@ impl RootContext for FilterContext {
fn on_tick(&mut self) {
debug!("starting up arch filter in mode: prompt gateway mode");
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
if self.embeddings_store.is_some() {
info!("All embeddings have been fetched, disabling tick");
self.set_tick_period(Duration::from_secs(0));
} else {
info!("waiting for embeddings to be fetched, continuing to wait");
if self.active_embedding_calls_count == 0 {
info!("no active calls seen, it seems like embedding calls are done but embedding store is not yet populated, retrying");
self.process_prompt_targets();
}
self.set_tick_period(Duration::from_secs(1));
}
}
}