mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
Pass tool call and app function response back in metadata (#193)
This commit is contained in:
parent
62a000036e
commit
dd1c7be706
8 changed files with 169 additions and 112 deletions
1
chatbot_ui/.vscode/launch.json
vendored
1
chatbot_ui/.vscode/launch.json
vendored
|
|
@ -5,6 +5,7 @@
|
|||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"python": "${workspaceFolder}/venv/bin/python",
|
||||
"name": "chatbot-ui",
|
||||
"cwd": "${workspaceFolder}/app",
|
||||
"type": "debugpy",
|
||||
|
|
|
|||
|
|
@ -2,14 +2,21 @@ import json
|
|||
import os
|
||||
from openai import OpenAI, DefaultHttpxClient
|
||||
import gradio as gr
|
||||
import logging as log
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT")
|
||||
ARCH_STATE_HEADER = "x-arch-state"
|
||||
log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT)
|
||||
log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}")
|
||||
|
||||
client = OpenAI(
|
||||
api_key="--",
|
||||
|
|
@ -23,23 +30,19 @@ def predict(message, state):
|
|||
state["history"] = []
|
||||
history = state.get("history")
|
||||
history.append({"role": "user", "content": message})
|
||||
log.info("history: ", history)
|
||||
log.info(f"history: {history}")
|
||||
|
||||
# Custom headers
|
||||
custom_headers = {
|
||||
"x-arch-deterministic-provider": "openai",
|
||||
}
|
||||
|
||||
metadata = None
|
||||
if "arch_state" in state:
|
||||
metadata = {ARCH_STATE_HEADER: state["arch_state"]}
|
||||
|
||||
try:
|
||||
raw_response = client.chat.completions.with_raw_response.create(
|
||||
model="--",
|
||||
messages=history,
|
||||
temperature=1.0,
|
||||
metadata=metadata,
|
||||
# metadata=metadata,
|
||||
extra_headers=custom_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -49,26 +52,35 @@ def predict(message, state):
|
|||
log.info("Error calling gateway API: {}".format(e.message))
|
||||
raise gr.Error("Error calling gateway API: {}".format(e.message))
|
||||
|
||||
log.info("raw_response: ", raw_response.text)
|
||||
log.error(f"raw_response: {raw_response.text}")
|
||||
response = raw_response.parse()
|
||||
|
||||
# extract arch_state from metadata and store it in gradio session state
|
||||
# this state must be passed back to the gateway in the next request
|
||||
response_json = json.loads(raw_response.text)
|
||||
arch_state = None
|
||||
if response_json:
|
||||
metadata = response_json.get("metadata", {})
|
||||
if metadata:
|
||||
arch_state = metadata.get(ARCH_STATE_HEADER, None)
|
||||
if arch_state:
|
||||
state["arch_state"] = arch_state
|
||||
# load arch_state from metadata
|
||||
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
|
||||
# parse arch_state into json object
|
||||
arch_state = json.loads(arch_state_str)
|
||||
# load messages from arch_state
|
||||
arch_messages_str = arch_state.get("messages", "[]")
|
||||
# parse messages into json object
|
||||
arch_messages = json.loads(arch_messages_str)
|
||||
# append messages from arch gateway to history
|
||||
for message in arch_messages:
|
||||
history.append(message)
|
||||
|
||||
content = response.choices[0].message.content
|
||||
|
||||
history.append({"role": "assistant", "content": content, "model": response.model})
|
||||
|
||||
# for gradio UI we don't want to show raw tool calls and messages from developer application
|
||||
# so we're filtering those out
|
||||
history_view = [h for h in history if h["role"] != "tool" and "content" in h]
|
||||
messages = [
|
||||
(history[i]["content"], history[i + 1]["content"])
|
||||
for i in range(0, len(history) - 1, 2)
|
||||
(history_view[i]["content"], history_view[i + 1]["content"])
|
||||
for i in range(0, len(history_view) - 1, 2)
|
||||
]
|
||||
return messages, state
|
||||
|
||||
|
|
|
|||
|
|
@ -188,6 +188,8 @@ pub mod open_ai {
|
|||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -381,6 +383,7 @@ mod test {
|
|||
content: Some("What city do you want to know the weather for?".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}],
|
||||
tools: Some(vec![super::open_ai::ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25;
|
|||
pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector";
|
||||
pub const SYSTEM_ROLE: &str = "system";
|
||||
pub const USER_ROLE: &str = "user";
|
||||
pub const TOOL_ROLE: &str = "tool";
|
||||
pub const ASSISTANT_ROLE: &str = "assistant";
|
||||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const ARC_FC_CLUSTER: &str = "arch_fc";
|
||||
pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use acap::cos;
|
|||
use common::common_types::open_ai::{
|
||||
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
|
||||
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
|
||||
StreamOptions, ToolCall, ToolCallState, ToolType,
|
||||
StreamOptions, ToolCall, ToolType,
|
||||
};
|
||||
use common::common_types::{
|
||||
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
|
||||
|
|
@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget};
|
|||
use common::consts::{
|
||||
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
|
||||
ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
|
||||
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
|
||||
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE,
|
||||
};
|
||||
use common::embeddings::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
|
|
@ -29,12 +29,12 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use serde_json::Value;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
use std::str::FromStr;
|
||||
use std::time::Duration;
|
||||
use derivative::Derivative;
|
||||
|
||||
use common::stats::IncrementingMetric;
|
||||
|
||||
|
|
@ -49,11 +49,13 @@ enum ResponseHandlerType {
|
|||
DefaultTarget,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Clone, Derivative)]
|
||||
#[derivative(Debug)]
|
||||
pub struct StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target_name: Option<String>,
|
||||
#[derivative(Debug = "ignore")]
|
||||
request_body: ChatCompletionsRequest,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
|
|
@ -306,6 +308,7 @@ impl StreamContext {
|
|||
content: Some(response),
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
|
||||
let chat_completion_response = ChatCompletionsResponse {
|
||||
|
|
@ -797,7 +800,7 @@ impl StreamContext {
|
|||
fn function_call_response_handler(
|
||||
&mut self,
|
||||
body: Vec<u8>,
|
||||
mut callout_context: StreamCallContext,
|
||||
callout_context: StreamCallContext,
|
||||
) {
|
||||
if let Some(http_status) = self.get_http_call_response_header(":status") {
|
||||
if http_status != StatusCode::OK.as_str() {
|
||||
|
|
@ -841,11 +844,18 @@ impl StreamContext {
|
|||
content: system_prompt,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
||||
messages.append(callout_context.request_body.messages.as_mut());
|
||||
// don't send tools message and api response to chat gpt
|
||||
for m in callout_context.request_body.messages.iter() {
|
||||
if m.role == TOOL_ROLE || m.content.is_none() {
|
||||
continue;
|
||||
}
|
||||
messages.push(m.clone());
|
||||
}
|
||||
|
||||
let user_message = match messages.pop() {
|
||||
Some(user_message) => user_message,
|
||||
|
|
@ -872,6 +882,7 @@ impl StreamContext {
|
|||
content: Some(final_prompt),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -1022,6 +1033,7 @@ impl StreamContext {
|
|||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
|
@ -1032,6 +1044,7 @@ impl StreamContext {
|
|||
content: Some(api_resp.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
});
|
||||
let chat_completion_request = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
|
|
@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext {
|
|||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
// compute sha hash from message history
|
||||
let mut hasher = Sha256::new();
|
||||
let prompts: Vec<String> = self
|
||||
.chat_completions_request
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.messages
|
||||
.iter()
|
||||
.filter(|msg| msg.role == USER_ROLE)
|
||||
.map(|msg| msg.content.clone().unwrap())
|
||||
.collect();
|
||||
let prompts_merged = prompts.join("#.#");
|
||||
hasher.update(prompts_merged.clone());
|
||||
let hash_key = hasher.finalize();
|
||||
// conver hash to hex string
|
||||
let hash_key_str = format!("{:x}", hash_key);
|
||||
debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged);
|
||||
|
||||
// create new tool call state
|
||||
let tool_call_state = ToolCallState {
|
||||
key: hash_key_str,
|
||||
message: self.user_prompt.clone(),
|
||||
tool_call: tool_calls[0].function.clone(),
|
||||
tool_response: self.tool_call_response.clone().unwrap(),
|
||||
};
|
||||
|
||||
// push tool call state to arch state
|
||||
self.arch_state
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.push(ArchState::ToolCall(vec![tool_call_state]));
|
||||
|
||||
let mut data: Value = serde_json::from_slice(&body).unwrap();
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let arch_state_str = serde_json::to_string(&self.arch_state).unwrap();
|
||||
debug!("arch_state: {}", arch_state_str);
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
// since arch gateway generates tool calls (using arch-fc) and calls upstream api to
|
||||
// get response, we will send these back to developer so they can see the api response
|
||||
// and tool call arch-fc generated
|
||||
let mut fc_messages = Vec::new();
|
||||
fc_messages.push(Message {
|
||||
role: ASSISTANT_ROLE.to_string(),
|
||||
content: None,
|
||||
model: Some(ARCH_FC_MODEL_NAME.to_string()),
|
||||
tool_calls: self.tool_calls.clone(),
|
||||
tool_call_id: None,
|
||||
});
|
||||
fc_messages.push(Message {
|
||||
role: TOOL_ROLE.to_string(),
|
||||
content: self.tool_call_response.clone(),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()),
|
||||
});
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
debug!("arch => user: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
|
|
|
|||
|
|
@ -546,6 +546,7 @@ fn request_to_llm_gateway() {
|
|||
},
|
||||
}]),
|
||||
model: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
|
|
@ -647,6 +648,7 @@ fn request_to_llm_gateway() {
|
|||
content: Some("hello from fake llm gateway".to_string()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
},
|
||||
}],
|
||||
model: String::from("test"),
|
||||
|
|
@ -665,8 +667,6 @@ fn request_to_llm_gateway() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Continue))
|
||||
|
|
|
|||
|
|
@ -13,62 +13,38 @@ logger = get_model_server_logger()
|
|||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: str = ""
|
||||
tool_calls: List[Dict[str, Any]] = []
|
||||
tool_call_id: str = ""
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
||||
# TODO: make it default none
|
||||
metadata: Dict[str, str] = {}
|
||||
|
||||
|
||||
def process_state(arch_state, history: list[Message]):
|
||||
logger.info("state: {}".format(arch_state))
|
||||
state_json = json.loads(arch_state)
|
||||
|
||||
state_map = {}
|
||||
if state_json:
|
||||
for tools_state in state_json:
|
||||
for tool_state in tools_state:
|
||||
state_map[tool_state["key"]] = tool_state
|
||||
|
||||
logger.info(f"state_map: {json.dumps(state_map)}")
|
||||
|
||||
sha_history = []
|
||||
def process_messages(history: list[Message]):
|
||||
updated_history = []
|
||||
for hist in history:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
if hist.role == "user":
|
||||
sha_history.append(hist.content)
|
||||
sha256_hash = hashlib.sha256()
|
||||
joined_key_str = ("#.#").join(sha_history)
|
||||
sha256_hash.update(joined_key_str.encode())
|
||||
sha_key = sha256_hash.hexdigest()
|
||||
logger.info(f"sha_key: {sha_key}")
|
||||
if sha_key in state_map:
|
||||
tool_call_state = state_map[sha_key]
|
||||
if "tool_call" in tool_call_state:
|
||||
tool_call_str = json.dumps(tool_call_state["tool_call"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
if "tool_response" in tool_call_state:
|
||||
tool_resp = tool_call_state["tool_response"]
|
||||
# TODO: try with role = user as well
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{tool_resp}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
# we dont want to match this state with any other messages
|
||||
del state_map[sha_key]
|
||||
|
||||
if hist.tool_calls:
|
||||
if len(hist.tool_calls) > 1:
|
||||
raise ValueError("Only one tool call is supported")
|
||||
tool_call_str = json.dumps(hist.tool_calls[0]["function"])
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": f"<tool_call>\n{tool_call_str}\n</tool_call>",
|
||||
}
|
||||
)
|
||||
elif hist.role == "tool":
|
||||
updated_history.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"<tool_response>\n{hist.content}\n</tool_response>",
|
||||
}
|
||||
)
|
||||
else:
|
||||
updated_history.append({"role": hist.role, "content": hist.content})
|
||||
return updated_history
|
||||
|
||||
|
||||
|
|
@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
||||
metadata = req.metadata
|
||||
arch_state = metadata.get("x-arch-state", "[]")
|
||||
|
||||
updated_history = process_state(arch_state, req.messages)
|
||||
updated_history = process_messages(req.messages)
|
||||
for message in updated_history:
|
||||
messages.append({"role": message["role"], "content": message["content"]})
|
||||
|
||||
|
|
|
|||
66
model_server/app/tests/test_state.py
Normal file
66
model_server/app/tests/test_state.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
from typing import List
|
||||
import pytest
|
||||
import json
|
||||
from app.function_calling.model_utils import Message, process_messages
|
||||
|
||||
test_input_history = """
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function-1.5B",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": "--",
|
||||
"tool_call_id": "call_3394"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "--",
|
||||
"model": "gpt-3.5-turbo-0125"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather in chicago for next 5 days?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": { "city": "Chicago", "days": 5 }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
|
||||
|
||||
def test_update_fc_history():
|
||||
history = json.loads(test_input_history)
|
||||
message_history = []
|
||||
for h in history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = process_messages(message_history)
|
||||
assert len(updated_history) == 6
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
Loading…
Add table
Add a link
Reference in a new issue