Serialize tool calls for Arch FC (#131)

* Serialize tool calls

* fix int tests
This commit is contained in:
Adil Hafeez 2024-10-07 00:03:25 -07:00 committed by GitHub
parent b43f687b85
commit 96686dc606
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 166 additions and 57 deletions

View file

@ -1,7 +1,7 @@
use crate::consts::{
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
@ -15,9 +15,9 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
ParameterType, StreamOptions, ToolType,
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
};
use public_types::common_types::{
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
@ -28,6 +28,8 @@ use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
use public_types::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
@ -59,10 +61,16 @@ pub struct StreamContext {
embeddings_store: Rc<EmbeddingsStore>,
overrides: Rc<Option<Overrides>>,
callouts: HashMap<u32, CallContext>,
tool_calls: Option<Vec<ToolCall>>,
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
response_tokens: usize,
chat_completions_request: bool,
is_chat_completions_request: bool,
chat_completions_request: Option<ChatCompletionsRequest>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
@ -83,11 +91,17 @@ impl StreamContext {
metrics,
prompt_targets,
embeddings_store,
chat_completions_request: None,
callouts: HashMap::new(),
tool_calls: None,
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
response_tokens: 0,
chat_completions_request: false,
is_chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
@ -463,13 +477,20 @@ impl StreamContext {
});
}
// archfc handler needs state so it can expand tool calls
let mut metadata = HashMap::new();
metadata.insert(
ARCH_STATE_HEADER.to_string(),
serde_json::to_string(&self.arch_state).unwrap(),
);
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
metadata: None,
metadata: Some(metadata),
};
let msg_body = match serde_json::to_string(&chat_completions) {
@ -521,10 +542,8 @@ impl StreamContext {
}
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
debug!("response received for function resolver");
let body_str = String::from_utf8(body).unwrap();
debug!("function_resolver response str: {}", body_str);
debug!("arch <= app response body: {}", body_str);
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
Ok(arch_fc_response) => arch_fc_response,
@ -559,7 +578,6 @@ impl StreamContext {
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
let tool_names: Vec<String> = tool_calls
.iter()
@ -581,8 +599,10 @@ impl StreamContext {
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
debug!("prompt_target_name: {}", prompt_target.name);
debug!("tool_name(s): {:?}", tool_names);
debug!(
"prompt_target_name: {}, tool_name(s): {:?}",
prompt_target.name, tool_names
);
debug!("tool_params: {}", tool_params_json_str);
let endpoint = prompt_target.endpoint.unwrap();
@ -611,6 +631,7 @@ impl StreamContext {
}
};
self.tool_calls = Some(tool_calls.clone());
callout_context.upstream_cluster = Some(endpoint.name);
callout_context.upstream_cluster_path = Some(path);
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
@ -635,9 +656,9 @@ impl StreamContext {
} else {
warn!("http status code not found in api response");
}
debug!("response received for function call response");
let body_str: String = String::from_utf8(body).unwrap();
debug!("function_call_response response str: {}", body_str);
self.tool_call_response = Some(body_str.clone());
debug!("arch <= app response body: {}", body_str);
let prompt_target_name = callout_context.prompt_target_name.unwrap();
let prompt_target = self
.prompt_targets
@ -697,10 +718,7 @@ impl StreamContext {
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
}
};
debug!(
"function_calling sending request to openai: msg {}",
json_string
);
debug!("arch => openai request body: {}", json_string);
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
@ -725,7 +743,7 @@ impl StreamContext {
}
}
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.resume_http_request();
}
@ -881,7 +899,7 @@ impl StreamContext {
};
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
debug!("sending response back to default llm: {}", json_resp);
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
}
@ -899,7 +917,7 @@ impl HttpContext for StreamContext {
self.delete_content_length_header();
self.save_ratelimit_header();
self.chat_completions_request =
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!(
@ -922,6 +940,8 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
self.request_body_size = body_size;
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -948,6 +968,20 @@ impl HttpContext for StreamContext {
}
};
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
Some(arch_state)
} else {
None
}
}
None => None,
};
self.is_chat_completions_request = true;
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(&self.llm_provider().model);
@ -958,10 +992,11 @@ impl HttpContext for StreamContext {
});
}
let user_message = match deserialized_body
let last_user_prompt = match deserialized_body
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.last()
.and_then(|last_message| last_message.content.clone())
{
Some(content) => content,
None => {
@ -970,17 +1005,24 @@ impl HttpContext for StreamContext {
}
};
self.user_prompt = Some(last_user_prompt.clone());
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
let prompt_guard_jailbreak_task = self
.prompt_guards
.input_guards
.contains_key(&public_types::configuration::GuardType::Jailbreak);
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve");
let callout_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
user_message: user_message_str.clone(),
prompt_target_name: None,
request_body: deserialized_body,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
@ -990,7 +1032,14 @@ impl HttpContext for StreamContext {
}
let get_prompt_guards_request = PromptGuardRequest {
input: user_message.clone(),
input: self
.user_prompt
.as_ref()
.unwrap()
.content
.as_ref()
.unwrap()
.clone(),
task: PromptGuardTask::Jailbreak,
};
@ -1032,9 +1081,9 @@ impl HttpContext for StreamContext {
let call_context = CallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: Some(user_message),
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
prompt_target_name: None,
request_body: deserialized_body,
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
similarity_scores: None,
upstream_cluster: None,
upstream_cluster_path: None,
@ -1057,7 +1106,7 @@ impl HttpContext for StreamContext {
self.context_id, body_size, end_of_stream
);
if !self.chat_completions_request {
if !self.is_chat_completions_request {
if let Some(body_str) = self
.get_http_response_body(0, body_size)
.and_then(|bytes| String::from_utf8(bytes).ok())
@ -1067,7 +1116,7 @@ impl HttpContext for StreamContext {
return Action::Continue;
}
if !end_of_stream && !self.streaming_response {
if !end_of_stream {
return Action::Pause;
}
@ -1075,9 +1124,8 @@ impl HttpContext for StreamContext {
.get_http_response_body(0, body_size)
.expect("cant get response body");
let body_str = String::from_utf8(body).expect("body is not utf-8");
if self.streaming_response {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
@ -1117,13 +1165,14 @@ impl HttpContext for StreamContext {
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(&body_str) {
match serde_json::from_slice(&body) {
Ok(de) => de,
Err(e) => {
self.send_server_error(
format!(
"error in non-streaming response: {}\n response was={}",
e, body_str
e,
String::from_utf8(body).unwrap()
),
None,
);
@ -1132,6 +1181,65 @@ impl HttpContext for StreamContext {
};
self.response_tokens += chat_completions_response.usage.completion_tokens;
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
}
}
}
debug!(