diff --git a/arch/src/stream_context.rs b/arch/src/stream_context.rs index 4ef74e55..6d15cc41 100644 --- a/arch/src/stream_context.rs +++ b/arch/src/stream_context.rs @@ -165,12 +165,12 @@ impl StreamContext { } }; - let embeddings_vector = &embedding_response.data[0].embedding; + let prompt_embeddings_vector = &embedding_response.data[0].embedding; debug!( "embedding model: {}, vector length: {:?}", embedding_response.model, - embeddings_vector.len() + prompt_embeddings_vector.len() ); let prompt_target_embeddings = match embeddings_store().read() { @@ -201,15 +201,29 @@ impl StreamContext { // exclude default prompt target .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) .map(|(prompt_name, _)| { - let default_embeddings = HashMap::new(); - let pte = prompt_target_embeddings - .get(prompt_name) - .unwrap_or(&default_embeddings); - let description_embeddings = pte.get(&EmbeddingType::Description); - let similarity_score_description = cos::cosine_similarity( - &embeddings_vector, - &description_embeddings.unwrap_or(&vec![0.0]), - ); + let pte = match prompt_target_embeddings.get(prompt_name) { + Some(embeddings) => embeddings, + None => { + warn!( + "embeddings not found for prompt target name: {}", + prompt_name + ); + return (prompt_name.clone(), f64::NAN); + } + }; + + let description_embeddings = match pte.get(&EmbeddingType::Description) { + Some(embeddings) => embeddings, + None => { + warn!( + "description embeddings not found for prompt target name: {}", + prompt_name + ); + return (prompt_name.clone(), f64::NAN); + } + }; + let similarity_score_description = + cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings); (prompt_name.clone(), similarity_score_description) }) .collect(); diff --git a/arch/tests/integration.rs b/arch/tests/integration.rs index ca467734..d16c5be7 100644 --- a/arch/tests/integration.rs +++ b/arch/tests/integration.rs @@ -175,6 +175,8 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) .returning(Some(&embeddings_response_buffer)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Warn), None) + .expect_log(Some(LogLevel::Warn), None) .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("model_server"),