mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
fix embeddings not found bug (#120)
This commit is contained in:
parent
701187474f
commit
07ef1af24f
2 changed files with 27 additions and 11 deletions
|
|
@ -165,12 +165,12 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let embeddings_vector = &embedding_response.data[0].embedding;
|
||||
let prompt_embeddings_vector = &embedding_response.data[0].embedding;
|
||||
|
||||
debug!(
|
||||
"embedding model: {}, vector length: {:?}",
|
||||
embedding_response.model,
|
||||
embeddings_vector.len()
|
||||
prompt_embeddings_vector.len()
|
||||
);
|
||||
|
||||
let prompt_target_embeddings = match embeddings_store().read() {
|
||||
|
|
@ -201,15 +201,29 @@ impl StreamContext {
|
|||
// exclude default prompt target
|
||||
.filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false))
|
||||
.map(|(prompt_name, _)| {
|
||||
let default_embeddings = HashMap::new();
|
||||
let pte = prompt_target_embeddings
|
||||
.get(prompt_name)
|
||||
.unwrap_or(&default_embeddings);
|
||||
let description_embeddings = pte.get(&EmbeddingType::Description);
|
||||
let similarity_score_description = cos::cosine_similarity(
|
||||
&embeddings_vector,
|
||||
&description_embeddings.unwrap_or(&vec![0.0]),
|
||||
);
|
||||
let pte = match prompt_target_embeddings.get(prompt_name) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
"embeddings not found for prompt target name: {}",
|
||||
prompt_name
|
||||
);
|
||||
return (prompt_name.clone(), f64::NAN);
|
||||
}
|
||||
};
|
||||
|
||||
let description_embeddings = match pte.get(&EmbeddingType::Description) {
|
||||
Some(embeddings) => embeddings,
|
||||
None => {
|
||||
warn!(
|
||||
"description embeddings not found for prompt target name: {}",
|
||||
prompt_name
|
||||
);
|
||||
return (prompt_name.clone(), f64::NAN);
|
||||
}
|
||||
};
|
||||
let similarity_score_description =
|
||||
cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings);
|
||||
(prompt_name.clone(), similarity_score_description)
|
||||
})
|
||||
.collect();
|
||||
|
|
|
|||
|
|
@ -175,6 +175,8 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Warn), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue