Improve prompt target intent matching (#51)

This commit is contained in:
Adil Hafeez 2024-09-16 19:20:07 -07:00 committed by GitHub
parent 8565462ec4
commit 9e50957f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 461 additions and 415 deletions

View file

@ -1,6 +1,6 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_INTENT_MODEL: &str = "tasksource/deberta-base-long-nli";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.8;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";

View file

@ -3,22 +3,20 @@ use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::debug;
use md5::Digest;
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::{
CallContext, EmbeddingRequest, StoreVectorEmbeddingsRequest, VectorPoint,
};
use public_types::common_types::EmbeddingType;
use public_types::configuration::{Configuration, PromptTarget};
use serde_json::to_string;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::{OnceLock, RwLock};
use std::time::Duration;
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
@ -33,11 +31,29 @@ impl WasmMetrics {
}
}
#[derive(Debug)]
struct CallContext {
prompt_target: String,
embedding_type: EmbeddingType,
}
pub type EmbeddingTypeMap = HashMap<EmbeddingType, Vec<f64>>;
#[derive(Debug)]
pub struct FilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: HashMap<u32, CallContext>,
config: Option<Configuration>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}
pub fn embeddings_store() -> &'static RwLock<HashMap<String, EmbeddingTypeMap>> {
static EMBEDDINGS: OnceLock<RwLock<HashMap<String, EmbeddingTypeMap>>> = OnceLock::new();
EMBEDDINGS.get_or_init(|| {
let embeddings: HashMap<String, EmbeddingTypeMap> = HashMap::new();
RwLock::new(embeddings)
})
}
impl FilterContext {
@ -46,75 +62,95 @@ impl FilterContext {
callouts: HashMap::new(),
config: None,
metrics: Rc::new(WasmMetrics::new()),
prompt_targets: Rc::new(RwLock::new(HashMap::new())),
}
}
fn process_prompt_targets(&mut self) {
for prompt_target in &self.config.as_ref().unwrap().prompt_targets {
for few_shot_example in &prompt_target.few_shot_examples {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(
few_shot_example.to_string(),
)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
// TODO: Handle potential errors
let json_data: String = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "20000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
let embedding_request = EmbeddingRequest {
create_embedding_request: embeddings_input,
// Need to clone prompt target to leave config string intact.
prompt_target: prompt_target.clone(),
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
if self
.callouts
.insert(token_id, {
CallContext::EmbeddingRequest(embedding_request)
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
panic!("Error reading prompt targets: {:?}", e);
}
};
for values in prompt_targets.iter() {
let prompt_target = &values.1;
// schedule embeddings call for prompt target name
let token_id = self.schedule_embeddings_call(prompt_target.name.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Name,
}
})
.is_some()
{
panic!("duplicate token_id")
}
// schedule embeddings call for prompt target description
let token_id = self.schedule_embeddings_call(prompt_target.description.clone());
if self
.callouts
.insert(token_id, {
CallContext {
prompt_target: prompt_target.name.clone(),
embedding_type: EmbeddingType::Description,
}
})
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
fn embedding_request_handler(
fn schedule_embeddings_call(&self, input: String) -> u32 {
let embeddings_input = CreateEmbeddingRequest {
input: Box::new(CreateEmbeddingRequestInput::String(input)),
model: String::from(DEFAULT_EMBEDDING_MODEL),
encoding_format: None,
dimensions: None,
user: None,
};
let json_data = to_string(&embeddings_input).unwrap();
let token_id = match self.dispatch_http_call(
"embeddingserver",
vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(60),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
token_id
}
fn embedding_response_handler(
&mut self,
body_size: usize,
create_embedding_request: CreateEmbeddingRequest,
prompt_target: PromptTarget,
embedding_type: EmbeddingType,
prompt_target_name: String,
) {
let prompt_targets = self.prompt_targets.read().unwrap();
let prompt_target = prompt_targets.get(&prompt_target_name).unwrap();
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
@ -129,111 +165,24 @@ impl FilterContext {
}
};
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
"prompt-target".to_string(),
to_string(&prompt_target).unwrap(),
let embeddings = embedding_response.data.remove(0).embedding;
log::info!(
"Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}",
prompt_target.name,
prompt_target.description,
embedding_type
);
let id: Option<Digest>;
match *create_embedding_request.input {
CreateEmbeddingRequestInput::String(input) => {
id = Some(md5::compute(&input));
payload.insert("input".to_string(), input);
}
CreateEmbeddingRequestInput::Array(_) => todo!(),
}
let create_vector_store_points = StoreVectorEmbeddingsRequest {
points: vec![VectorPoint {
id: format!("{:x}", id.unwrap()),
payload,
vector: embedding_response.data.remove(0).embedding,
}],
};
let json_data = to_string(&create_vector_store_points).unwrap(); // Handle potential errors
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store/points"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::StoreVectorEmbeddings(create_vector_store_points),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
embeddings_store().write().unwrap().insert(
prompt_target.name.clone(),
HashMap::from([(embedding_type, embeddings)]),
);
}
} else {
panic!("No body in response");
}
}
fn create_vector_store_points_handler(&self, body_size: usize) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
debug!(
"response body: len {:?}",
String::from_utf8(body).unwrap().len()
);
}
}
}
//TODO: run once per envoy instance, right now it runs once per worker
fn init_vector_store(&mut self) {
let token_id = match self.dispatch_http_call(
"qdrant",
vec![
(":method", "PUT"),
(":path", "/collections/prompt_vector_store"),
(":authority", "qdrant"),
("content-type", "application/json"),
],
Some(b"{ \"vectors\": { \"size\": 1024, \"distance\": \"Cosine\"}}"),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for init-vector-store: {:?}", e);
}
};
if self
.callouts
.insert(
token_id,
CallContext::CreateVectorCollection("prompt_vector_store".to_string()),
)
.is_some()
{
panic!("duplicate token_id")
}
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
}
impl Context for FilterContext {
fn on_http_call_response(
&mut self,
@ -242,37 +191,18 @@ impl Context for FilterContext {
body_size: usize,
_num_trailers: usize,
) {
let callout_data = self
.callouts
.remove(&token_id)
.expect("invalid token_id: {}");
debug!("on_http_call_response called with token_id: {:?}", token_id);
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
self.metrics
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
match callout_data {
CallContext::EmbeddingRequest(EmbeddingRequest {
create_embedding_request,
prompt_target,
}) => {
self.embedding_request_handler(body_size, create_embedding_request, prompt_target)
}
CallContext::StoreVectorEmbeddings(_) => {
self.create_vector_store_points_handler(body_size)
}
CallContext::CreateVectorCollection(_) => {
let mut http_status_code = "Nil".to_string();
self.get_http_call_response_headers()
.iter()
.for_each(|(k, v)| {
if k == ":status" {
http_status_code.clone_from(v);
}
});
debug!("CreateVectorCollection response: {}", http_status_code);
}
}
self.embedding_response_handler(
body_size,
callout_data.embedding_type,
callout_data.prompt_target,
)
}
}
@ -282,6 +212,13 @@ impl RootContext for FilterContext {
if let Some(config_bytes) = self.get_plugin_configuration() {
self.config = serde_yaml::from_slice(&config_bytes).unwrap();
for pt in self.config.clone().unwrap().prompt_targets {
self.prompt_targets
.write()
.unwrap()
.insert(pt.name.clone(), pt.clone());
}
debug!("set configuration object: {:?}", self.config);
if let Some(ratelimits_config) = self
@ -301,6 +238,7 @@ impl RootContext for FilterContext {
ratelimit_selector: None,
callouts: HashMap::new(),
metrics: Rc::clone(&self.metrics),
prompt_targets: Rc::clone(&self.prompt_targets),
}))
}
@ -314,7 +252,6 @@ impl RootContext for FilterContext {
}
fn on_tick(&mut self) {
self.init_vector_store();
self.process_prompt_targets();
self.set_tick_period(Duration::from_secs(0));
}

