mirror of
https://github.com/katanemo/plano.git
synced 2026-05-03 21:02:56 +02:00
Serialize tool calls for Arch FC (#131)
* Serialize tool calls * fix int tests
This commit is contained in:
parent
b43f687b85
commit
96686dc606
10 changed files with 166 additions and 57 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use crate::consts::{
|
||||
ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_MESSAGES_KEY, ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER,
|
||||
ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
ARCH_STATE_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
|
||||
|
|
@ -15,9 +15,9 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{
|
||||
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
||||
ParameterType, StreamOptions, ToolType,
|
||||
ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
|
|
@ -28,6 +28,8 @@ use public_types::configuration::{Overrides, PromptGuards, PromptTarget};
|
|||
use public_types::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::num::NonZero;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -59,10 +61,16 @@ pub struct StreamContext {
|
|||
embeddings_store: Rc<EmbeddingsStore>,
|
||||
overrides: Rc<Option<Overrides>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
tool_call_response: Option<String>,
|
||||
arch_state: Option<Vec<ArchState>>,
|
||||
request_body_size: usize,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
user_prompt: Option<Message>,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
is_chat_completions_request: bool,
|
||||
chat_completions_request: Option<ChatCompletionsRequest>,
|
||||
prompt_guards: Rc<PromptGuards>,
|
||||
llm_providers: Rc<LlmProviders>,
|
||||
llm_provider: Option<Rc<LlmProvider>>,
|
||||
|
|
@ -83,11 +91,17 @@ impl StreamContext {
|
|||
metrics,
|
||||
prompt_targets,
|
||||
embeddings_store,
|
||||
chat_completions_request: None,
|
||||
callouts: HashMap::new(),
|
||||
tool_calls: None,
|
||||
tool_call_response: None,
|
||||
arch_state: None,
|
||||
request_body_size: 0,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
user_prompt: None,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
is_chat_completions_request: false,
|
||||
llm_providers,
|
||||
llm_provider: None,
|
||||
prompt_guards,
|
||||
|
|
@ -463,13 +477,20 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
// archfc handler needs state so it can expand tool calls
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::to_string(&self.arch_state).unwrap(),
|
||||
);
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
metadata: None,
|
||||
metadata: Some(metadata),
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
|
|
@ -521,10 +542,8 @@ impl StreamContext {
|
|||
}
|
||||
|
||||
fn function_resolver_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
debug!("response received for function resolver");
|
||||
|
||||
let body_str = String::from_utf8(body).unwrap();
|
||||
debug!("function_resolver response str: {}", body_str);
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
|
||||
let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) {
|
||||
Ok(arch_fc_response) => arch_fc_response,
|
||||
|
|
@ -559,7 +578,6 @@ impl StreamContext {
|
|||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
|
||||
debug!("tool_call_details: {:?}", tool_calls);
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tool_calls
|
||||
.iter()
|
||||
|
|
@ -581,8 +599,10 @@ impl StreamContext {
|
|||
|
||||
let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone();
|
||||
|
||||
debug!("prompt_target_name: {}", prompt_target.name);
|
||||
debug!("tool_name(s): {:?}", tool_names);
|
||||
debug!(
|
||||
"prompt_target_name: {}, tool_name(s): {:?}",
|
||||
prompt_target.name, tool_names
|
||||
);
|
||||
debug!("tool_params: {}", tool_params_json_str);
|
||||
|
||||
let endpoint = prompt_target.endpoint.unwrap();
|
||||
|
|
@ -611,6 +631,7 @@ impl StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.tool_calls = Some(tool_calls.clone());
|
||||
callout_context.upstream_cluster = Some(endpoint.name);
|
||||
callout_context.upstream_cluster_path = Some(path);
|
||||
callout_context.response_handler_type = ResponseHandlerType::FunctionCall;
|
||||
|
|
@ -635,9 +656,9 @@ impl StreamContext {
|
|||
} else {
|
||||
warn!("http status code not found in api response");
|
||||
}
|
||||
debug!("response received for function call response");
|
||||
let body_str: String = String::from_utf8(body).unwrap();
|
||||
debug!("function_call_response response str: {}", body_str);
|
||||
self.tool_call_response = Some(body_str.clone());
|
||||
debug!("arch <= app response body: {}", body_str);
|
||||
let prompt_target_name = callout_context.prompt_target_name.unwrap();
|
||||
let prompt_target = self
|
||||
.prompt_targets
|
||||
|
|
@ -697,10 +718,7 @@ impl StreamContext {
|
|||
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
"function_calling sending request to openai: msg {}",
|
||||
json_string
|
||||
);
|
||||
debug!("arch => openai request body: {}", json_string);
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
|
|
@ -725,7 +743,7 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
|
||||
self.set_http_request_body(0, json_string.len(), &json_string.into_bytes());
|
||||
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
|
|
@ -881,7 +899,7 @@ impl StreamContext {
|
|||
};
|
||||
let json_resp = serde_json::to_string(&chat_completion_request).unwrap();
|
||||
debug!("sending response back to default llm: {}", json_resp);
|
||||
self.set_http_request_body(0, json_resp.len(), json_resp.as_bytes());
|
||||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
}
|
||||
|
|
@ -899,7 +917,7 @@ impl HttpContext for StreamContext {
|
|||
self.delete_content_length_header();
|
||||
self.save_ratelimit_header();
|
||||
|
||||
self.chat_completions_request =
|
||||
self.is_chat_completions_request =
|
||||
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
|
||||
|
||||
debug!(
|
||||
|
|
@ -922,6 +940,8 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
self.request_body_size = body_size;
|
||||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
|
|
@ -948,6 +968,20 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.arch_state = match deserialized_body.metadata {
|
||||
Some(ref metadata) => {
|
||||
if metadata.contains_key(ARCH_STATE_HEADER) {
|
||||
let arch_state_str = metadata[ARCH_STATE_HEADER].clone();
|
||||
let arch_state: Vec<ArchState> = serde_json::from_str(&arch_state_str).unwrap();
|
||||
Some(arch_state)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
self.is_chat_completions_request = true;
|
||||
// Set the model based on the chosen LLM Provider
|
||||
deserialized_body.model = String::from(&self.llm_provider().model);
|
||||
|
||||
|
|
@ -958,10 +992,11 @@ impl HttpContext for StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
let user_message = match deserialized_body
|
||||
let last_user_prompt = match deserialized_body
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.last()
|
||||
.and_then(|last_message| last_message.content.clone())
|
||||
{
|
||||
Some(content) => content,
|
||||
None => {
|
||||
|
|
@ -970,17 +1005,24 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
self.user_prompt = Some(last_user_prompt.clone());
|
||||
|
||||
let user_message_str = self.user_prompt.as_ref().unwrap().content.clone();
|
||||
|
||||
let prompt_guard_jailbreak_task = self
|
||||
.prompt_guards
|
||||
.input_guards
|
||||
.contains_key(&public_types::configuration::GuardType::Jailbreak);
|
||||
|
||||
self.chat_completions_request = Some(deserialized_body);
|
||||
|
||||
if !prompt_guard_jailbreak_task {
|
||||
debug!("Missing input guard. Making inline call to retrieve");
|
||||
let callout_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
user_message: user_message_str.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
|
|
@ -990,7 +1032,14 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
let get_prompt_guards_request = PromptGuardRequest {
|
||||
input: user_message.clone(),
|
||||
input: self
|
||||
.user_prompt
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.content
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone(),
|
||||
task: PromptGuardTask::Jailbreak,
|
||||
};
|
||||
|
||||
|
|
@ -1032,9 +1081,9 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let call_context = CallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchGuard,
|
||||
user_message: Some(user_message),
|
||||
user_message: self.user_prompt.as_ref().unwrap().content.clone(),
|
||||
prompt_target_name: None,
|
||||
request_body: deserialized_body,
|
||||
request_body: self.chat_completions_request.as_ref().unwrap().clone(),
|
||||
similarity_scores: None,
|
||||
upstream_cluster: None,
|
||||
upstream_cluster_path: None,
|
||||
|
|
@ -1057,7 +1106,7 @@ impl HttpContext for StreamContext {
|
|||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !self.chat_completions_request {
|
||||
if !self.is_chat_completions_request {
|
||||
if let Some(body_str) = self
|
||||
.get_http_response_body(0, body_size)
|
||||
.and_then(|bytes| String::from_utf8(bytes).ok())
|
||||
|
|
@ -1067,7 +1116,7 @@ impl HttpContext for StreamContext {
|
|||
return Action::Continue;
|
||||
}
|
||||
|
||||
if !end_of_stream && !self.streaming_response {
|
||||
if !end_of_stream {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
|
|
@ -1075,9 +1124,8 @@ impl HttpContext for StreamContext {
|
|||
.get_http_response_body(0, body_size)
|
||||
.expect("cant get response body");
|
||||
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
|
||||
if self.streaming_response {
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
debug!("streaming response");
|
||||
let chat_completions_data = match body_str.split_once("data: ") {
|
||||
Some((_, chat_completions_data)) => chat_completions_data,
|
||||
|
|
@ -1117,13 +1165,14 @@ impl HttpContext for StreamContext {
|
|||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_str(&body_str) {
|
||||
match serde_json::from_slice(&body) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"error in non-streaming response: {}\n response was={}",
|
||||
e, body_str
|
||||
e,
|
||||
String::from_utf8(body).unwrap()
|
||||
),
|
||||
None,
|
||||
);
|
||||
|
|
@ -1132,6 +1181,65 @@ impl HttpContext for StreamContext {
|
|||
};
|
||||
|
||||
self.response_tokens += chat_completions_response.usage.completion_tokens;
|
||||
|
||||
if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
// compute sha hash from message history
|
||||
let mut hasher = Sha256::new();
|
||||
let prompts: Vec<String> = self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.map(|msg| msg.content.clone().unwrap())
|
||||
.collect();
|
||||
let prompts_merged = prompts.join("#.#");
|
||||
hasher.update(prompts_merged.clone());
|
||||
let hash_key = hasher.finalize();
|
||||
// conver hash to hex string
|
||||
let hash_key_str = format!("{:x}", hash_key);
|
||||
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
|
||||
|
||||
// create new tool call state
|
||||
let tool_call_state = ToolCallState {
|
||||
key: hash_key_str,
|
||||
message: self.user_prompt.clone(),
|
||||
tool_call: tool_calls[0].function.clone(),
|
||||
tool_response: self.tool_call_response.clone().unwrap(),
|
||||
};
|
||||
|
||||
// push tool call state to arch state
|
||||
self.arch_state
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(ArchState::ToolCall(vec![tool_call_state]));
|
||||
|
||||
let mut data: Value = serde_json::from_slice(&body).unwrap();
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
|
||||
debug!("arch_state: {}", arch_state_str);
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
debug!("arch => user: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue