fix embeddings not found bug (#120)

This commit is contained in:
Adil Hafeez 2024-10-04 17:07:59 -07:00 committed by GitHub
parent 701187474f
commit 07ef1af24f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 11 deletions

View file

@ -165,12 +165,12 @@ impl StreamContext {
}
};
let embeddings_vector = &embedding_response.data[0].embedding;
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
debug!(
"embedding model: {}, vector length: {:?}",
embedding_response.model,
embeddings_vector.len()
prompt_embeddings_vector.len()
);
let prompt_target_embeddings = match embeddings_store().read() {
@ -201,15 +201,29 @@ impl StreamContext {
// exclude default prompt target
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
.map(|(prompt_name, _)| {
let default_embeddings = HashMap::new();
let pte = prompt_target_embeddings
.get(prompt_name)
.unwrap_or(&default_embeddings);
let description_embeddings = pte.get(&EmbeddingType::Description);
let similarity_score_description = cos::cosine_similarity(
&embeddings_vector,
&description_embeddings.unwrap_or(&vec![0.0]),
);
let pte = match prompt_target_embeddings.get(prompt_name) {
Some(embeddings) => embeddings,
None => {
warn!(
"embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), f64::NAN);
}
};
let description_embeddings = match pte.get(&EmbeddingType::Description) {
Some(embeddings) => embeddings,
None => {
warn!(
"description embeddings not found for prompt target name: {}",
prompt_name
);
return (prompt_name.clone(), f64::NAN);
}
};
let similarity_score_description =
cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings);
(prompt_name.clone(), similarity_score_description)
})
.collect();