mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
Add function calling support using bolt-fc-1b (#35)
This commit is contained in:
parent
fdfad87347
commit
7b5203a2ce
39 changed files with 1763 additions and 416 deletions
|
|
@ -1,8 +1,9 @@
|
|||
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
|
||||
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
|
||||
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
|
||||
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
|
||||
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
|
||||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
||||
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ use crate::consts::DEFAULT_EMBEDDING_MODEL;
|
|||
use crate::ratelimit;
|
||||
use crate::stats::{Counter, Gauge, RecordingMetric};
|
||||
use crate::stream_context::StreamContext;
|
||||
use log::{debug, info};
|
||||
use log::debug;
|
||||
use md5::Digest;
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -72,6 +72,8 @@ impl FilterContext {
|
|||
(":path", "/embeddings"),
|
||||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "20000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -87,6 +89,10 @@ impl FilterContext {
|
|||
// Need to clone prompt target to leave config string intact.
|
||||
prompt_target: prompt_target.clone(),
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
if self
|
||||
.callouts
|
||||
.insert(token_id, {
|
||||
|
|
@ -112,7 +118,16 @@ impl FilterContext {
|
|||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
let mut embedding_response: CreateEmbeddingResponse =
|
||||
serde_json::from_slice(&body).unwrap();
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error deserializing embedding response. body: {:?}: {:?}",
|
||||
String::from_utf8(body).unwrap(),
|
||||
e
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
let mut payload: HashMap<String, String> = HashMap::new();
|
||||
payload.insert(
|
||||
|
|
@ -168,13 +183,15 @@ impl FilterContext {
|
|||
.active_http_calls
|
||||
.record(self.callouts.len().try_into().unwrap());
|
||||
}
|
||||
} else {
|
||||
panic!("No body in response");
|
||||
}
|
||||
}
|
||||
|
||||
fn create_vector_store_points_handler(&self, body_size: usize) {
|
||||
if let Some(body) = self.get_http_call_response_body(0, body_size) {
|
||||
if !body.is_empty() {
|
||||
info!(
|
||||
debug!(
|
||||
"response body: len {:?}",
|
||||
String::from_utf8(body).unwrap().len()
|
||||
);
|
||||
|
|
@ -225,7 +242,10 @@ impl Context for FilterContext {
|
|||
body_size: usize,
|
||||
_num_trailers: usize,
|
||||
) {
|
||||
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
|
||||
let callout_data = self
|
||||
.callouts
|
||||
.remove(&token_id)
|
||||
.expect("invalid token_id: {}");
|
||||
|
||||
self.metrics
|
||||
.active_http_calls
|
||||
|
|
@ -250,7 +270,7 @@ impl Context for FilterContext {
|
|||
http_status_code.clone_from(v);
|
||||
}
|
||||
});
|
||||
info!("CreateVectorCollection response: {}", http_status_code);
|
||||
debug!("CreateVectorCollection response: {}", http_status_code);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use crate::consts::{
|
||||
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
|
||||
USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::WasmMetrics;
|
||||
use crate::ratelimit;
|
||||
|
|
@ -16,9 +17,12 @@ use proxy_wasm::traits::*;
|
|||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::{
|
||||
open_ai::{ChatCompletions, Message},
|
||||
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
|
||||
SearchPointsRequest, SearchPointsResponse,
|
||||
};
|
||||
use public_types::configuration::{Entity, PromptTarget};
|
||||
use public_types::common_types::{
|
||||
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
};
|
||||
use public_types::configuration::{PromptTarget, PromptType};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -27,8 +31,8 @@ use std::time::Duration;
|
|||
enum RequestType {
|
||||
GetEmbedding,
|
||||
SearchPoints,
|
||||
Ner,
|
||||
ContextResolver,
|
||||
FunctionResolver,
|
||||
FunctionCallResponse,
|
||||
}
|
||||
|
||||
pub struct CallContext {
|
||||
|
|
@ -153,8 +157,23 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
info!("similarity score: {}", search_results[0].score);
|
||||
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
|
||||
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
|
||||
let mut bolt_assistant = false;
|
||||
let messages = &callout_context.request_body.messages;
|
||||
if messages.len() >= 2 {
|
||||
let latest_assistant_message = &messages[messages.len() - 2];
|
||||
if let Some(model) = latest_assistant_message.model.as_ref() {
|
||||
if model.starts_with("Bolt") {
|
||||
info!("Bolt assistant message found");
|
||||
bolt_assistant = true;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
info!("no assistant message found, probably first interaction");
|
||||
}
|
||||
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
|
||||
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
|
||||
info!(
|
||||
"prompt target below threshold: {}",
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD
|
||||
|
|
@ -172,177 +191,237 @@ impl StreamContext {
|
|||
return;
|
||||
}
|
||||
};
|
||||
info!("prompt_target name: {:?}", prompt_target.name);
|
||||
info!(
|
||||
"prompt_target name: {:?}, type: {:?}",
|
||||
prompt_target.name, prompt_target.prompt_type
|
||||
);
|
||||
|
||||
// only extract entity names
|
||||
let entity_names: Vec<String> = match prompt_target.entities {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => entities.iter().map(|entity| entity.name.clone()).collect(),
|
||||
None => vec![],
|
||||
};
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, ToolParameter> = match prompt_target.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = ToolParameter {
|
||||
parameter_type: entity.parameter_type.clone(),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
};
|
||||
properties.insert(entity.name.clone(), param);
|
||||
}
|
||||
properties
|
||||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = ToolParameters {
|
||||
parameters_type: "dict".to_string(),
|
||||
properties,
|
||||
};
|
||||
|
||||
let ner_request = NERRequest {
|
||||
input: callout_context.user_message.take().unwrap(),
|
||||
labels: entity_names,
|
||||
model: DEFAULT_NER_MODEL.to_string(),
|
||||
};
|
||||
let tools_defintion: ToolsDefinition = ToolsDefinition {
|
||||
name: prompt_target.name.clone(),
|
||||
description: prompt_target.description.clone().unwrap_or("".to_string()),
|
||||
parameters: tools_parameters,
|
||||
};
|
||||
|
||||
let json_data: String = match serde_json::to_string(&ner_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(e) => {
|
||||
warn!("Error serializing ner_request: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
let chat_completions = ChatCompletions {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(vec![tools_defintion]),
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
Ok(msg_body) => {
|
||||
debug!("msg_body: {}", msg_body);
|
||||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_params: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
BOLT_FC_CLUSTER,
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/v1/chat/completions"),
|
||||
(":authority", BOLT_FC_CLUSTER),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
(
|
||||
"x-envoy-upstream-rq-timeout-ms",
|
||||
BOLT_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
|
||||
),
|
||||
],
|
||||
Some(msg_body.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
}
|
||||
};
|
||||
|
||||
debug!(
|
||||
"dispatched call to function {} token_id={}",
|
||||
BOLT_FC_CLUSTER, token_id
|
||||
);
|
||||
|
||||
callout_context.request_type = RequestType::FunctionResolver;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let token_id = match self.dispatch_http_call(
|
||||
"nerhost",
|
||||
vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/ner"),
|
||||
(":authority", "nerhost"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::Ner;
|
||||
callout_context.prompt_target = Some(prompt_target);
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let ner_response: NERResponse = match serde_json::from_slice(&body) {
|
||||
Ok(ner_response) => ner_response,
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
debug!("response received for function resolver");
|
||||
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("function_resolver response str: {:?}", body_str);
|
||||
|
||||
let mut boltfc_response: BoltFCResponse = serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
let boltfc_response_str = boltfc_response.message.content.as_ref().unwrap();
|
||||
|
||||
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
|
||||
Ok(fc_resp) => fc_resp,
|
||||
Err(e) => {
|
||||
warn!("Error deserializing ner_response: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
info!("ner_response: {:?}", ner_response);
|
||||
// This means that Bolt FC did not have enough information to resolve the function call
|
||||
// Bolt FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
|
||||
let mut request_params: HashMap<String, String> = HashMap::new();
|
||||
for entity in ner_response.data.into_iter() {
|
||||
if entity.score < DEFAULT_NER_THRESHOLD {
|
||||
warn!(
|
||||
"score of entity was too low entity name: {}, score: {}",
|
||||
entity.label, entity.score
|
||||
// add resolver name to the response so the client can send the response back to the correct resolver
|
||||
boltfc_response.resolver_name = Some(callout_context.prompt_target.unwrap().name);
|
||||
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
|
||||
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(bolt_fc_dialogue_message.as_bytes()),
|
||||
);
|
||||
continue;
|
||||
}
|
||||
request_params.insert(entity.label, entity.text);
|
||||
}
|
||||
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
|
||||
let empty_vec: Vec<Entity> = vec![];
|
||||
for entity in prompt_target.entities.as_ref().unwrap_or(&empty_vec) {
|
||||
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
|
||||
warn!(
|
||||
"required entity missing or score of entity was too low: {}",
|
||||
entity.name
|
||||
);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
let req_param_bytes = match serde_json::to_string(&request_params) {
|
||||
Ok(req_param_str) => req_param_str.as_bytes().to_owned(),
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_params: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint = callout_context
|
||||
// verify required parameters are present
|
||||
callout_context
|
||||
.prompt_target
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.endpoint
|
||||
.parameters
|
||||
.as_ref()
|
||||
.unwrap();
|
||||
.unwrap()
|
||||
.iter()
|
||||
.for_each(|param| match param.required {
|
||||
None => {}
|
||||
Some(required) => {
|
||||
if required
|
||||
&& !tools_call_response.tool_calls[0]
|
||||
.arguments
|
||||
.contains_key(¶m.name)
|
||||
{
|
||||
warn!("boltfc did not extract required parameter: {}", param.name);
|
||||
return self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some("missing required parameter".as_bytes()),
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let http_path = match &endpoint.path {
|
||||
Some(path) => path,
|
||||
None => "/",
|
||||
};
|
||||
debug!("tool_call_details: {:?}", tools_call_response);
|
||||
let tool_name = &tools_call_response.tool_calls[0].name;
|
||||
let tool_params = &tools_call_response.tool_calls[0].arguments;
|
||||
debug!("tool_name: {:?}", tool_name);
|
||||
debug!("tool_params: {:?}", tool_params);
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
debug!("prompt_target: {:?}", prompt_target);
|
||||
|
||||
let http_method = match &endpoint.method {
|
||||
Some(method) => method,
|
||||
None => http::Method::POST.as_str(),
|
||||
};
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let endpoint = prompt_target.endpoint.as_ref().unwrap();
|
||||
let token_id = match self.dispatch_http_call(
|
||||
&endpoint.cluster,
|
||||
vec![
|
||||
(":method", http_method),
|
||||
(":path", http_path),
|
||||
(":method", "POST"),
|
||||
(":path", endpoint.path.as_ref().unwrap_or(&"/".to_string())),
|
||||
(":authority", endpoint.cluster.as_str()),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
],
|
||||
Some(&req_param_bytes),
|
||||
Some(tool_params_json_str.as_bytes()),
|
||||
vec![],
|
||||
Duration::from_secs(5),
|
||||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for context_resolver: {:?}", e);
|
||||
panic!("Error dispatching HTTP call for function_resolver: {:?}", e);
|
||||
}
|
||||
};
|
||||
callout_context.request_type = RequestType::ContextResolver;
|
||||
|
||||
callout_context.request_type = RequestType::FunctionCallResponse;
|
||||
if self.callouts.insert(token_id, callout_context).is_some() {
|
||||
panic!("duplicate token_id")
|
||||
}
|
||||
self.metrics.active_http_calls.increment(1);
|
||||
}
|
||||
|
||||
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
debug!("response received for context_resolver");
|
||||
let body_string = String::from_utf8(body);
|
||||
let prompt_target = callout_context.prompt_target.unwrap();
|
||||
let mut request_body = callout_context.request_body;
|
||||
match prompt_target.system_prompt {
|
||||
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
|
||||
debug!("response received for function call response");
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
debug!("function_call_response response str: {:?}", body_str);
|
||||
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
|
||||
|
||||
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
|
||||
|
||||
// add system prompt
|
||||
match prompt_target.system_prompt.as_ref() {
|
||||
None => {}
|
||||
Some(system_prompt) => {
|
||||
let system_prompt_message: Message = Message {
|
||||
let system_prompt_message = Message {
|
||||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
};
|
||||
request_body.messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
match body_string {
|
||||
Ok(body_string) => {
|
||||
info!("context_resolver response: {}", body_string);
|
||||
let context_resolver_response = Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_string),
|
||||
};
|
||||
request_body.messages.push(context_resolver_response);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Error converting response to string: {:?}", e);
|
||||
self.resume_http_request();
|
||||
return;
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
}
|
||||
|
||||
let json_string = match serde_json::to_string(&request_body) {
|
||||
// add data from function call response
|
||||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_str),
|
||||
model: None,
|
||||
}
|
||||
});
|
||||
|
||||
// add original user prompt
|
||||
messages.push({
|
||||
Message {
|
||||
role: USER_ROLE.to_string(),
|
||||
content: Some(callout_context.user_message.unwrap()),
|
||||
model: None,
|
||||
}
|
||||
});
|
||||
|
||||
let request_message: ChatCompletions = ChatCompletions {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let json_string = match serde_json::to_string(&request_message) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
warn!("Error serializing request_body: {:?}", e);
|
||||
|
|
@ -350,6 +429,12 @@ impl StreamContext {
|
|||
return;
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"function_calling sending request to openai: msg {}",
|
||||
json_string
|
||||
);
|
||||
|
||||
let request_body = callout_context.request_body;
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
|
|
@ -405,8 +490,7 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size)
|
||||
{
|
||||
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
|
|
@ -434,12 +518,12 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
.pop()
|
||||
.and_then(|last_message| last_message.content)
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.clone())
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
info!("No messages in the request body");
|
||||
warn!("No messages in the request body");
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
|
@ -468,6 +552,7 @@ impl HttpContext for StreamContext {
|
|||
(":authority", "embeddingserver"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
],
|
||||
Some(json_data.as_bytes()),
|
||||
vec![],
|
||||
|
|
@ -481,6 +566,11 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"dispatched HTTP call to embedding server token_id={}",
|
||||
token_id
|
||||
);
|
||||
|
||||
let call_context = CallContext {
|
||||
request_type: RequestType::GetEmbedding,
|
||||
user_message: Some(user_message),
|
||||
|
|
@ -530,8 +620,10 @@ impl Context for StreamContext {
|
|||
match callout_context.request_type {
|
||||
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
|
||||
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
|
||||
RequestType::Ner => self.ner_handler(body, callout_context),
|
||||
RequestType::ContextResolver => self.context_resolver_handler(body, callout_context),
|
||||
RequestType::FunctionResolver => self.function_resolver_handler(body, callout_context),
|
||||
RequestType::FunctionCallResponse => {
|
||||
self.function_call_response_handler(body, callout_context)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue