Add function calling support using bolt-fc-1b (#35)

This commit is contained in:
Adil Hafeez 2024-09-10 14:24:46 -07:00 committed by GitHub
parent fdfad87347
commit 7b5203a2ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1763 additions and 416 deletions

View file

@ -1,8 +1,9 @@
pub const DEFAULT_EMBEDDING_MODEL: &str = "BAAI/bge-large-en-v1.5";
pub const DEFAULT_COLLECTION_NAME: &str = "prompt_vector_store";
pub const DEFAULT_NER_MODEL: &str = "urchade/gliner_large-v2.1";
pub const DEFAULT_PROMPT_TARGET_THRESHOLD: f64 = 0.6;
pub const DEFAULT_NER_THRESHOLD: f64 = 0.6;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-katanemo-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes

View file

@ -2,7 +2,7 @@ use crate::consts::DEFAULT_EMBEDDING_MODEL;
use crate::ratelimit;
use crate::stats::{Counter, Gauge, RecordingMetric};
use crate::stream_context::StreamContext;
use log::{debug, info};
use log::debug;
use md5::Digest;
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -72,6 +72,8 @@ impl FilterContext {
(":path", "/embeddings"),
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "20000"),
],
Some(json_data.as_bytes()),
vec![],
@ -87,6 +89,10 @@ impl FilterContext {
// Need to clone prompt target to leave config string intact.
prompt_target: prompt_target.clone(),
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
if self
.callouts
.insert(token_id, {
@ -112,7 +118,16 @@ impl FilterContext {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
let mut embedding_response: CreateEmbeddingResponse =
serde_json::from_slice(&body).unwrap();
match serde_json::from_slice(&body) {
Ok(response) => response,
Err(e) => {
panic!(
"Error deserializing embedding response. body: {:?}: {:?}",
String::from_utf8(body).unwrap(),
e
);
}
};
let mut payload: HashMap<String, String> = HashMap::new();
payload.insert(
@ -168,13 +183,15 @@ impl FilterContext {
.active_http_calls
.record(self.callouts.len().try_into().unwrap());
}
} else {
panic!("No body in response");
}
}
fn create_vector_store_points_handler(&self, body_size: usize) {
if let Some(body) = self.get_http_call_response_body(0, body_size) {
if !body.is_empty() {
info!(
debug!(
"response body: len {:?}",
String::from_utf8(body).unwrap().len()
);
@ -225,7 +242,10 @@ impl Context for FilterContext {
body_size: usize,
_num_trailers: usize,
) {
let callout_data = self.callouts.remove(&token_id).expect("invalid token_id");
let callout_data = self
.callouts
.remove(&token_id)
.expect("invalid token_id: {}");
self.metrics
.active_http_calls
@ -250,7 +270,7 @@ impl Context for FilterContext {
http_status_code.clone_from(v);
}
});
info!("CreateVectorCollection response: {}", http_status_code);
debug!("CreateVectorCollection response: {}", http_status_code);
}
}
}

View file

@ -1,6 +1,7 @@
use crate::consts::{
DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL, DEFAULT_NER_MODEL, DEFAULT_NER_THRESHOLD,
DEFAULT_PROMPT_TARGET_THRESHOLD, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_COLLECTION_NAME, DEFAULT_EMBEDDING_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE,
USER_ROLE,
};
use crate::filter_context::WasmMetrics;
use crate::ratelimit;
@ -16,9 +17,12 @@ use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::{
open_ai::{ChatCompletions, Message},
NERRequest, NERResponse, SearchPointsRequest, SearchPointsResponse,
SearchPointsRequest, SearchPointsResponse,
};
use public_types::configuration::{Entity, PromptTarget};
use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, ToolParameter, ToolParameters, ToolsDefinition,
};
use public_types::configuration::{PromptTarget, PromptType};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
@ -27,8 +31,8 @@ use std::time::Duration;
enum RequestType {
GetEmbedding,
SearchPoints,
Ner,
ContextResolver,
FunctionResolver,
FunctionCallResponse,
}
pub struct CallContext {
@ -153,8 +157,23 @@ impl StreamContext {
}
info!("similarity score: {}", search_results[0].score);
// Check to see who responded to user message. This will help us identify if control should be passed to Bolt FC or not.
// If the last message was from Bolt FC, then Bolt FC is handling the conversation (possibly for parameter collection).
let mut bolt_assistant = false;
let messages = &callout_context.request_body.messages;
if messages.len() >= 2 {
let latest_assistant_message = &messages[messages.len() - 2];
if let Some(model) = latest_assistant_message.model.as_ref() {
if model.starts_with("Bolt") {
info!("Bolt assistant message found");
bolt_assistant = true;
}
}
} else {
info!("no assistant message found, probably first interaction");
}
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD {
if search_results[0].score < DEFAULT_PROMPT_TARGET_THRESHOLD && !bolt_assistant {
info!(
"prompt target below threshold: {}",
DEFAULT_PROMPT_TARGET_THRESHOLD
@ -172,177 +191,237 @@ impl StreamContext {
return;
}
};
info!("prompt_target name: {:?}", prompt_target.name);
info!(
"prompt_target name: {:?}, type: {:?}",
prompt_target.name, prompt_target.prompt_type
);
// only extract entity names
let entity_names: Vec<String> = match prompt_target.entities {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => entities.iter().map(|entity| entity.name.clone()).collect(),
None => vec![],
};
match prompt_target.prompt_type {
PromptType::FunctionResolver => {
// only extract entity names
let properties: HashMap<String, ToolParameter> = match prompt_target.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
for entity in entities.iter() {
let param = ToolParameter {
parameter_type: entity.parameter_type.clone(),
description: entity.description.clone(),
required: entity.required,
};
properties.insert(entity.name.clone(), param);
}
properties
}
None => HashMap::new(),
};
let tools_parameters = ToolParameters {
parameters_type: "dict".to_string(),
properties,
};
let ner_request = NERRequest {
input: callout_context.user_message.take().unwrap(),
labels: entity_names,
model: DEFAULT_NER_MODEL.to_string(),
};
let tools_defintion: ToolsDefinition = ToolsDefinition {
name: prompt_target.name.clone(),
description: prompt_target.description.clone().unwrap_or("".to_string()),
parameters: tools_parameters,
};
let json_data: String = match serde_json::to_string(&ner_request) {
Ok(json_data) => json_data,
Err(e) => {
warn!("Error serializing ner_request: {:?}", e);
self.resume_http_request();
return;
let chat_completions = ChatCompletions {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(vec![tools_defintion]),
};
let msg_body = match serde_json::to_string(&chat_completions) {
Ok(msg_body) => {
debug!("msg_body: {}", msg_body);
msg_body
}
Err(e) => {
warn!("Error serializing request_params: {:?}", e);
self.resume_http_request();
return;
}
};
let token_id = match self.dispatch_http_call(
BOLT_FC_CLUSTER,
vec![
(":method", "POST"),
(":path", "/v1/chat/completions"),
(":authority", BOLT_FC_CLUSTER),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
(
"x-envoy-upstream-rq-timeout-ms",
BOLT_FC_REQUEST_TIMEOUT_MS.to_string().as_str(),
),
],
Some(msg_body.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for function-call: {:?}", e);
}
};
debug!(
"dispatched call to function {} token_id={}",
BOLT_FC_CLUSTER, token_id
);
callout_context.request_type = RequestType::FunctionResolver;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
}
};
let token_id = match self.dispatch_http_call(
"nerhost",
vec![
(":method", "POST"),
(":path", "/ner"),
(":authority", "nerhost"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(json_data.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for get-embeddings: {:?}", e);
}
};
callout_context.request_type = RequestType::Ner;
callout_context.prompt_target = Some(prompt_target);
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
self.metrics.active_http_calls.increment(1);
}
fn ner_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let ner_response: NERResponse = match serde_json::from_slice(&body) {
Ok(ner_response) => ner_response,
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
debug!("response received for function resolver");
let body_str = String::from_utf8(body).unwrap();
debug!("function_resolver response str: {:?}", body_str);
let mut boltfc_response: BoltFCResponse = serde_json::from_str(&body_str).unwrap();
let boltfc_response_str = boltfc_response.message.content.as_ref().unwrap();
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
Ok(fc_resp) => fc_resp,
Err(e) => {
warn!("Error deserializing ner_response: {:?}", e);
self.resume_http_request();
return;
}
};
info!("ner_response: {:?}", ner_response);
// This means that Bolt FC did not have enough information to resolve the function call
// Bolt FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
let mut request_params: HashMap<String, String> = HashMap::new();
for entity in ner_response.data.into_iter() {
if entity.score < DEFAULT_NER_THRESHOLD {
warn!(
"score of entity was too low entity name: {}, score: {}",
entity.label, entity.score
// add resolver name to the response so the client can send the response back to the correct resolver
boltfc_response.resolver_name = Some(callout_context.prompt_target.unwrap().name);
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(bolt_fc_dialogue_message.as_bytes()),
);
continue;
}
request_params.insert(entity.label, entity.text);
}
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let empty_vec: Vec<Entity> = vec![];
for entity in prompt_target.entities.as_ref().unwrap_or(&empty_vec) {
if entity.required.unwrap_or(false) && !request_params.contains_key(&entity.name) {
warn!(
"required entity missing or score of entity was too low: {}",
entity.name
);
self.resume_http_request();
return;
}
}
let req_param_bytes = match serde_json::to_string(&request_params) {
Ok(req_param_str) => req_param_str.as_bytes().to_owned(),
Err(e) => {
warn!("Error serializing request_params: {:?}", e);
self.resume_http_request();
return;
}
};
let endpoint = callout_context
// verify required parameters are present
callout_context
.prompt_target
.as_ref()
.unwrap()
.endpoint
.parameters
.as_ref()
.unwrap();
.unwrap()
.iter()
.for_each(|param| match param.required {
None => {}
Some(required) => {
if required
&& !tools_call_response.tool_calls[0]
.arguments
.contains_key(&param.name)
{
warn!("boltfc did not extract required parameter: {}", param.name);
return self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some("missing required parameter".as_bytes()),
);
}
}
});
let http_path = match &endpoint.path {
Some(path) => path,
None => "/",
};
debug!("tool_call_details: {:?}", tools_call_response);
let tool_name = &tools_call_response.tool_calls[0].name;
let tool_params = &tools_call_response.tool_calls[0].arguments;
debug!("tool_name: {:?}", tool_name);
debug!("tool_params: {:?}", tool_params);
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
debug!("prompt_target: {:?}", prompt_target);
let http_method = match &endpoint.method {
Some(method) => method,
None => http::Method::POST.as_str(),
};
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let endpoint = prompt_target.endpoint.as_ref().unwrap();
let token_id = match self.dispatch_http_call(
&endpoint.cluster,
vec![
(":method", http_method),
(":path", http_path),
(":method", "POST"),
(":path", endpoint.path.as_ref().unwrap_or(&"/".to_string())),
(":authority", endpoint.cluster.as_str()),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
],
Some(&req_param_bytes),
Some(tool_params_json_str.as_bytes()),
vec![],
Duration::from_secs(5),
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for context_resolver: {:?}", e);
panic!("Error dispatching HTTP call for function_resolver: {:?}", e);
}
};
callout_context.request_type = RequestType::ContextResolver;
callout_context.request_type = RequestType::FunctionCallResponse;
if self.callouts.insert(token_id, callout_context).is_some() {
panic!("duplicate token_id")
}
self.metrics.active_http_calls.increment(1);
}
fn context_resolver_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
debug!("response received for context_resolver");
let body_string = String::from_utf8(body);
let prompt_target = callout_context.prompt_target.unwrap();
let mut request_body = callout_context.request_body;
match prompt_target.system_prompt {
fn function_call_response_handler(&mut self, body: Vec<u8>, callout_context: CallContext) {
debug!("response received for function call response");
let body_str: String = String::from_utf8(body).unwrap();
debug!("function_call_response response str: {:?}", body_str);
let prompt_target = callout_context.prompt_target.as_ref().unwrap();
let mut messages: Vec<Message> = callout_context.request_body.messages.clone();
// add system prompt
match prompt_target.system_prompt.as_ref() {
None => {}
Some(system_prompt) => {
let system_prompt_message: Message = Message {
let system_prompt_message = Message {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt),
content: Some(system_prompt.clone()),
model: None,
};
request_body.messages.push(system_prompt_message);
}
}
match body_string {
Ok(body_string) => {
info!("context_resolver response: {}", body_string);
let context_resolver_response = Message {
role: USER_ROLE.to_string(),
content: Some(body_string),
};
request_body.messages.push(context_resolver_response);
}
Err(e) => {
warn!("Error converting response to string: {:?}", e);
self.resume_http_request();
return;
messages.push(system_prompt_message);
}
}
let json_string = match serde_json::to_string(&request_body) {
// add data from function call response
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(body_str),
model: None,
}
});
// add original user prompt
messages.push({
Message {
role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()),
model: None,
}
});
let request_message: ChatCompletions = ChatCompletions {
model: GPT_35_TURBO.to_string(),
messages,
tools: None,
};
let json_string = match serde_json::to_string(&request_message) {
Ok(json_string) => json_string,
Err(e) => {
warn!("Error serializing request_body: {:?}", e);
@ -350,6 +429,12 @@ impl StreamContext {
return;
}
};
debug!(
"function_calling sending request to openai: msg {}",
json_string
);
let request_body = callout_context.request_body;
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
@ -405,8 +490,7 @@ impl HttpContext for StreamContext {
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size)
{
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
@ -434,12 +518,12 @@ impl HttpContext for StreamContext {
let user_message = match deserialized_body
.messages
.pop()
.and_then(|last_message| last_message.content)
.last()
.and_then(|last_message| last_message.content.clone())
{
Some(content) => content,
None => {
info!("No messages in the request body");
warn!("No messages in the request body");
return Action::Continue;
}
};
@ -468,6 +552,7 @@ impl HttpContext for StreamContext {
(":authority", "embeddingserver"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
],
Some(json_data.as_bytes()),
vec![],
@ -481,6 +566,11 @@ impl HttpContext for StreamContext {
);
}
};
debug!(
"dispatched HTTP call to embedding server token_id={}",
token_id
);
let call_context = CallContext {
request_type: RequestType::GetEmbedding,
user_message: Some(user_message),
@ -530,8 +620,10 @@ impl Context for StreamContext {
match callout_context.request_type {
RequestType::GetEmbedding => self.embeddings_handler(body, callout_context),
RequestType::SearchPoints => self.search_points_handler(body, callout_context),
RequestType::Ner => self.ner_handler(body, callout_context),
RequestType::ContextResolver => self.context_resolver_handler(body, callout_context),
RequestType::FunctionResolver => self.function_resolver_handler(body, callout_context),
RequestType::FunctionCallResponse => {
self.function_call_response_handler(body, callout_context)
}
}
}
}