Pass tool call and app function response back in metadata (#193)

This commit is contained in:
Adil Hafeez 2024-10-18 13:25:39 -07:00 committed by GitHub
parent 62a000036e
commit dd1c7be706
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 169 additions and 112 deletions

View file

@ -3,7 +3,7 @@ use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
StreamOptions, ToolCall, ToolCallState, ToolType,
StreamOptions, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -29,12 +29,12 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use derivative::Derivative;
use common::stats::IncrementingMetric;
@ -49,11 +49,13 @@ enum ResponseHandlerType {
DefaultTarget,
}
#[derive(Debug, Clone)]
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct StreamCallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
request_body: ChatCompletionsRequest,
tool_calls: Option<Vec<ToolCall>>,
similarity_scores: Option<Vec<(String, f64)>>,
@ -306,6 +308,7 @@ impl StreamContext {
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
};
let chat_completion_response = ChatCompletionsResponse {
@ -797,7 +800,7 @@ impl StreamContext {
fn function_call_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
callout_context: StreamCallContext,
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
@ -841,11 +844,18 @@ impl StreamContext {
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
messages.append(callout_context.request_body.messages.as_mut());
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
if m.role == TOOL_ROLE || m.content.is_none() {
continue;
}
messages.push(m.clone());
}
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -872,6 +882,7 @@ impl StreamContext {
content: Some(final_prompt),
model: None,
tool_calls: None,
tool_call_id: None,
}
});
@ -1022,6 +1033,7 @@ impl StreamContext {
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
@ -1032,6 +1044,7 @@ impl StreamContext {
content: Some(api_resp.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
// get response, we will send these back to developer so they can see the api response
// and tool call arch-fc generated
let mut fc_messages = Vec::new();
fc_messages.push(Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
});
fc_messages.push(Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
});
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());