View file

@ -33,7 +33,7 @@ pub trait RecordingMetric: Metric {
}
}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct Counter {
id: u32,
}
@ -55,7 +55,7 @@ impl Metric for Counter {
impl IncrementingMetric for Counter {}
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
pub struct Gauge {
id: u32,
}

View file

@ -1,13 +1,14 @@
use crate::consts::{
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
USER_ROLE,
};
use crate::filter_context::WasmMetrics;
use crate::filter_context::{embeddings_store, WasmMetrics};
use crate::ratelimit;
use crate::ratelimit::Header;
use crate::stats::IncrementingMetric;
use crate::tokenizer;
use acap::cos;
use http::StatusCode;
use log::{debug, error, info, warn};
use open_message_format_embeddings::models::{
@ -15,31 +16,31 @@ use open_message_format_embeddings::models::{
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{ChatCompletions, Message};
use public_types::common_types::{
open_ai::{ChatCompletions, Message},
SearchPointsRequest, SearchPointsResponse,
};
use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::RwLock;
use std::time::Duration;
enum RequestType {
GetEmbedding,
SearchPoints,
enum ResponseHandlerType {
GetEmbeddings,
FunctionResolver,
FunctionCallResponse,
FunctionCall,
ZeroShotIntent,
}
pub struct CallContext {
request_type: RequestType,
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target: Option<PromptTarget>,
request_body: ChatCompletions,
similarity_scores: Option<Vec<(String, f64)>>,
}
pub struct StreamContext {
@ -47,6 +48,7 @@ pub struct StreamContext {
pub ratelimit_selector: Option<Header>,
pub callouts: HashMap<u32, CallContext>,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
}
impl StreamContext {
@ -61,7 +63,6 @@ impl StreamContext {
// However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could
// manipulate the body in benign ways e.g., compression.
self.set_http_request_header("content-length", None);
// self.set_http_request_header("authorization", None);
}
fn modify_path_header(&mut self) {
@ -85,7 +86,7 @@ impl StreamContext {
});
}
fn send_server_error(&mut self, error: String) {
fn send_server_error(&self, error: String) {
debug!("server error occurred: {}", error);
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
@ -103,30 +104,85 @@ impl StreamContext {
}
};
let search_points_request = SearchPointsRequest {
vector: embedding_response.data[0].embedding.clone(),
limit: 10,
with_payload: true,
};
let embeddings_vector = &embedding_response.data[0].embedding;
let json_data: String = match serde_json::to_string(&search_points_request) {
Ok(json_data) => json_data,
debug!(
"embedding model: {}, vector length: {:?}",
embedding_response.model,
embeddings_vector.len()
);
let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings,
Err(e) => {
self.send_server_error(format!("Error serializing search_points_request: {:?}", e));
let error_message = format!("Error reading embeddings store: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
return;
}
};
let path = format!("/collections/{}/points/search", DEFAULT_COLLECTION_NAME);
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
let error_message = format!("Error reading prompt targets: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
return;
}
};
let prompt_target_names = prompt_targets
.iter()
.map(|(name, _)| name.clone())
.collect();
let similarity_scores: Vec<(String, f64)> = prompt_targets
.iter()
.map(|(prompt_name, _prompt_target)| {
let default_embeddings = HashMap::new();
let pte = prompt_target_embeddings
.get(prompt_name)
.unwrap_or(&default_embeddings);
let description_embeddings = pte.get(&EmbeddingType::Description);
let similarity_score_description = cos::cosine_similarity(
&embeddings_vector,
&description_embeddings.unwrap_or(&vec![0.0]),
);
(prompt_name.clone(), similarity_score_description)
})
.collect();
debug!(
"similarity scores based on description embeddings match: {:?}",
similarity_scores
);
callout_context.similarity_scores = Some(similarity_scores);
let zero_shot_classification_request = ZeroShotClassificationRequest {
// Need to clone into input because user_message is used below.
input: callout_context.user_message.as_ref().unwrap().clone(),
model: String::from(DEFAULT_INTENT_MODEL),
labels: prompt_target_names,
};
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing zero shot request: {}", error);
}
};
let token_id = match self.dispatch_http_call(
"qdrant",
"embeddingserver",
vec![
(":method", "POST"),
(":path", &path),
(":authority", "qdrant"),
(":path", "/zeroshot"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
@ -134,39 +190,60 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
panic!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
}
};
debug!(
"dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}",
token_id
);
callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent;
callout_context.request_type = RequestType::SearchPoints;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
panic!(
"duplicate token_id={} in embedding server requests",
token_id
)
}
self.metrics.active_http_calls.increment(1);
}
fn search_points_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let search_points_response: SearchPointsResponse = match serde_json::from_slice(&body) {
Ok(search_points_response) => search_points_response,
Err(e) => {
self.send_server_error(format!(
"Error deserializing search_points_response: {:?}",
e
));
fn zero_shot_intent_detection_resp_handler(
&mut self,
body: Vec<u8>,
mut callout_context: CallContext,
) {
let zeroshot_intent_response: ZeroShotClassificationResponse =
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
warn!(
"Error deserializing zeroshot intent detection response: {:?}",
e
);
info!("body: {:?}", String::from_utf8(body).unwrap());
self.resume_http_request();
return;
}
};
return;
}
};
debug!("zeroshot intent response: {:?}", zeroshot_intent_response);
let search_results = &search_points_response.result;
let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7
+ callout_context.similarity_scores.as_ref().unwrap()[0].1 * 0.3;
if search_results.is_empty() {
info!("No prompt target matched");
self.resume_http_request();
return;
}
debug!(
"similarity score: {}, intent score: {}, description embedding score: {}",
prompt_target_similarity_score,
zeroshot_intent_response.predicted_class_score,
callout_context.similarity_scores.as_ref().unwrap()[0].1
);
let prompt_target_name = zeroshot_intent_response.predicted_class.clone();
info!("similarity score: {}", search_results[0].score);
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
let mut bolt_assistant = false;
@ -175,7 +252,6 @@ impl StreamContext {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with("Bolt") {
info!("Bolt assistant message found");
bolt_assistant = true;
}
}
@ -183,23 +259,30 @@ impl StreamContext {
info!("no assistant message found, probably first interaction");
}
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
info!(
"prompt target below threshold: {}",
DEFAULT_PROMPT_TARGET_THRESHOLD
);
self.resume_http_request();
return;
}
let prompt_target_str = search_results[0].payload.get("prompt-target").unwrap();
let prompt_target: PromptTarget = match serde_json::from_slice(prompt_target_str.as_bytes())
{
Ok(prompt_target) => prompt_target,
Err(e) => {
self.send_server_error(format!("Error deserializing prompt_target: {:?}", e));
// check to ensure that the prompt target similarity score is above the threshold
if prompt_target_similarity_score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
// if bolt fc responded to the user message, then we don't need to check the similarity score
// it may be that bolt fc is handling the conversation for parameter collection
if bolt_assistant {
info!("bolt assistant is handling the conversation");
} else {
info!(
"prompt target below threshold: {}, continue conversation with user",
prompt_target_similarity_score,
);
self.resume_http_request();
return;
}
};
}
let prompt_target = self
.prompt_targets
.read()
.unwrap()
.get(&prompt_target_name)
.unwrap()
.clone();
info!(
"prompt_target name: {:?}, type: {:?}",
prompt_target.name, prompt_target.prompt_type
@ -231,7 +314,7 @@ impl StreamContext {
let tools_defintion: ToolsDefinition = ToolsDefinition {
name: prompt_target.name.clone(),
description: prompt_target.description.clone().unwrap_or("".to_string()),
description: prompt_target.description.clone(),
parameters: tools_parameters,
};
@ -283,7 +366,7 @@ impl StreamContext {
BOLT_FC_CLUSTER, token_id
);
callout_context.request_type = RequestType::FunctionResolver;
callout_context.response_handler_type = ResponseHandlerType::FunctionResolver;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
@ -342,7 +425,7 @@ impl StreamContext {
{
warn!("boltfc did not extract required parameter: {}", param.name);
return self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
Some("missing required parameter".as_bytes()),
);
@ -380,7 +463,7 @@ impl StreamContext {
}
};
callout_context.request_type = RequestType::FunctionCallResponse;
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
@ -468,7 +551,6 @@ impl StreamContext {
}
}
debug!("sending request to openai: msg {}", json_string);
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.resume_http_request();
}
@ -582,10 +664,11 @@ impl HttpContext for StreamContext {
);
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
response_handler_type: ResponseHandlerType::GetEmbeddings,
user_message: Some(user_message),
prompt_target: None,
request_body: deserialized_body,
similarity_scores: None,
};
if self.callouts.insert(token_id, call_context).is_some() {
panic!(
@ -593,6 +676,7 @@ impl HttpContext for StreamContext {
token_id
)
}
self.metrics.active_http_calls.increment(1);
Action::Pause
@ -611,18 +695,24 @@ impl Context for StreamContext {
self.metrics.active_http_calls.increment(-1);
if let Some(body) = self.get_http_call_response_body(0, body_size) {
match callout_context.request_type {
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
RequestType::FunctionResolver => {
match callout_context.response_handler_type {
ResponseHandlerType::GetEmbeddings => {
self.embeddings_handler(body, callout_context)
}
ResponseHandlerType::FunctionResolver => {
self.function_resolver_handler(body, callout_context)
}
RequestType::FunctionCallResponse => {
ResponseHandlerType::FunctionCall => {
self.function_call_response_handler(body, callout_context)
}
ResponseHandlerType::ZeroShotIntent => {
self.zero_shot_intent_detection_resp_handler(body, callout_context)
}
}
} else {
warn!("No response body in inline HTTP request");
let error_message = "No response body in inline HTTP request";
warn!("{}", error_message);
self.send_server_error(error_message.to_owned());
}
}
}