diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d846666a..ac33c76c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -19,12 +19,12 @@ jobs: - name: Setup | Install wasm toolchain run: rustup target add wasm32-wasi - - name: Build wasm module for prompt_gateway - run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi - - name: Run Tests on common crate run: cd crates/common && cargo test + - name: Build wasm module for prompt_gateway + run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi + - name: Run Tests on prompt_gateway crate run: cd crates/prompt_gateway && cargo test diff --git a/arch/Dockerfile b/arch/Dockerfile index 3a875a62..073c0b6b 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -12,8 +12,8 @@ FROM envoyproxy/envoy:v1.31-latest as envoy #Build config generator, so that we have a single build image for both Rust and Python FROM python:3-slim as arch -COPY --from=builder /arch/prompt_gateway/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm -COPY --from=builder /arch/llm_gateway/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm +COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm +COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy WORKDIR /config COPY arch/requirements.txt . diff --git a/arch/tools/.vscode/settings.json b/arch/tools/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/arch/tools/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} diff --git a/gateway.code-workspace b/archgw.code-workspace similarity index 54% rename from gateway.code-workspace rename to archgw.code-workspace index cc1b4efc..9148057d 100644 --- a/gateway.code-workspace +++ b/archgw.code-workspace @@ -5,19 +5,11 @@ "path": "." }, { - "name": "common", - "path": "crates/common" + "name": "crates", + "path": "crates" }, { - "name": "prompt_gateway", - "path": "crates/prompt_gateway" - }, - { - "name": "llm_gateway", - "path": "crates/llm_gateway" - }, - { - "name": "arch/tools", + "name": "archgw_cli", "path": "arch/tools" }, { @@ -36,10 +28,15 @@ "name": "demos/insurance_agent", "path": "./demos/insurance_agent", }, - { - "name": "demos/function_calling/api_server", - "path": "./demos/function_calling/api_server", - }, ], - "settings": {} + "settings": { + }, + "extensions": { + "recommendations": [ + "ms-python.python", + "ms-python.debugpy", + "rust-lang.rust-analyzer", + "humao.rest-client" + ] + } } diff --git a/chatbot_ui/.vscode/launch.json b/chatbot_ui/.vscode/launch.json index 47ee5a58..8b42a191 100644 --- a/chatbot_ui/.vscode/launch.json +++ b/chatbot_ui/.vscode/launch.json @@ -5,6 +5,7 @@ "version": "0.2.0", "configurations": [ { + "python": "${workspaceFolder}/venv/bin/python", "name": "chatbot-ui", "cwd": "${workspaceFolder}/app", "type": "debugpy", diff --git a/chatbot_ui/.vscode/settings.json b/chatbot_ui/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/chatbot_ui/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index f2e85231..02d6e01c 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -2,14 +2,21 @@ import json import os from openai import OpenAI, DefaultHttpxClient import gradio as gr -import logging as log +import logging from dotenv import load_dotenv load_dotenv() +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +log = logging.getLogger(__name__) + CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") ARCH_STATE_HEADER = "x-arch-state" -log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) +log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") client = OpenAI( api_key="--", @@ -23,23 +30,19 @@ def predict(message, state): state["history"] = [] history = state.get("history") history.append({"role": "user", "content": message}) - log.info("history: ", history) + log.info(f"history: {history}") # Custom headers custom_headers = { "x-arch-deterministic-provider": "openai", } - metadata = None - if "arch_state" in state: - metadata = {ARCH_STATE_HEADER: state["arch_state"]} - try: raw_response = client.chat.completions.with_raw_response.create( model="--", messages=history, temperature=1.0, - metadata=metadata, + # metadata=metadata, extra_headers=custom_headers, ) except Exception as e: @@ -49,26 +52,35 @@ def predict(message, state): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - log.info("raw_response: ", raw_response.text) + log.error(f"raw_response: {raw_response.text}") response = raw_response.parse() # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request response_json = json.loads(raw_response.text) - arch_state = None if response_json: - metadata = response_json.get("metadata", {}) - if metadata: - arch_state = metadata.get(ARCH_STATE_HEADER, None) - if arch_state: - state["arch_state"] = arch_state + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + for message in arch_messages: + history.append(message) content = response.choices[0].message.content history.append({"role": "assistant", "content": content, "model": response.model}) + + # for gradio UI we don't want to show raw tool calls and messages from developer application + # so we're filtering those out + history_view = [h for h in history if h["role"] != "tool" and "content" in h] messages = [ - (history[i]["content"], history[i + 1]["content"]) - for i in range(0, len(history) - 1, 2) + (history_view[i]["content"], history_view[i + 1]["content"]) + for i in range(0, len(history_view) - 1, 2) ] return messages, state diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index a362da9c..4651c610 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -14,6 +14,7 @@ derivative = "2.2.0" thiserror = "1.0.64" tiktoken-rs = "0.5.9" rand = "0.8.5" +serde_json = "1.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index fb0f902c..c8f91e0f 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -188,6 +188,8 @@ pub mod open_ai { pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -381,6 +383,7 @@ mod test { content: Some("What city do you want to know the weather for?".to_string()), model: None, tool_calls: None, + tool_call_id: None, }], tools: Some(vec![super::open_ai::ChatCompletionTool { tool_type: ToolType::Function, diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 63ab156c..293dad09 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -229,20 +229,6 @@ mod test { let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); assert_eq!(config.version, "v0.1"); - let open_ai_provider = config - .llm_providers - .iter() - .find(|p| p.name.to_lowercase() == "openai") - .unwrap(); - assert_eq!(open_ai_provider.name.to_lowercase(), "openai"); - assert_eq!( - open_ai_provider.access_key, - Some("OPENAI_API_KEY".to_string()) - ); - assert_eq!(open_ai_provider.model, "gpt-4o"); - assert_eq!(open_ai_provider.default, Some(true)); - assert_eq!(open_ai_provider.stream, Some(true)); - let prompt_guards = config.prompt_guards.as_ref().unwrap(); let input_guards = &prompt_guards.input_guards; let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index badba77d..fe67f6a8 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; +pub const TOOL_ROLE: &str = "tool"; +pub const ASSISTANT_ROLE: &str = "assistant"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes pub const MODEL_SERVER_NAME: &str = "model_server"; @@ -16,7 +18,7 @@ pub const GUARD_INTERNAL_HOST: &str = "guard"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; +pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const ARCH_STATE_HEADER: &str = "x-arch-state"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs new file mode 100644 index 00000000..fd634915 --- /dev/null +++ b/crates/common/src/errors.rs @@ -0,0 +1,39 @@ +use proxy_wasm::types::Status; + +use crate::ratelimit; + +#[derive(thiserror::Error, Debug)] +pub enum ClientError { + #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] + DispatchError { + upstream_name: String, + path: String, + internal_status: Status, + }, +} + +#[derive(thiserror::Error, Debug)] +pub enum ServerError { + #[error(transparent)] + HttpDispatch(ClientError), + #[error(transparent)] + Deserialization(serde_json::Error), + #[error(transparent)] + Serialization(serde_json::Error), + #[error("{0}")] + LogicError(String), + #[error("upstream error response authority={authority}, path={path}, status={status}")] + Upstream { + authority: String, + path: String, + status: String, + }, + #[error("jailbreak detected: {0}")] + Jailbreak(String), + #[error("{why}")] + NoMessagesFound { why: String }, + #[error(transparent)] + ExceededRatelimit(ratelimit::Error), + #[error("{why}")] + BadRequest { why: String }, +} diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 21380b0f..842818e2 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,4 @@ -use crate::stats::{Gauge, IncrementingMetric}; +use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; @@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> { } } -#[derive(thiserror::Error, Debug)] -pub enum ClientError { - #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] - DispatchError { - upstream_name: String, - path: String, - internal_status: Status, - }, -} - pub trait Client: Context { type CallContext: Debug; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 27a51803..c23443ca 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -10,3 +10,4 @@ pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; +pub mod errors; diff --git a/crates/llm_gateway/Cargo.lock b/crates/llm_gateway/Cargo.lock index 35182863..19ce3747 100644 --- a/crates/llm_gateway/Cargo.lock +++ b/crates/llm_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/llm_gateway/src/llm_filter_context.rs b/crates/llm_gateway/src/filter_context.rs similarity index 80% rename from crates/llm_gateway/src/llm_filter_context.rs rename to crates/llm_gateway/src/filter_context.rs index e1ed2620..be80c390 100644 --- a/crates/llm_gateway/src/llm_filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,4 +1,4 @@ -use crate::llm_stream_context::LlmGatewayStreamContext; +use crate::stream_context::StreamContext; use common::configuration::Configuration; use common::http::Client; use common::llm_providers::LlmProviders; @@ -28,19 +28,19 @@ impl WasmMetrics { } #[derive(Debug)] -pub struct FilterCallContext {} +pub struct CallContext {} #[derive(Debug)] -pub struct LlmGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: RefCell>, + callouts: RefCell>, llm_providers: Option>, } -impl LlmGatewayFilterContext { - pub fn new() -> LlmGatewayFilterContext { - LlmGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, @@ -48,8 +48,8 @@ impl LlmGatewayFilterContext { } } -impl Client for LlmGatewayFilterContext { - type CallContext = FilterCallContext; +impl Client for FilterContext { + type CallContext = CallContext; fn callouts(&self) -> &RefCell> { &self.callouts @@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext { } } -impl Context for LlmGatewayFilterContext {} +impl Context for FilterContext {} // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for LlmGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - Some(Box::new(LlmGatewayStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone( diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index 766d32bb..e2ad9025 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use llm_filter_context::LlmGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod llm_filter_context; -mod llm_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(LlmGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/stream_context.rs similarity index 94% rename from crates/llm_gateway/src/llm_stream_context.rs rename to crates/llm_gateway/src/stream_context.rs index 6c585a72..bd2fba5e 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::llm_filter_context::WasmMetrics; +use crate::filter_context::WasmMetrics; use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ToolCall, ToolCallState, @@ -8,6 +8,7 @@ use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE, }; +use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::ratelimit::Header; use common::{ratelimit, routing, tokenizer}; @@ -22,25 +23,12 @@ use std::rc::Rc; use common::stats::IncrementingMetric; -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - Deserialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), - #[error("{why}")] - BadRequest { why: String }, -} - -pub struct LlmGatewayStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, tool_calls: Option>, tool_call_response: Option, arch_state: Option>, - request_body_size: usize, ratelimit_selector: Option
, streaming_response: bool, user_prompt: Option, @@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext { request_id: Option, } -impl LlmGatewayStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { - LlmGatewayStreamContext { + StreamContext { context_id, metrics, chat_completions_request: None, tool_calls: None, tool_call_response: None, arch_state: None, - request_body_size: 0, ratelimit_selector: None, streaming_response: false, user_prompt: None, @@ -160,7 +146,7 @@ impl LlmGatewayStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for LlmGatewayStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Continue; } - self.request_body_size = body_size; - // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; // remove metadata from the request body deserialized_body.metadata = None; @@ -333,10 +316,9 @@ impl HttpContext for LlmGatewayStreamContext { let chat_completions_response: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(de) => de, - Err(e) => { + Err(_e) => { debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + return Action::Continue; } }; @@ -418,4 +400,4 @@ impl HttpContext for LlmGatewayStreamContext { } } -impl Context for LlmGatewayStreamContext {} +impl Context for StreamContext {} diff --git a/crates/prompt_gateway/Cargo.lock b/crates/prompt_gateway/Cargo.lock index 63de3b3f..7679b301 100644 --- a/crates/prompt_gateway/Cargo.lock +++ b/crates/prompt_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/prompt_gateway/src/prompt_filter_context.rs b/crates/prompt_gateway/src/filter_context.rs similarity index 86% rename from crates/prompt_gateway/src/prompt_filter_context.rs rename to crates/prompt_gateway/src/filter_context.rs index 5d7c1e71..b60191a5 100644 --- a/crates/prompt_gateway/src/prompt_filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,7 +1,7 @@ -use crate::prompt_stream_context::PromptStreamContext; +use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; use common::consts::{ARCH_INTERNAL_CLUSTER_NAME, EMBEDDINGS_INTERNAL_HOST}; +use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; use common::embeddings::{ @@ -9,7 +9,6 @@ use common::embeddings::{ }; use common::http::CallArgs; use common::http::Client; -use common::llm_providers::LlmProviders; use common::stats::Gauge; use common::stats::IncrementingMetric; use log::debug; @@ -44,31 +43,27 @@ pub struct FilterCallContext { } #[derive(Debug)] -pub struct PromptGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, overrides: Rc>, system_prompt: Rc>, prompt_targets: Rc>, - mode: GatewayMode, prompt_guards: Rc, - llm_providers: Option>, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, } -impl PromptGatewayFilterContext { - pub fn new() -> PromptGatewayFilterContext { - PromptGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), system_prompt: Rc::new(None), prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - mode: GatewayMode::Prompt, - llm_providers: None, embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), } @@ -116,7 +111,7 @@ impl PromptGatewayFilterContext { Duration::from_secs(60), ); - let call_context = crate::prompt_filter_context::FilterCallContext { + let call_context = crate::filter_context::FilterCallContext { prompt_target_name: String::from(prompt_target_name), embedding_type, }; @@ -193,7 +188,7 @@ impl PromptGatewayFilterContext { } } -impl Client for PromptGatewayFilterContext { +impl Client for FilterContext { type CallContext = FilterCallContext; fn callouts(&self) -> &RefCell> { @@ -205,7 +200,7 @@ impl Client for PromptGatewayFilterContext { } } -impl Context for PromptGatewayFilterContext { +impl Context for FilterContext { fn on_http_call_response( &mut self, token_id: u32, @@ -234,7 +229,7 @@ impl Context for PromptGatewayFilterContext { } // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for PromptGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -253,17 +248,11 @@ impl RootContext for PromptGatewayFilterContext { } self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); - self.mode = config.mode.unwrap_or_default(); if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) } - match config.llm_providers.try_into() { - Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), - Err(err) => panic!("{err}"), - } - true } @@ -273,12 +262,11 @@ impl RootContext for PromptGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store = match self.mode { - GatewayMode::Llm => None, - GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), + let embedding_store = match self.embeddings_store.as_ref() { + None => return None, + Some(store) => Some(Rc::clone(store)), }; - Some(Box::new(PromptStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), @@ -299,11 +287,8 @@ impl RootContext for PromptGatewayFilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: {:?}", self.mode); - if self.mode == GatewayMode::Prompt { - self.process_prompt_targets(); - } - + debug!("starting up arch filter in mode: prompt gateway mode"); + self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); } } diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 75edea5d..e2ad9025 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use prompt_filter_context::PromptGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod prompt_filter_context; -mod prompt_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(PromptGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/prompt_gateway/src/prompt_stream_context.rs b/crates/prompt_gateway/src/stream_context.rs similarity index 91% rename from crates/prompt_gateway/src/prompt_stream_context.rs rename to crates/prompt_gateway/src/stream_context.rs index 503ea8b5..33938deb 100644 --- a/crates/prompt_gateway/src/prompt_stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,9 +1,9 @@ -use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, - StreamOptions, ToolCall, ToolCallState, ToolType, + StreamOptions, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -12,23 +12,25 @@ use common::common_types::{ }; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ - ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARCH_FC_INTERNAL_HOST, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST + ARCH_FC_INTERNAL_HOST, ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, EMBEDDINGS_INTERNAL_HOST, GPT_35_TURBO, GUARD_INTERNAL_HOST, HALLUCINATION_INTERNAL_HOST, REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, ZEROSHOT_INTERNAL_HOST }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::http::{CallArgs, Client, ClientError}; +use common::errors::ClientError; +use common::http::{CallArgs, Client}; use common::stats::Gauge; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use serde_json::Value; -use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use std::str::FromStr; use std::time::Duration; +use derivative::Derivative; use common::stats::IncrementingMetric; @@ -43,11 +45,13 @@ enum ResponseHandlerType { DefaultTarget, } -#[derive(Debug, Clone)] +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct StreamCallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target_name: Option, + #[derivative(Debug = "ignore")] request_body: ChatCompletionsRequest, tool_calls: Option>, similarity_scores: Option>, @@ -65,11 +69,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), @@ -77,7 +82,7 @@ pub enum ServerError { NoMessagesFound { why: String }, } -pub struct PromptStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, system_prompt: Rc>, @@ -98,8 +103,7 @@ pub struct PromptStreamContext { request_id: Option, } -impl PromptStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new( context_id: u32, metrics: Rc, @@ -109,7 +113,7 @@ impl PromptStreamContext { overrides: Rc>, embeddings_store: Option>, ) -> Self { - PromptStreamContext { + StreamContext { context_id, metrics, system_prompt, @@ -145,7 +149,6 @@ impl PromptStreamContext { } fn send_server_error(&self, error: ServerError, override_status_code: Option) { - debug!("server error occurred: {}", error); self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -160,6 +163,7 @@ impl PromptStreamContext { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { + debug!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -230,6 +234,7 @@ impl PromptStreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing zero shot classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -259,6 +264,7 @@ impl PromptStreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -272,6 +278,7 @@ impl PromptStreamContext { match serde_json::from_slice(&body) { Ok(hallucination_response) => hallucination_response, Err(e) => { + debug!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -297,6 +304,7 @@ impl PromptStreamContext { content: Some(response), model: Some(ARCH_FC_MODEL_NAME.to_string()), tool_calls: None, + tool_call_id: None, }; let chat_completion_response = ChatCompletionsResponse { @@ -335,6 +343,7 @@ impl PromptStreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { + debug!("error deserializing zero shot classification response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -446,6 +455,7 @@ impl PromptStreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -461,6 +471,7 @@ impl PromptStreamContext { let prompt_target = match self.prompt_targets.get(&prompt_target_name) { Some(prompt_target) => prompt_target.clone(), None => { + debug!("prompt target not found: {}", prompt_target_name); return self.send_server_error( ServerError::LogicError(format!( "Prompt target not found: {prompt_target_name}" @@ -533,6 +544,7 @@ impl PromptStreamContext { msg_body } Err(e) => { + debug!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -565,6 +577,7 @@ impl PromptStreamContext { callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } } @@ -576,6 +589,7 @@ impl PromptStreamContext { let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { + debug!("error deserializing arch_fc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -689,6 +703,7 @@ impl PromptStreamContext { match serde_json::to_string(&hallucination_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing hallucination classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -781,17 +796,19 @@ impl PromptStreamContext { fn function_call_response_handler( &mut self, body: Vec, - mut callout_context: StreamCallContext, + callout_context: StreamCallContext, ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { + debug!("upstream error response: {}", http_status); return self.send_server_error( ServerError::Upstream { - authority: callout_context.upstream_cluster.unwrap(), + host: callout_context.upstream_cluster.unwrap(), path: callout_context.upstream_cluster_path.unwrap(), - status: http_status, + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), }, - None, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), ); } } else { @@ -823,11 +840,18 @@ impl PromptStreamContext { content: system_prompt, model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } - messages.append(callout_context.request_body.messages.as_mut()); + // don't send tools message and api response to chat gpt + for m in callout_context.request_body.messages.iter() { + if m.role == TOOL_ROLE || m.content.is_none() { + continue; + } + messages.push(m.clone()); + } let user_message = match messages.pop() { Some(user_message) => user_message, @@ -854,6 +878,7 @@ impl PromptStreamContext { content: Some(final_prompt), model: None, tool_calls: None, + tool_call_id: None, } }); @@ -889,6 +914,7 @@ impl PromptStreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); + debug!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -912,6 +938,7 @@ impl PromptStreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing get embeddings request: {}", error); return self.send_server_error(ServerError::Deserialization(error), None); } }; @@ -948,6 +975,7 @@ impl PromptStreamContext { }; if let Err(e) = self.http_call(call_args, call_context) { + debug!("error dispatching get embeddings request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -981,6 +1009,7 @@ impl PromptStreamContext { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { + debug!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1000,6 +1029,7 @@ impl PromptStreamContext { content: Some(system_prompt.clone()), model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } @@ -1010,6 +1040,7 @@ impl PromptStreamContext { content: Some(api_resp.clone()), model: None, tool_calls: None, + tool_call_id: None, }); let chat_completion_request = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), @@ -1027,7 +1058,7 @@ impl PromptStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for PromptStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -1090,7 +1121,6 @@ impl HttpContext for PromptStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { @@ -1256,9 +1286,8 @@ impl HttpContext for PromptStreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e); + return Action::Continue; } }; @@ -1276,55 +1305,42 @@ impl HttpContext for PromptStreamContext { self.arch_state = Some(Vec::new()); } - // compute sha hash from message history - let mut hasher = Sha256::new(); - let prompts: Vec = self - .chat_completions_request - .as_ref() - .unwrap() - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .map(|msg| msg.content.clone().unwrap()) - .collect(); - let prompts_merged = prompts.join("#.#"); - hasher.update(prompts_merged.clone()); - let hash_key = hasher.finalize(); - // conver hash to hex string - let hash_key_str = format!("{:x}", hash_key); - debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); - - // create new tool call state - let tool_call_state = ToolCallState { - key: hash_key_str, - message: self.user_prompt.clone(), - tool_call: tool_calls[0].function.clone(), - tool_response: self.tool_call_response.clone().unwrap(), - }; - - // push tool call state to arch state - self.arch_state - .as_mut() - .unwrap() - .push(ArchState::ToolCall(vec![tool_call_state])); - let mut data: Value = serde_json::from_slice(&body).unwrap(); // use serde::Value to manipulate the json object and ensure that we don't lose any data if let Value::Object(ref mut map) = data { // serialize arch state and add to metadata - let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); - debug!("arch_state: {}", arch_state_str); let metadata = map .entry("metadata") .or_insert(Value::Object(serde_json::Map::new())); if metadata == &Value::Null { *metadata = Value::Object(serde_json::Map::new()); } + + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let mut fc_messages = Vec::new(); + fc_messages.push(Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }); + fc_messages.push(Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }); + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); - let data_serialized = serde_json::to_string(&data).unwrap(); debug!("arch => user: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); @@ -1342,7 +1358,7 @@ impl HttpContext for PromptStreamContext { } } -impl Context for PromptStreamContext { +impl Context for StreamContext { fn on_http_call_response( &mut self, token_id: u32, @@ -1388,7 +1404,7 @@ impl Context for PromptStreamContext { } } -impl Client for PromptStreamContext { +impl Client for StreamContext { type CallContext = StreamCallContext; fn callouts(&self) -> &RefCell> { diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 5f27adc3..14ca1aa2 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, @@ -547,6 +546,7 @@ fn request_to_llm_gateway() { }, }]), model: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -648,6 +648,7 @@ fn request_to_llm_gateway() { content: Some("hello from fake llm gateway".to_string()), model: None, tool_calls: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -666,8 +667,6 @@ fn request_to_llm_gateway() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Continue)) diff --git a/model_server/.vscode/settings.json b/model_server/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/model_server/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 3e4e6654..04078a1b 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -13,62 +13,38 @@ logger = get_model_server_logger() class Message(BaseModel): role: str - content: str + content: str = "" + tool_calls: List[Dict[str, Any]] = [] + tool_call_id: str = "" class ChatMessage(BaseModel): messages: list[Message] tools: List[Dict[str, Any]] - # TODO: make it default none - metadata: Dict[str, str] = {} - -def process_state(arch_state, history: list[Message]): - logger.info("state: {}".format(arch_state)) - state_json = json.loads(arch_state) - - state_map = {} - if state_json: - for tools_state in state_json: - for tool_state in tools_state: - state_map[tool_state["key"]] = tool_state - - logger.info(f"state_map: {json.dumps(state_map)}") - - sha_history = [] +def process_messages(history: list[Message]): updated_history = [] for hist in history: - updated_history.append({"role": hist.role, "content": hist.content}) - if hist.role == "user": - sha_history.append(hist.content) - sha256_hash = hashlib.sha256() - joined_key_str = ("#.#").join(sha_history) - sha256_hash.update(joined_key_str.encode()) - sha_key = sha256_hash.hexdigest() - logger.info(f"sha_key: {sha_key}") - if sha_key in state_map: - tool_call_state = state_map[sha_key] - if "tool_call" in tool_call_state: - tool_call_str = json.dumps(tool_call_state["tool_call"]) - updated_history.append( - { - "role": "assistant", - "content": f"\n{tool_call_str}\n", - } - ) - if "tool_response" in tool_call_state: - tool_resp = tool_call_state["tool_response"] - # TODO: try with role = user as well - updated_history.append( - { - "role": "user", - "content": f"\n{tool_resp}\n", - } - ) - # we dont want to match this state with any other messages - del state_map[sha_key] - + if hist.tool_calls: + if len(hist.tool_calls) > 1: + raise ValueError("Only one tool call is supported") + tool_call_str = json.dumps(hist.tool_calls[0]["function"]) + updated_history.append( + { + "role": "assistant", + "content": f"\n{tool_call_str}\n", + } + ) + elif hist.role == "tool": + updated_history.append( + { + "role": "user", + "content": f"\n{hist.content}\n", + } + ) + else: + updated_history.append({"role": hist.role, "content": hist.content}) return updated_history @@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response): messages = [{"role": "system", "content": tools_encoded}] - metadata = req.metadata - arch_state = metadata.get("x-arch-state", "[]") - - updated_history = process_state(arch_state, req.messages) + updated_history = process_messages(req.messages) for message in updated_history: messages.append({"role": message["role"], "content": message["content"]}) diff --git a/model_server/app/tests/test_state.py b/model_server/app/tests/test_state.py new file mode 100644 index 00000000..9eb72c8c --- /dev/null +++ b/model_server/app/tests/test_state.py @@ -0,0 +1,66 @@ +from typing import List +import pytest +import json +from app.function_calling.model_utils import Message, process_messages + +test_input_history = """ +[ + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "tool_calls": [ + { + "id": "call_3394", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + }, + { + "role": "tool", + "content": "--", + "tool_call_id": "call_3394" + }, + { + "role": "assistant", + "content": "--", + "model": "gpt-3.5-turbo-0125" + }, + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_5306", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + } + ] + """ + + +def test_update_fc_history(): + history = json.loads(test_input_history) + message_history = [] + for h in history: + message_history.append(Message(**h)) + + updated_history = process_messages(message_history) + assert len(updated_history) == 6 + # ensure that tool role does not exist anymore + assert all([h["role"] != "tool" for h in updated_history])