mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 16:22:42 +02:00
Pass tool call and app function response back in metadata (#193)
This commit is contained in:
parent
62a000036e
commit
dd1c7be706
8 changed files with 169 additions and 112 deletions
|
|
@ -3,7 +3,7 @@ use acap::cos;
|
|||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
|
||||
StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
StreamOptions, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
|||
use common::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
|
||||
ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
|
||||
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
|
||||
};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -29,12 +29,12 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use derivative::Derivative;
|
||||
|
||||
use common::stats::IncrementingMetric;
|
||||
|
||||
|
|
@ -49,11 +49,13 @@ enum ResponseHandlerType {
|
|||
DefaultTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Derivative)]
|
||||
#[derivative(Debug)]
|
||||
pub struct StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target_name: Option<String>,
|
||||
#[derivative(Debug = "ignore")]
|
||||
request_body: ChatCompletionsRequest,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
|
|
@ -306,6 +308,7 @@ impl StreamContext {
|
|||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
|
|
@ -797,7 +800,7 @@ impl StreamContext {
|
|||
fn function_call_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
|
|
@ -841,11 +844,18 @@ impl StreamContext {
|
|||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
messages.append(callout_context.request_body.messages.as_mut());
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
if m.role == TOOL_ROLE || m.content.is_none() {
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -872,6 +882,7 @@ impl StreamContext {
|
|||
content: Some(final_prompt),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1022,6 +1033,7 @@ impl StreamContext {
|
|||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
|
@ -1032,6 +1044,7 @@ impl StreamContext {
|
|||
content: Some(api_resp.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
|
|
@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext {
|
|||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
// compute sha hash from message history
|
||||
let mut hasher = Sha256::new();
|
||||
let prompts: Vec<String> = self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.map(|msg| msg.content.clone().unwrap())
|
||||
.collect();
|
||||
let prompts_merged = prompts.join("#.#");
|
||||
hasher.update(prompts_merged.clone());
|
||||
let hash_key = hasher.finalize();
|
||||
// conver hash to hex string
|
||||
let hash_key_str = format!("{:x}", hash_key);
|
||||
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
|
||||
|
||||
// create new tool call state
|
||||
let tool_call_state = ToolCallState {
|
||||
key: hash_key_str,
|
||||
message: self.user_prompt.clone(),
|
||||
tool_call: tool_calls[0].function.clone(),
|
||||
tool_response: self.tool_call_response.clone().unwrap(),
|
||||
};
|
||||
|
||||
// push tool call state to arch state
|
||||
self.arch_state
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(ArchState::ToolCall(vec![tool_call_state]));
|
||||
|
||||
let mut data: Value = serde_json::from_slice(&body).unwrap();
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
|
||||
debug!("arch_state: {}", arch_state_str);
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
|
||||
// get response, we will send these back to developer so they can see the api response
|
||||
// and tool call arch-fc generated
|
||||
let mut fc_messages = Vec::new();
|
||||
fc_messages.push(Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
});
|
||||
fc_messages.push(Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
});
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
debug!("arch => user: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue