mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
Improve prompt target intent matching (#51)
This commit is contained in:
parent
8565462ec4
commit
9e50957f22
14 changed files with 461 additions and 415 deletions
|
|
@ -1,6 +1,6 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
|
|
|
|||
|
|
@ -3,22 +3,20 @@ use crate::ratelimit;
|
|||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::debug;
|
||||
use md5::Digest;
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::{
|
||||
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
|
||||
};
|
||||
use public_types::common_types::EmbeddingType;
|
||||
use public_types::configuration::{Configuration, PromptTarget};
|
||||
use serde_json::to_string;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
use std::time::Duration;
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct WasmMetrics {
|
||||
pub active_http_calls: Gauge,
|
||||
pub ratelimited_rq: Counter,
|
||||
|
|
@ -33,11 +31,29 @@ impl WasmMetrics {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CallContext {
|
||||
prompt_target: String,
|
||||
embedding_type: EmbeddingType,
|
||||
}
|
||||
|
||||
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FilterContext {
|
||||
metrics: Rc<WasmMetrics>,
|
||||
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
config: Option<Configuration>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
}
|
||||
|
||||
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
|
||||
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
|
||||
EMBEDDINGS.get_or_init(|| {
|
||||
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
|
||||
RwLock::new(embeddings)
|
||||
})
|
||||
}
|
||||
|
||||
impl FilterContext {
|
||||
|
|
@ -46,75 +62,95 @@ impl FilterContext {
|
|||
callouts: HashMap::new(),
|
||||
config: None,
|
||||
metrics: Rc::new(WasmMetrics::new()),
|
||||
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_prompt_targets(&mut self) {
|
||||
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
|
||||
for few_shot_example in &prompt_target.few_shot_examples {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(
|
||||
few_shot_example.to_string(),
|
||||
)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
// TODO: Handle potential errors
|
||||
let json_data: String = to_string(&embeddings_input).unwrap();
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "20000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
let embedding_request = EmbeddingRequest {
|
||||
create_embedding_request: embeddings_input,
|
||||
// Need to clone prompt target to leave config string intact.
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext::EmbeddingRequest(embedding_request)
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
panic!("Error reading prompt targets: {:?}", e);
|
||||
}
|
||||
};
|
||||
for values in prompt_targets.iter() {
|
||||
let prompt_target = &values.1;
|
||||
|
||||
// schedule embeddings call for prompt target name
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Name,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
// schedule embeddings call for prompt target description
|
||||
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
CallContext {
|
||||
prompt_target: prompt_target.name.clone(),
|
||||
embedding_type: EmbeddingType::Description,
|
||||
}
|
||||
})
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
fn embedding_request_handler(
|
||||
fn schedule_embeddings_call(&self, input: String) -> u32 {
|
||||
let embeddings_input = CreateEmbeddingRequest {
|
||||
input: Box::new(CreateEmbeddingRequestInput::String(input)),
|
||||
model: String::from(DEFAULT_EMBEDDING_MODEL),
|
||||
encoding_format: None,
|
||||
dimensions: None,
|
||||
user: None,
|
||||
};
|
||||
|
||||
let json_data = to_string(&embeddings_input).unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(60),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
token_id
|
||||
}
|
||||
|
||||
fn embedding_response_handler(
|
||||
&mut self,
|
||||
body_size: usize,
|
||||
create_embedding_request: CreateEmbeddingRequest,
|
||||
prompt_target: PromptTarget,
|
||||
embedding_type: EmbeddingType,
|
||||
prompt_target_name: String,
|
||||
) {
|
||||
let prompt_targets = self.prompt_targets.read().unwrap();
|
||||
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
|
|
@ -129,111 +165,24 @@ impl FilterContext {
|
|||
}
|
||||
};
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
"prompt-target".to_string(),
|
||||
to_string(&prompt_target).unwrap(),
|
||||
let embeddings = embedding_response.data.remove(0).embedding;
|
||||
log::info!(
|
||||
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
|
||||
prompt_target.name,
|
||||
prompt_target.description,
|
||||
embedding_type
|
||||
);
|
||||
let id: Option<Digest>;
|
||||
match *create_embedding_request.input {
|
||||
CreateEmbeddingRequestInput::String(input) => {
|
||||
id = Some(md5::compute(&input));
|
||||
payload.insert("input".to_string(), input);
|
||||
}
|
||||
CreateEmbeddingRequestInput::Array(_) => todo!(),
|
||||
}
|
||||
|
||||
let create_vector_store_points = StoreVectorEmbeddingsRequest {
|
||||
points: vec![VectorPoint {
|
||||
id: format!("{:x}", id.unwrap()),
|
||||
payload,
|
||||
vector: embedding_response.data.remove(0).embedding,
|
||||
}],
|
||||
};
|
||||
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store/points"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::StoreVectorEmbeddings(create_vector_store_points),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
embeddings_store().write().unwrap().insert(
|
||||
prompt_target.name.clone(),
|
||||
HashMap::from([(embedding_type, embeddings)]),
|
||||
);
|
||||
}
|
||||
} else {
|
||||
panic!("No body in response");
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
debug!(
|
||||
"response body: len {:?}",
|
||||
String::from_utf8(body).unwrap().len()
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: run once per envoy instance, right now it runs once per worker
|
||||
fn init_vector_store(&mut self) {
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
vec![
|
||||
(":method", "PUT"),
|
||||
(":path", "/collections/prompt_vector_store"),
|
||||
(":authority", "qdrant"),
|
||||
("content-type", "application/json"),
|
||||
],
|
||||
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
|
||||
}
|
||||
};
|
||||
if self
|
||||
.callouts
|
||||
.insert(
|
||||
token_id,
|
||||
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
|
||||
)
|
||||
.is_some()
|
||||
{
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for FilterContext {
|
||||
fn on_http_call_response(
|
||||
&mut self,
|
||||
|
|
@ -242,37 +191,18 @@ impl Context for FilterContext {
|
|||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id: {}");
|
||||
debug!("on_http_call_response called with token_id: {:?}", token_id);
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
|
||||
match callout_data {
|
||||
CallContext::EmbeddingRequest(EmbeddingRequest {
|
||||
create_embedding_request,
|
||||
prompt_target,
|
||||
}) => {
|
||||
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
|
||||
}
|
||||
CallContext::StoreVectorEmbeddings(_) => {
|
||||
self.create_vector_store_points_handler(body_size)
|
||||
}
|
||||
CallContext::CreateVectorCollection(_) => {
|
||||
let mut http_status_code = "Nil".to_string();
|
||||
self.get_http_call_response_headers()
|
||||
.iter()
|
||||
.for_each(|(k, v)| {
|
||||
if k == ":status" {
|
||||
http_status_code.clone_from(v);
|
||||
}
|
||||
});
|
||||
debug!("CreateVectorCollection response: {}", http_status_code);
|
||||
}
|
||||
}
|
||||
self.embedding_response_handler(
|
||||
body_size,
|
||||
callout_data.embedding_type,
|
||||
callout_data.prompt_target,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -282,6 +212,13 @@ impl RootContext for FilterContext {
|
|||
if let Some(config_bytes) = self.get_plugin_configuration() {
|
||||
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
|
||||
|
||||
for pt in self.config.clone().unwrap().prompt_targets {
|
||||
self.prompt_targets
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(pt.name.clone(), pt.clone());
|
||||
}
|
||||
|
||||
debug!("set configuration object: {:?}", self.config);
|
||||
|
||||
if let Some(ratelimits_config) = self
|
||||
|
|
@ -301,6 +238,7 @@ impl RootContext for FilterContext {
|
|||
ratelimit_selector: None,
|
||||
callouts: HashMap::new(),
|
||||
metrics: Rc::clone(&self.metrics),
|
||||
prompt_targets: Rc::clone(&self.prompt_targets),
|
||||
}))
|
||||
}
|
||||
|
||||
|
|
@ -314,7 +252,6 @@ impl RootContext for FilterContext {
|
|||
}
|
||||
|
||||
fn on_tick(&mut self) {
|
||||
self.init_vector_store();
|
||||
self.process_prompt_targets();
|
||||
self.set_tick_period(Duration::from_secs(0));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ pub trait RecordingMetric: Metric {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Counter {
|
||||
id: u32,
|
||||
}
|
||||
|
|
@ -55,7 +55,7 @@ impl Metric for Counter {
|
|||
|
||||
impl IncrementingMetric for Counter {}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Gauge {
|
||||
id: u32,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
use crate::consts::{
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
|
||||
USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::WasmMetrics;
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
use crate::ratelimit;
|
||||
use crate::ratelimit::Header;
|
||||
use crate::stats::IncrementingMetric;
|
||||
use crate::tokenizer;
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, error, info, warn};
|
||||
use open_message_format_embeddings::models::{
|
||||
|
|
@ -15,31 +16,31 @@ use open_message_format_embeddings::models::{
|
|||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{ChatCompletions, Message};
|
||||
use public_types::common_types::{
|
||||
open_ai::{ChatCompletions, Message},
|
||||
SearchPointsRequest, SearchPointsResponse,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{PromptTarget, PromptType};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
enum RequestType {
|
||||
GetEmbedding,
|
||||
SearchPoints,
|
||||
enum ResponseHandlerType {
|
||||
GetEmbeddings,
|
||||
FunctionResolver,
|
||||
FunctionCallResponse,
|
||||
FunctionCall,
|
||||
ZeroShotIntent,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
request_type: RequestType,
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: ChatCompletions,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
|
|
@ -47,6 +48,7 @@ pub struct StreamContext {
|
|||
pub ratelimit_selector: Option<Header>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -61,7 +63,6 @@ impl StreamContext {
|
|||
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
|
||||
// manipulate the body in benign ways e.g., compression.
|
||||
self.set_http_request_header("content-length", None);
|
||||
// self.set_http_request_header("authorization", None);
|
||||
}
|
||||
|
||||
fn modify_path_header(&mut self) {
|
||||
|
|
@ -85,7 +86,7 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
fn send_server_error(&mut self, error: String) {
|
||||
fn send_server_error(&self, error: String) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
|
|
@ -103,30 +104,85 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let search_points_request = SearchPointsRequest {
|
||||
vector: embedding_response.data[0].embedding.clone(),
|
||||
limit: 10,
|
||||
with_payload: true,
|
||||
};
|
||||
let embeddings_vector = &embedding_response.data[0].embedding;
|
||||
|
||||
let json_data: String = match serde_json::to_string(&search_points_request) {
|
||||
Ok(json_data) => json_data,
|
||||
debug!(
|
||||
"embedding model: {}, vector length: {:?}",
|
||||
embedding_response.model,
|
||||
embeddings_vector.len()
|
||||
);
|
||||
|
||||
let prompt_target_embeddings = match embeddings_store().read() {
|
||||
Ok(embeddings) => embeddings,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error serializing search_points_request: {:?}", e));
|
||||
let error_message = format!("Error reading embeddings store: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
let error_message = format!("Error reading prompt targets: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_target_names = prompt_targets
|
||||
.iter()
|
||||
.map(|(name, _)| name.clone())
|
||||
.collect();
|
||||
|
||||
let similarity_scores: Vec<(String, f64)> = prompt_targets
|
||||
.iter()
|
||||
.map(|(prompt_name, _prompt_target)| {
|
||||
let default_embeddings = HashMap::new();
|
||||
let pte = prompt_target_embeddings
|
||||
.get(prompt_name)
|
||||
.unwrap_or(&default_embeddings);
|
||||
let description_embeddings = pte.get(&EmbeddingType::Description);
|
||||
let similarity_score_description = cos::cosine_similarity(
|
||||
&embeddings_vector,
|
||||
&description_embeddings.unwrap_or(&vec![0.0]),
|
||||
);
|
||||
(prompt_name.clone(), similarity_score_description)
|
||||
})
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
"similarity scores based on description embeddings match: {:?}",
|
||||
similarity_scores
|
||||
);
|
||||
|
||||
callout_context.similarity_scores = Some(similarity_scores);
|
||||
|
||||
let zero_shot_classification_request = ZeroShotClassificationRequest {
|
||||
// Need to clone into input because user_message is used below.
|
||||
input: callout_context.user_message.as_ref().unwrap().clone(),
|
||||
model: String::from(DEFAULT_INTENT_MODEL),
|
||||
labels: prompt_target_names,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing zero shot request: {}", error);
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"qdrant",
|
||||
"embeddingserver",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", &path),
|
||||
(":authority", "qdrant"),
|
||||
(":path", "/zeroshot"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -134,39 +190,60 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
|
||||
|
||||
callout_context.request_type = RequestType::SearchPoints;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
panic!(
|
||||
"duplicate token_id={} in embedding server requests",
|
||||
token_id
|
||||
)
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
|
||||
Ok(search_points_response) => search_points_response,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!(
|
||||
"Error deserializing search_points_response: {:?}",
|
||||
e
|
||||
));
|
||||
fn zero_shot_intent_detection_resp_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: CallContext,
|
||||
) {
|
||||
let zeroshot_intent_response: ZeroShotClassificationResponse =
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Error deserializing zeroshot intent detection response: {:?}",
|
||||
e
|
||||
);
|
||||
info!("body: {:?}", String::from_utf8(body).unwrap());
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
return;
|
||||
}
|
||||
};
|
||||
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
|
||||
|
||||
let search_results = &search_points_response.result;
|
||||
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
|
||||
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;
|
||||
|
||||
if search_results.is_empty() {
|
||||
info!("No prompt target matched");
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
debug!(
|
||||
"similarity score: {}, intent score: {}, description embedding score: {}",
|
||||
prompt_target_similarity_score,
|
||||
zeroshot_intent_response.predicted_class_score,
|
||||
callout_context.similarity_scores.as_ref().unwrap()[0].1
|
||||
);
|
||||
|
||||
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
|
||||
|
||||
info!("similarity score: {}", search_results[0].score);
|
||||
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
|
||||
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
|
||||
let mut bolt_assistant = false;
|
||||
|
|
@ -175,7 +252,6 @@ impl StreamContext {
|
|||
let latest_assistant_message = &messages[messages.len() - 2];
|
||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||
if model.starts_with("Bolt") {
|
||||
info!("Bolt assistant message found");
|
||||
bolt_assistant = true;
|
||||
}
|
||||
}
|
||||
|
|
@ -183,23 +259,30 @@ impl StreamContext {
|
|||
info!("no assistant message found, probably first interaction");
|
||||
}
|
||||
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
|
||||
info!(
|
||||
"prompt target below threshold: {}",
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
|
||||
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
|
||||
{
|
||||
Ok(prompt_target) => prompt_target,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error deserializing prompt_target: {:?}", e));
|
||||
// check to ensure that the prompt target similarity score is above the threshold
|
||||
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
|
||||
// if bolt fc responded to the user message, then we don't need to check the similarity score
|
||||
// it may be that bolt fc is handling the conversation for parameter collection
|
||||
if bolt_assistant {
|
||||
info!("bolt assistant is handling the conversation");
|
||||
} else {
|
||||
info!(
|
||||
"prompt target below threshold: {}, continue conversation with user",
|
||||
prompt_target_similarity_score,
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(&prompt_target_name)
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
info!(
|
||||
"prompt_target name: {:?}, type: {:?}",
|
||||
prompt_target.name, prompt_target.prompt_type
|
||||
|
|
@ -231,7 +314,7 @@ impl StreamContext {
|
|||
|
||||
let tools_defintion: ToolsDefinition = ToolsDefinition {
|
||||
name: prompt_target.name.clone(),
|
||||
description: prompt_target.description.clone().unwrap_or("".to_string()),
|
||||
description: prompt_target.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
};
|
||||
|
||||
|
|
@ -283,7 +366,7 @@ impl StreamContext {
|
|||
BOLT_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
callout_context.request_type = RequestType::FunctionResolver;
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
|
|
@ -342,7 +425,7 @@ impl StreamContext {
|
|||
{
|
||||
warn!("boltfc did not extract required parameter: {}", param.name);
|
||||
return self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
Some("missing required parameter".as_bytes()),
|
||||
);
|
||||
|
|
@ -380,7 +463,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
callout_context.request_type = RequestType::FunctionCallResponse;
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
|
|
@ -468,7 +551,6 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
debug!("sending request to openai: msg {}", json_string);
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
|
@ -582,10 +664,11 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
response_handler_type: ResponseHandlerType::GetEmbeddings,
|
||||
user_message: Some(user_message),
|
||||
prompt_target: None,
|
||||
request_body: deserialized_body,
|
||||
similarity_scores: None,
|
||||
};
|
||||
if self.callouts.insert(token_id, call_context).is_some() {
|
||||
panic!(
|
||||
|
|
@ -593,6 +676,7 @@ impl HttpContext for StreamContext {
|
|||
token_id
|
||||
)
|
||||
}
|
||||
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
|
||||
Action::Pause
|
||||
|
|
@ -611,18 +695,24 @@ impl Context for StreamContext {
|
|||
self.metrics.active_http_calls.increment(-1);
|
||||
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
match callout_context.request_type {
|
||||
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
|
||||
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
|
||||
RequestType::FunctionResolver => {
|
||||
match callout_context.response_handler_type {
|
||||
ResponseHandlerType::GetEmbeddings => {
|
||||
self.embeddings_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::FunctionResolver => {
|
||||
self.function_resolver_handler(body, callout_context)
|
||||
}
|
||||
RequestType::FunctionCallResponse => {
|
||||
ResponseHandlerType::FunctionCall => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
ResponseHandlerType::ZeroShotIntent => {
|
||||
self.zero_shot_intent_detection_resp_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!("No response body in inline HTTP request");
|
||||
let error_message = "No response body in inline HTTP request";
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue