diff --git a/chatbot_ui/.vscode/launch.json b/chatbot_ui/.vscode/launch.json index 47ee5a58..8b42a191 100644 --- a/chatbot_ui/.vscode/launch.json +++ b/chatbot_ui/.vscode/launch.json @@ -5,6 +5,7 @@ "version": "0.2.0", "configurations": [ { + "python": "${workspaceFolder}/venv/bin/python", "name": "chatbot-ui", "cwd": "${workspaceFolder}/app", "type": "debugpy", diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index f2e85231..02d6e01c 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -2,14 +2,21 @@ import json import os from openai import OpenAI, DefaultHttpxClient import gradio as gr -import logging as log +import logging from dotenv import load_dotenv load_dotenv() +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +log = logging.getLogger(__name__) + CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") ARCH_STATE_HEADER = "x-arch-state" -log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) +log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") client = OpenAI( api_key="--", @@ -23,23 +30,19 @@ def predict(message, state): state["history"] = [] history = state.get("history") history.append({"role": "user", "content": message}) - log.info("history: ", history) + log.info(f"history: {history}") # Custom headers custom_headers = { "x-arch-deterministic-provider": "openai", } - metadata = None - if "arch_state" in state: - metadata = {ARCH_STATE_HEADER: state["arch_state"]} - try: raw_response = client.chat.completions.with_raw_response.create( model="--", messages=history, temperature=1.0, - metadata=metadata, + # metadata=metadata, extra_headers=custom_headers, ) except Exception as e: @@ -49,26 +52,35 @@ def predict(message, state): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - log.info("raw_response: ", raw_response.text) + log.error(f"raw_response: {raw_response.text}") response = raw_response.parse() # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request response_json = json.loads(raw_response.text) - arch_state = None if response_json: - metadata = response_json.get("metadata", {}) - if metadata: - arch_state = metadata.get(ARCH_STATE_HEADER, None) - if arch_state: - state["arch_state"] = arch_state + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + for message in arch_messages: + history.append(message) content = response.choices[0].message.content history.append({"role": "assistant", "content": content, "model": response.model}) + + # for gradio UI we don't want to show raw tool calls and messages from developer application + # so we're filtering those out + history_view = [h for h in history if h["role"] != "tool" and "content" in h] messages = [ - (history[i]["content"], history[i + 1]["content"]) - for i in range(0, len(history) - 1, 2) + (history_view[i]["content"], history_view[i + 1]["content"]) + for i in range(0, len(history_view) - 1, 2) ] return messages, state diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index fb0f902c..c8f91e0f 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -188,6 +188,8 @@ pub mod open_ai { pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -381,6 +383,7 @@ mod test { content: Some("What city do you want to know the weather for?".to_string()), model: None, tool_calls: None, + tool_call_id: None, }], tools: Some(vec![super::open_ai::ChatCompletionTool { tool_type: ToolType::Function, diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index ce119eab..fdc21aed 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; +pub const TOOL_ROLE: &str = "tool"; +pub const ASSISTANT_ROLE: &str = "assistant"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const ARC_FC_CLUSTER: &str = "arch_fc"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index da4d344f..18463b49 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,7 +3,7 @@ use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, - StreamOptions, ToolCall, ToolCallState, ToolType, + StreamOptions, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, - CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, + ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, - REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, + REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -29,12 +29,12 @@ use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use serde_json::Value; -use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; +use derivative::Derivative; use common::stats::IncrementingMetric; @@ -49,11 +49,13 @@ enum ResponseHandlerType { DefaultTarget, } -#[derive(Debug, Clone)] +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct StreamCallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target_name: Option, + #[derivative(Debug = "ignore")] request_body: ChatCompletionsRequest, tool_calls: Option>, similarity_scores: Option>, @@ -306,6 +308,7 @@ impl StreamContext { content: Some(response), model: Some(ARCH_FC_MODEL_NAME.to_string()), tool_calls: None, + tool_call_id: None, }; let chat_completion_response = ChatCompletionsResponse { @@ -797,7 +800,7 @@ impl StreamContext { fn function_call_response_handler( &mut self, body: Vec, - mut callout_context: StreamCallContext, + callout_context: StreamCallContext, ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { @@ -841,11 +844,18 @@ impl StreamContext { content: system_prompt, model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } - messages.append(callout_context.request_body.messages.as_mut()); + // don't send tools message and api response to chat gpt + for m in callout_context.request_body.messages.iter() { + if m.role == TOOL_ROLE || m.content.is_none() { + continue; + } + messages.push(m.clone()); + } let user_message = match messages.pop() { Some(user_message) => user_message, @@ -872,6 +882,7 @@ impl StreamContext { content: Some(final_prompt), model: None, tool_calls: None, + tool_call_id: None, } }); @@ -1022,6 +1033,7 @@ impl StreamContext { content: Some(system_prompt.clone()), model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } @@ -1032,6 +1044,7 @@ impl StreamContext { content: Some(api_resp.clone()), model: None, tool_calls: None, + tool_call_id: None, }); let chat_completion_request = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), @@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext { self.arch_state = Some(Vec::new()); } - // compute sha hash from message history - let mut hasher = Sha256::new(); - let prompts: Vec = self - .chat_completions_request - .as_ref() - .unwrap() - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .map(|msg| msg.content.clone().unwrap()) - .collect(); - let prompts_merged = prompts.join("#.#"); - hasher.update(prompts_merged.clone()); - let hash_key = hasher.finalize(); - // conver hash to hex string - let hash_key_str = format!("{:x}", hash_key); - debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); - - // create new tool call state - let tool_call_state = ToolCallState { - key: hash_key_str, - message: self.user_prompt.clone(), - tool_call: tool_calls[0].function.clone(), - tool_response: self.tool_call_response.clone().unwrap(), - }; - - // push tool call state to arch state - self.arch_state - .as_mut() - .unwrap() - .push(ArchState::ToolCall(vec![tool_call_state])); - let mut data: Value = serde_json::from_slice(&body).unwrap(); // use serde::Value to manipulate the json object and ensure that we don't lose any data if let Value::Object(ref mut map) = data { // serialize arch state and add to metadata - let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); - debug!("arch_state: {}", arch_state_str); let metadata = map .entry("metadata") .or_insert(Value::Object(serde_json::Map::new())); if metadata == &Value::Null { *metadata = Value::Object(serde_json::Map::new()); } + + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let mut fc_messages = Vec::new(); + fc_messages.push(Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }); + fc_messages.push(Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }); + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); - let data_serialized = serde_json::to_string(&data).unwrap(); debug!("arch => user: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 0338f23b..14ca1aa2 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -546,6 +546,7 @@ fn request_to_llm_gateway() { }, }]), model: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -647,6 +648,7 @@ fn request_to_llm_gateway() { content: Some("hello from fake llm gateway".to_string()), model: None, tool_calls: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -665,8 +667,6 @@ fn request_to_llm_gateway() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Continue)) diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 3e4e6654..04078a1b 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -13,62 +13,38 @@ logger = get_model_server_logger() class Message(BaseModel): role: str - content: str + content: str = "" + tool_calls: List[Dict[str, Any]] = [] + tool_call_id: str = "" class ChatMessage(BaseModel): messages: list[Message] tools: List[Dict[str, Any]] - # TODO: make it default none - metadata: Dict[str, str] = {} - -def process_state(arch_state, history: list[Message]): - logger.info("state: {}".format(arch_state)) - state_json = json.loads(arch_state) - - state_map = {} - if state_json: - for tools_state in state_json: - for tool_state in tools_state: - state_map[tool_state["key"]] = tool_state - - logger.info(f"state_map: {json.dumps(state_map)}") - - sha_history = [] +def process_messages(history: list[Message]): updated_history = [] for hist in history: - updated_history.append({"role": hist.role, "content": hist.content}) - if hist.role == "user": - sha_history.append(hist.content) - sha256_hash = hashlib.sha256() - joined_key_str = ("#.#").join(sha_history) - sha256_hash.update(joined_key_str.encode()) - sha_key = sha256_hash.hexdigest() - logger.info(f"sha_key: {sha_key}") - if sha_key in state_map: - tool_call_state = state_map[sha_key] - if "tool_call" in tool_call_state: - tool_call_str = json.dumps(tool_call_state["tool_call"]) - updated_history.append( - { - "role": "assistant", - "content": f"\n{tool_call_str}\n", - } - ) - if "tool_response" in tool_call_state: - tool_resp = tool_call_state["tool_response"] - # TODO: try with role = user as well - updated_history.append( - { - "role": "user", - "content": f"\n{tool_resp}\n", - } - ) - # we dont want to match this state with any other messages - del state_map[sha_key] - + if hist.tool_calls: + if len(hist.tool_calls) > 1: + raise ValueError("Only one tool call is supported") + tool_call_str = json.dumps(hist.tool_calls[0]["function"]) + updated_history.append( + { + "role": "assistant", + "content": f"\n{tool_call_str}\n", + } + ) + elif hist.role == "tool": + updated_history.append( + { + "role": "user", + "content": f"\n{hist.content}\n", + } + ) + else: + updated_history.append({"role": hist.role, "content": hist.content}) return updated_history @@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response): messages = [{"role": "system", "content": tools_encoded}] - metadata = req.metadata - arch_state = metadata.get("x-arch-state", "[]") - - updated_history = process_state(arch_state, req.messages) + updated_history = process_messages(req.messages) for message in updated_history: messages.append({"role": message["role"], "content": message["content"]}) diff --git a/model_server/app/tests/test_state.py b/model_server/app/tests/test_state.py new file mode 100644 index 00000000..9eb72c8c --- /dev/null +++ b/model_server/app/tests/test_state.py @@ -0,0 +1,66 @@ +from typing import List +import pytest +import json +from app.function_calling.model_utils import Message, process_messages + +test_input_history = """ +[ + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "tool_calls": [ + { + "id": "call_3394", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + }, + { + "role": "tool", + "content": "--", + "tool_call_id": "call_3394" + }, + { + "role": "assistant", + "content": "--", + "model": "gpt-3.5-turbo-0125" + }, + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_5306", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + } + ] + """ + + +def test_update_fc_history(): + history = json.loads(test_input_history) + message_history = [] + for h in history: + message_history.append(Message(**h)) + + updated_history = process_messages(message_history) + assert len(updated_history) == 6 + # ensure that tool role does not exist anymore + assert all([h["role"] != "tool" for h in updated_history])