Pass tool call and app function response back in metadata (#193)

This commit is contained in:
Adil Hafeez 2024-10-18 13:25:39 -07:00 committed by GitHub
parent 62a000036e
commit dd1c7be706
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 169 additions and 112 deletions

View file

@ -5,6 +5,7 @@
"version": "0.2.0",
"configurations": [
{
"python": "${workspaceFolder}/venv/bin/python",
"name": "chatbot-ui",
"cwd": "${workspaceFolder}/app",
"type": "debugpy",

View file

@ -2,14 +2,21 @@ import json
import os
from openai import OpenAI, DefaultHttpxClient
import gradio as gr
import logging as log
import logging
from dotenv import load_dotenv
load_dotenv()
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
)
log = logging.getLogger(__name__)
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
ARCH_STATE_HEADER = "x-arch-state"
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
client = OpenAI(
api_key="--",
@ -23,23 +30,19 @@ def predict(message, state):
state["history"] = []
history = state.get("history")
history.append({"role": "user", "content": message})
log.info("history: ", history)
log.info(f"history: {history}")
# Custom headers
custom_headers = {
"x-arch-deterministic-provider": "openai",
}
metadata = None
if "arch_state" in state:
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
try:
raw_response = client.chat.completions.with_raw_response.create(
model="--",
messages=history,
temperature=1.0,
metadata=metadata,
# metadata=metadata,
extra_headers=custom_headers,
)
except Exception as e:
@ -49,26 +52,35 @@ def predict(message, state):
log.info("Error calling gateway API: {}".format(e.message))
raise gr.Error("Error calling gateway API: {}".format(e.message))
log.info("raw_response: ", raw_response.text)
log.error(f"raw_response: {raw_response.text}")
response = raw_response.parse()
# extract arch_state from metadata and store it in gradio session state
# this state must be passed back to the gateway in the next request
response_json = json.loads(raw_response.text)
arch_state = None
if response_json:
metadata = response_json.get("metadata", {})
if metadata:
arch_state = metadata.get(ARCH_STATE_HEADER, None)
if arch_state:
state["arch_state"] = arch_state
# load arch_state from metadata
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
# parse arch_state into json object
arch_state = json.loads(arch_state_str)
# load messages from arch_state
arch_messages_str = arch_state.get("messages", "[]")
# parse messages into json object
arch_messages = json.loads(arch_messages_str)
# append messages from arch gateway to history
for message in arch_messages:
history.append(message)
content = response.choices[0].message.content
history.append({"role": "assistant", "content": content, "model": response.model})
# for gradio UI we don't want to show raw tool calls and messages from developer application
# so we're filtering those out
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
messages = [
(history[i]["content"], history[i + 1]["content"])
for i in range(0, len(history) - 1, 2)
(history_view[i]["content"], history_view[i + 1]["content"])
for i in range(0, len(history_view) - 1, 2)
]
return messages, state

View file

@ -188,6 +188,8 @@ pub mod open_ai {
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -381,6 +383,7 @@ mod test {
content: Some("What city do you want to know the weather for?".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
}],
tools: Some(vec![super::open_ai::ChatCompletionTool {
tool_type: ToolType::Function,

View file

@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
pub const SYSTEM_ROLE: &str = "system";
pub const USER_ROLE: &str = "user";
pub const TOOL_ROLE: &str = "tool";
pub const ASSISTANT_ROLE: &str = "assistant";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const ARC_FC_CLUSTER: &str = "arch_fc";
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes

View file

@ -3,7 +3,7 @@ use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
StreamOptions, ToolCall, ToolCallState, ToolType,
StreamOptions, ToolCall, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
@ -29,12 +29,12 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
use std::str::FromStr;
use std::time::Duration;
use derivative::Derivative;
use common::stats::IncrementingMetric;
@ -49,11 +49,13 @@ enum ResponseHandlerType {
DefaultTarget,
}
#[derive(Debug, Clone)]
#[derive(Clone, Derivative)]
#[derivative(Debug)]
pub struct StreamCallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target_name: Option<String>,
#[derivative(Debug = "ignore")]
request_body: ChatCompletionsRequest,
tool_calls: Option<Vec<ToolCall>>,
similarity_scores: Option<Vec<(String, f64)>>,
@ -306,6 +308,7 @@ impl StreamContext {
content: Some(response),
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: None,
tool_call_id: None,
};
let chat_completion_response = ChatCompletionsResponse {
@ -797,7 +800,7 @@ impl StreamContext {
fn function_call_response_handler(
&mut self,
body: Vec<u8>,
mut callout_context: StreamCallContext,
callout_context: StreamCallContext,
) {
if let Some(http_status) = self.get_http_call_response_header(":status") {
if http_status != StatusCode::OK.as_str() {
@ -841,11 +844,18 @@ impl StreamContext {
content: system_prompt,
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
messages.append(callout_context.request_body.messages.as_mut());
// don't send tools message and api response to chat gpt
for m in callout_context.request_body.messages.iter() {
if m.role == TOOL_ROLE || m.content.is_none() {
continue;
}
messages.push(m.clone());
}
let user_message = match messages.pop() {
Some(user_message) => user_message,
@ -872,6 +882,7 @@ impl StreamContext {
content: Some(final_prompt),
model: None,
tool_calls: None,
tool_call_id: None,
}
});
@ -1022,6 +1033,7 @@ impl StreamContext {
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
};
messages.push(system_prompt_message);
}
@ -1032,6 +1044,7 @@ impl StreamContext {
content: Some(api_resp.clone()),
model: None,
tool_calls: None,
tool_call_id: None,
});
let chat_completion_request = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext {
self.arch_state = Some(Vec::new());
}
// compute sha hash from message history
let mut hasher = Sha256::new();
let prompts: Vec<String> = self
.chat_completions_request
.as_ref()
.unwrap()
.messages
.iter()
.filter(|msg| msg.role == USER_ROLE)
.map(|msg| msg.content.clone().unwrap())
.collect();
let prompts_merged = prompts.join("#.#");
hasher.update(prompts_merged.clone());
let hash_key = hasher.finalize();
// conver hash to hex string
let hash_key_str = format!("{:x}", hash_key);
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
// create new tool call state
let tool_call_state = ToolCallState {
key: hash_key_str,
message: self.user_prompt.clone(),
tool_call: tool_calls[0].function.clone(),
tool_response: self.tool_call_response.clone().unwrap(),
};
// push tool call state to arch state
self.arch_state
.as_mut()
.unwrap()
.push(ArchState::ToolCall(vec![tool_call_state]));
let mut data: Value = serde_json::from_slice(&body).unwrap();
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
debug!("arch_state: {}", arch_state_str);
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
// get response, we will send these back to developer so they can see the api response
// and tool call arch-fc generated
let mut fc_messages = Vec::new();
fc_messages.push(Message {
role: ASSISTANT_ROLE.to_string(),
content: None,
model: Some(ARCH_FC_MODEL_NAME.to_string()),
tool_calls: self.tool_calls.clone(),
tool_call_id: None,
});
fc_messages.push(Message {
role: TOOL_ROLE.to_string(),
content: self.tool_call_response.clone(),
model: None,
tool_calls: None,
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
});
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
let data_serialized = serde_json::to_string(&data).unwrap();
debug!("arch => user: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());

View file

@ -546,6 +546,7 @@ fn request_to_llm_gateway() {
},
}]),
model: None,
tool_call_id: None,
},
}],
model: String::from("test"),
@ -647,6 +648,7 @@ fn request_to_llm_gateway() {
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
tool_call_id: None,
},
}],
model: String::from("test"),
@ -665,8 +667,6 @@ fn request_to_llm_gateway() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))

View file

@ -13,62 +13,38 @@ logger = get_model_server_logger()
class Message(BaseModel):
role: str
content: str
content: str = ""
tool_calls: List[Dict[str, Any]] = []
tool_call_id: str = ""
class ChatMessage(BaseModel):
messages: list[Message]
tools: List[Dict[str, Any]]
# TODO: make it default none
metadata: Dict[str, str] = {}
def process_state(arch_state, history: list[Message]):
logger.info("state: {}".format(arch_state))
state_json = json.loads(arch_state)
state_map = {}
if state_json:
for tools_state in state_json:
for tool_state in tools_state:
state_map[tool_state["key"]] = tool_state
logger.info(f"state_map: {json.dumps(state_map)}")
sha_history = []
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
updated_history.append({"role": hist.role, "content": hist.content})
if hist.role == "user":
sha_history.append(hist.content)
sha256_hash = hashlib.sha256()
joined_key_str = ("#.#").join(sha_history)
sha256_hash.update(joined_key_str.encode())
sha_key = sha256_hash.hexdigest()
logger.info(f"sha_key: {sha_key}")
if sha_key in state_map:
tool_call_state = state_map[sha_key]
if "tool_call" in tool_call_state:
tool_call_str = json.dumps(tool_call_state["tool_call"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
if "tool_response" in tool_call_state:
tool_resp = tool_call_state["tool_response"]
# TODO: try with role = user as well
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{tool_resp}\n</tool_response>",
}
)
# we dont want to match this state with any other messages
del state_map[sha_key]
if hist.tool_calls:
if len(hist.tool_calls) > 1:
raise ValueError("Only one tool call is supported")
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
updated_history.append(
{
"role": "assistant",
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
}
)
elif hist.role == "tool":
updated_history.append(
{
"role": "user",
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
}
)
else:
updated_history.append({"role": hist.role, "content": hist.content})
return updated_history
@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response):
messages = [{"role": "system", "content": tools_encoded}]
metadata = req.metadata
arch_state = metadata.get("x-arch-state", "[]")
updated_history = process_state(arch_state, req.messages)
updated_history = process_messages(req.messages)
for message in updated_history:
messages.append({"role": message["role"], "content": message["content"]})

View file

@ -0,0 +1,66 @@
from typing import List
import pytest
import json
from app.function_calling.model_utils import Message, process_messages
test_input_history = """
[
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"model": "Arch-Function-1.5B",
"tool_calls": [
{
"id": "call_3394",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
},
{
"role": "tool",
"content": "--",
"tool_call_id": "call_3394"
},
{
"role": "assistant",
"content": "--",
"model": "gpt-3.5-turbo-0125"
},
{
"role": "user",
"content": "how is the weather in chicago for next 5 days?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "call_5306",
"type": "function",
"function": {
"name": "weather_forecast",
"arguments": { "city": "Chicago", "days": 5 }
}
}
]
}
]
"""
def test_update_fc_history():
history = json.loads(test_input_history)
message_history = []
for h in history:
message_history.append(Message(**h))
updated_history = process_messages(message_history)
assert len(updated_history) == 6
# ensure that tool role does not exist anymore
assert all([h["role"] != "tool" for h in updated_history])