This commit is contained in:
Adil Hafeez 2024-11-04 15:02:23 -08:00
parent 0a655ba5c4
commit fded75c35a
4 changed files with 29 additions and 15 deletions

View file

@ -66,7 +66,7 @@ impl FilterContext {
prompt_targets: Rc::new(HashMap::new()),
overrides: Rc::new(None),
prompt_guards: Rc::new(PromptGuards::default()),
embeddings_store: None,
embeddings_store: Some(Rc::new(HashMap::new())),
temp_embeddings_store: HashMap::new(),
active_embedding_calls_count: 0,
}
@ -305,7 +305,9 @@ impl RootContext for FilterContext {
}
fn on_tick(&mut self) {
if self.embeddings_store.is_some() {
if self.embeddings_store.is_some()
&& self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len()
{
info!("embeddings store initialized");
self.set_tick_period(Duration::from_secs(0));
} else {

View file

@ -35,10 +35,10 @@ impl HttpContext for StreamContext {
let request_path = self.get_http_request_header(":path").unwrap_or_default();
if request_path == HEALTHZ_PATH {
if self.embeddings_store.is_none() {
self.send_http_response(503, vec![], None);
} else {
if self.is_embedding_store_initialized() {
self.send_http_response(200, vec![], None);
} else {
self.send_http_response(503, vec![], None);
}
return Action::Continue;
}

View file

@ -61,7 +61,7 @@ pub struct StreamCallContext {
pub struct StreamContext {
system_prompt: Rc<Option<String>>,
prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub prompt_targets: Rc<HashMap<String, PromptTarget>>,
pub embeddings_store: Option<Rc<EmbeddingsStore>>,
overrides: Rc<Option<Overrides>>,
pub metrics: Rc<WasmMetrics>,
@ -109,10 +109,21 @@ impl StreamContext {
request_id: None,
}
}
fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store
.as_ref()
.expect("embeddings store is not set")
self.embeddings_store.as_ref().unwrap()
}
pub fn is_embedding_store_initialized(&self) -> bool {
if self.embeddings_store.as_ref().is_none() {
return false;
}
if self.embeddings_store.as_ref().unwrap().len() == self.prompt_targets.len() {
return true;
}
false
}
pub fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
@ -223,7 +234,7 @@ impl StreamContext {
"embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), f64::NAN);
return (prompt_name.clone(), 0.0);
}
};
@ -234,7 +245,7 @@ impl StreamContext {
"description embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), f64::NAN);
return (prompt_name.clone(), 0.0);
}
};
let similarity_score_description =

View file

@ -161,6 +161,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Trace), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
@ -244,7 +245,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Trace), None)
.expect_http_call(
Some("arch_internal"),
@ -262,7 +263,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.expect_set_tick_period_millis(Some(5000))
.execute_and_expect(ReturnType::None)
.unwrap();
@ -289,7 +290,7 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(LogLevel::Trace),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
@ -332,7 +333,7 @@ llm_providers:
overrides:
# confidence threshold for prompt target intent matching
prompt_target_intent_matching_threshold: 0.6
prompt_target_intent_matching_threshold: 0.0
system_prompt: |
You are a helpful assistant.