From c6ba28dfcc3f3fa97184f6f4efc5dcb8e62f7e38 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 18 Oct 2024 12:53:44 -0700 Subject: [PATCH 1/5] Code refactor and some improvements - see description (#194) --- .github/workflows/checks.yml | 6 +-- crates/common/Cargo.toml | 1 + crates/common/src/configuration.rs | 14 ------ crates/common/src/consts.rs | 2 +- crates/common/src/errors.rs | 39 +++++++++++++++ crates/common/src/http.rs | 12 +---- crates/common/src/lib.rs | 1 + crates/llm_gateway/Cargo.lock | 1 + ...lm_filter_context.rs => filter_context.rs} | 25 +++++----- crates/llm_gateway/src/lib.rs | 8 ++-- ...lm_stream_context.rs => stream_context.rs} | 31 +++--------- crates/prompt_gateway/Cargo.lock | 1 + ...pt_filter_context.rs => filter_context.rs} | 47 +++++++------------ crates/prompt_gateway/src/lib.rs | 8 ++-- ...pt_stream_context.rs => stream_context.rs} | 19 ++++---- 15 files changed, 100 insertions(+), 115 deletions(-) create mode 100644 crates/common/src/errors.rs rename crates/llm_gateway/src/{llm_filter_context.rs => filter_context.rs} (80%) rename crates/llm_gateway/src/{llm_stream_context.rs => stream_context.rs} (95%) rename crates/prompt_gateway/src/{prompt_filter_context.rs => filter_context.rs} (86%) rename crates/prompt_gateway/src/{prompt_stream_context.rs => stream_context.rs} (99%) diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index d846666a..ac33c76c 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -19,12 +19,12 @@ jobs: - name: Setup | Install wasm toolchain run: rustup target add wasm32-wasi - - name: Build wasm module for prompt_gateway - run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi - - name: Run Tests on common crate run: cd crates/common && cargo test + - name: Build wasm module for prompt_gateway + run: cd crates/prompt_gateway && cargo build --release --target=wasm32-wasi + - name: Run Tests on prompt_gateway crate run: cd crates/prompt_gateway && cargo test diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index a362da9c..4651c610 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -14,6 +14,7 @@ derivative = "2.2.0" thiserror = "1.0.64" tiktoken-rs = "0.5.9" rand = "0.8.5" +serde_json = "1.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 63ab156c..293dad09 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -229,20 +229,6 @@ mod test { let config: super::Configuration = serde_yaml::from_str(&ref_config).unwrap(); assert_eq!(config.version, "v0.1"); - let open_ai_provider = config - .llm_providers - .iter() - .find(|p| p.name.to_lowercase() == "openai") - .unwrap(); - assert_eq!(open_ai_provider.name.to_lowercase(), "openai"); - assert_eq!( - open_ai_provider.access_key, - Some("OPENAI_API_KEY".to_string()) - ); - assert_eq!(open_ai_provider.model, "gpt-4o"); - assert_eq!(open_ai_provider.default, Some(true)); - assert_eq!(open_ai_provider.stream, Some(true)); - let prompt_guards = config.prompt_guards.as_ref().unwrap(); let input_guards = &prompt_guards.input_guards; let jailbreak_guard = input_guards.get(&GuardType::Jailbreak).unwrap(); diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 76244f6b..ce119eab 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -12,7 +12,7 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const ARCH_MESSAGES_KEY: &str = "arch_messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: &str = "v1/chat/completions"; +pub const CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const ARCH_STATE_HEADER: &str = "x-arch-state"; pub const ARCH_FC_MODEL_NAME: &str = "Arch-Function-1.5B"; pub const REQUEST_ID_HEADER: &str = "x-request-id"; diff --git a/crates/common/src/errors.rs b/crates/common/src/errors.rs new file mode 100644 index 00000000..fd634915 --- /dev/null +++ b/crates/common/src/errors.rs @@ -0,0 +1,39 @@ +use proxy_wasm::types::Status; + +use crate::ratelimit; + +#[derive(thiserror::Error, Debug)] +pub enum ClientError { + #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] + DispatchError { + upstream_name: String, + path: String, + internal_status: Status, + }, +} + +#[derive(thiserror::Error, Debug)] +pub enum ServerError { + #[error(transparent)] + HttpDispatch(ClientError), + #[error(transparent)] + Deserialization(serde_json::Error), + #[error(transparent)] + Serialization(serde_json::Error), + #[error("{0}")] + LogicError(String), + #[error("upstream error response authority={authority}, path={path}, status={status}")] + Upstream { + authority: String, + path: String, + status: String, + }, + #[error("jailbreak detected: {0}")] + Jailbreak(String), + #[error("{why}")] + NoMessagesFound { why: String }, + #[error(transparent)] + ExceededRatelimit(ratelimit::Error), + #[error("{why}")] + BadRequest { why: String }, +} diff --git a/crates/common/src/http.rs b/crates/common/src/http.rs index 21380b0f..842818e2 100644 --- a/crates/common/src/http.rs +++ b/crates/common/src/http.rs @@ -1,4 +1,4 @@ -use crate::stats::{Gauge, IncrementingMetric}; +use crate::{errors::ClientError, stats::{Gauge, IncrementingMetric}}; use derivative::Derivative; use log::debug; use proxy_wasm::{traits::Context, types::Status}; @@ -37,16 +37,6 @@ impl<'a> CallArgs<'a> { } } -#[derive(thiserror::Error, Debug)] -pub enum ClientError { - #[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")] - DispatchError { - upstream_name: String, - path: String, - internal_status: Status, - }, -} - pub trait Client: Context { type CallContext: Debug; diff --git a/crates/common/src/lib.rs b/crates/common/src/lib.rs index 27a51803..c23443ca 100644 --- a/crates/common/src/lib.rs +++ b/crates/common/src/lib.rs @@ -10,3 +10,4 @@ pub mod ratelimit; pub mod routing; pub mod stats; pub mod tokenizer; +pub mod errors; diff --git a/crates/llm_gateway/Cargo.lock b/crates/llm_gateway/Cargo.lock index 35182863..19ce3747 100644 --- a/crates/llm_gateway/Cargo.lock +++ b/crates/llm_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/llm_gateway/src/llm_filter_context.rs b/crates/llm_gateway/src/filter_context.rs similarity index 80% rename from crates/llm_gateway/src/llm_filter_context.rs rename to crates/llm_gateway/src/filter_context.rs index e1ed2620..be80c390 100644 --- a/crates/llm_gateway/src/llm_filter_context.rs +++ b/crates/llm_gateway/src/filter_context.rs @@ -1,4 +1,4 @@ -use crate::llm_stream_context::LlmGatewayStreamContext; +use crate::stream_context::StreamContext; use common::configuration::Configuration; use common::http::Client; use common::llm_providers::LlmProviders; @@ -28,19 +28,19 @@ impl WasmMetrics { } #[derive(Debug)] -pub struct FilterCallContext {} +pub struct CallContext {} #[derive(Debug)] -pub struct LlmGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: RefCell>, + callouts: RefCell>, llm_providers: Option>, } -impl LlmGatewayFilterContext { - pub fn new() -> LlmGatewayFilterContext { - LlmGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), llm_providers: None, @@ -48,8 +48,8 @@ impl LlmGatewayFilterContext { } } -impl Client for LlmGatewayFilterContext { - type CallContext = FilterCallContext; +impl Client for FilterContext { + type CallContext = CallContext; fn callouts(&self) -> &RefCell> { &self.callouts @@ -60,10 +60,10 @@ impl Client for LlmGatewayFilterContext { } } -impl Context for LlmGatewayFilterContext {} +impl Context for FilterContext {} // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for LlmGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -90,8 +90,7 @@ impl RootContext for LlmGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - Some(Box::new(LlmGatewayStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone( diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index 766d32bb..e2ad9025 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use llm_filter_context::LlmGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod llm_filter_context; -mod llm_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(LlmGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/stream_context.rs similarity index 95% rename from crates/llm_gateway/src/llm_stream_context.rs rename to crates/llm_gateway/src/stream_context.rs index 6c585a72..655f76ff 100644 --- a/crates/llm_gateway/src/llm_stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::llm_filter_context::WasmMetrics; +use crate::filter_context::WasmMetrics; use common::common_types::open_ai::{ ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, ToolCall, ToolCallState, @@ -8,6 +8,7 @@ use common::consts::{ ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE, }; +use common::errors::ServerError; use common::llm_providers::LlmProviders; use common::ratelimit::Header; use common::{ratelimit, routing, tokenizer}; @@ -22,25 +23,12 @@ use std::rc::Rc; use common::stats::IncrementingMetric; -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - Deserialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), - #[error("{why}")] - BadRequest { why: String }, -} - -pub struct LlmGatewayStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, tool_calls: Option>, tool_call_response: Option, arch_state: Option>, - request_body_size: usize, ratelimit_selector: Option
, streaming_response: bool, user_prompt: Option, @@ -52,17 +40,15 @@ pub struct LlmGatewayStreamContext { request_id: Option, } -impl LlmGatewayStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { - LlmGatewayStreamContext { + StreamContext { context_id, metrics, chat_completions_request: None, tool_calls: None, tool_call_response: None, arch_state: None, - request_body_size: 0, ratelimit_selector: None, streaming_response: false, user_prompt: None, @@ -160,7 +146,7 @@ impl LlmGatewayStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for LlmGatewayStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -198,8 +184,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Continue; } - self.request_body_size = body_size; - // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -225,7 +209,6 @@ impl HttpContext for LlmGatewayStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; // remove metadata from the request body deserialized_body.metadata = None; @@ -418,4 +401,4 @@ impl HttpContext for LlmGatewayStreamContext { } } -impl Context for LlmGatewayStreamContext {} +impl Context for StreamContext {} diff --git a/crates/prompt_gateway/Cargo.lock b/crates/prompt_gateway/Cargo.lock index 63de3b3f..7679b301 100644 --- a/crates/prompt_gateway/Cargo.lock +++ b/crates/prompt_gateway/Cargo.lock @@ -228,6 +228,7 @@ dependencies = [ "proxy-wasm", "rand", "serde", + "serde_json", "serde_yaml", "thiserror", "tiktoken-rs", diff --git a/crates/prompt_gateway/src/prompt_filter_context.rs b/crates/prompt_gateway/src/filter_context.rs similarity index 86% rename from crates/prompt_gateway/src/prompt_filter_context.rs rename to crates/prompt_gateway/src/filter_context.rs index 0c25ee5c..655b391f 100644 --- a/crates/prompt_gateway/src/prompt_filter_context.rs +++ b/crates/prompt_gateway/src/filter_context.rs @@ -1,6 +1,6 @@ -use crate::prompt_stream_context::PromptStreamContext; +use crate::stream_context::StreamContext; use common::common_types::EmbeddingType; -use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; +use common::configuration::{Configuration, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_INTERNAL_CLUSTER_NAME; use common::consts::ARCH_UPSTREAM_HOST_HEADER; use common::consts::DEFAULT_EMBEDDING_MODEL; @@ -10,7 +10,6 @@ use common::embeddings::{ }; use common::http::CallArgs; use common::http::Client; -use common::llm_providers::LlmProviders; use common::stats::Gauge; use common::stats::IncrementingMetric; use log::debug; @@ -45,31 +44,27 @@ pub struct FilterCallContext { } #[derive(Debug)] -pub struct PromptGatewayFilterContext { +pub struct FilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, overrides: Rc>, system_prompt: Rc>, prompt_targets: Rc>, - mode: GatewayMode, prompt_guards: Rc, - llm_providers: Option>, embeddings_store: Option>, temp_embeddings_store: EmbeddingsStore, } -impl PromptGatewayFilterContext { - pub fn new() -> PromptGatewayFilterContext { - PromptGatewayFilterContext { +impl FilterContext { + pub fn new() -> FilterContext { + FilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), system_prompt: Rc::new(None), prompt_targets: Rc::new(HashMap::new()), overrides: Rc::new(None), prompt_guards: Rc::new(PromptGuards::default()), - mode: GatewayMode::Prompt, - llm_providers: None, embeddings_store: Some(Rc::new(HashMap::new())), temp_embeddings_store: HashMap::new(), } @@ -117,7 +112,7 @@ impl PromptGatewayFilterContext { Duration::from_secs(60), ); - let call_context = crate::prompt_filter_context::FilterCallContext { + let call_context = crate::filter_context::FilterCallContext { prompt_target_name: String::from(prompt_target_name), embedding_type, }; @@ -194,7 +189,7 @@ impl PromptGatewayFilterContext { } } -impl Client for PromptGatewayFilterContext { +impl Client for FilterContext { type CallContext = FilterCallContext; fn callouts(&self) -> &RefCell> { @@ -206,7 +201,7 @@ impl Client for PromptGatewayFilterContext { } } -impl Context for PromptGatewayFilterContext { +impl Context for FilterContext { fn on_http_call_response( &mut self, token_id: u32, @@ -235,7 +230,7 @@ impl Context for PromptGatewayFilterContext { } // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for PromptGatewayFilterContext { +impl RootContext for FilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -254,17 +249,11 @@ impl RootContext for PromptGatewayFilterContext { } self.system_prompt = Rc::new(config.system_prompt); self.prompt_targets = Rc::new(prompt_targets); - self.mode = config.mode.unwrap_or_default(); if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) } - match config.llm_providers.try_into() { - Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), - Err(err) => panic!("{err}"), - } - true } @@ -274,12 +263,11 @@ impl RootContext for PromptGatewayFilterContext { context_id ); - // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store = match self.mode { - GatewayMode::Llm => None, - GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), + let embedding_store = match self.embeddings_store.as_ref() { + None => return None, + Some(store) => Some(Rc::clone(store)), }; - Some(Box::new(PromptStreamContext::new( + Some(Box::new(StreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), @@ -300,11 +288,8 @@ impl RootContext for PromptGatewayFilterContext { } fn on_tick(&mut self) { - debug!("starting up arch filter in mode: {:?}", self.mode); - if self.mode == GatewayMode::Prompt { - self.process_prompt_targets(); - } - + debug!("starting up arch filter in mode: prompt gateway mode"); + self.process_prompt_targets(); self.set_tick_period(Duration::from_secs(0)); } } diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index 75edea5d..e2ad9025 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use prompt_filter_context::PromptGatewayFilterContext; +use filter_context::FilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod prompt_filter_context; -mod prompt_stream_context; +mod filter_context; +mod stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(PromptGatewayFilterContext::new()) + Box::new(FilterContext::new()) }); }} diff --git a/crates/prompt_gateway/src/prompt_stream_context.rs b/crates/prompt_gateway/src/stream_context.rs similarity index 99% rename from crates/prompt_gateway/src/prompt_stream_context.rs rename to crates/prompt_gateway/src/stream_context.rs index d208f5e8..602f1629 100644 --- a/crates/prompt_gateway/src/prompt_stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -1,4 +1,4 @@ -use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::filter_context::{EmbeddingsStore, WasmMetrics}; use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, @@ -21,7 +21,8 @@ use common::consts::{ use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; -use common::http::{CallArgs, Client, ClientError}; +use common::errors::ClientError; +use common::http::{CallArgs, Client}; use common::stats::Gauge; use http::StatusCode; use log::{debug, info, warn}; @@ -81,7 +82,7 @@ pub enum ServerError { NoMessagesFound { why: String }, } -pub struct PromptStreamContext { +pub struct StreamContext { context_id: u32, metrics: Rc, system_prompt: Rc>, @@ -102,8 +103,7 @@ pub struct PromptStreamContext { request_id: Option, } -impl PromptStreamContext { - #[allow(clippy::too_many_arguments)] +impl StreamContext { pub fn new( context_id: u32, metrics: Rc, @@ -113,7 +113,7 @@ impl PromptStreamContext { overrides: Rc>, embeddings_store: Option>, ) -> Self { - PromptStreamContext { + StreamContext { context_id, metrics, system_prompt, @@ -1031,7 +1031,7 @@ impl PromptStreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for PromptStreamContext { +impl HttpContext for StreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { @@ -1094,7 +1094,6 @@ impl HttpContext for PromptStreamContext { return Action::Pause; } }; - self.is_chat_completions_request = true; self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { @@ -1346,7 +1345,7 @@ impl HttpContext for PromptStreamContext { } } -impl Context for PromptStreamContext { +impl Context for StreamContext { fn on_http_call_response( &mut self, token_id: u32, @@ -1392,7 +1391,7 @@ impl Context for PromptStreamContext { } } -impl Client for PromptStreamContext { +impl Client for StreamContext { type CallContext = StreamCallContext; fn callouts(&self) -> &RefCell> { From 28421353fd9ce819d186a1d3800a0fad23ba4280 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 18 Oct 2024 12:57:58 -0700 Subject: [PATCH 2/5] Update vscode workspce (#199) - add recommended extensions - set python interpreter path for all python projects to be venv/bin/python - update project structure in workspace - rename project file from gatewa -> archgw --- arch/tools/.vscode/settings.json | 3 ++ ...ay.code-workspace => archgw.code-workspace | 29 +++++++++---------- chatbot_ui/.vscode/settings.json | 3 ++ model_server/.vscode/settings.json | 3 ++ 4 files changed, 22 insertions(+), 16 deletions(-) create mode 100644 arch/tools/.vscode/settings.json rename gateway.code-workspace => archgw.code-workspace (54%) create mode 100644 chatbot_ui/.vscode/settings.json create mode 100644 model_server/.vscode/settings.json diff --git a/arch/tools/.vscode/settings.json b/arch/tools/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/arch/tools/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} diff --git a/gateway.code-workspace b/archgw.code-workspace similarity index 54% rename from gateway.code-workspace rename to archgw.code-workspace index cc1b4efc..9148057d 100644 --- a/gateway.code-workspace +++ b/archgw.code-workspace @@ -5,19 +5,11 @@ "path": "." }, { - "name": "common", - "path": "crates/common" + "name": "crates", + "path": "crates" }, { - "name": "prompt_gateway", - "path": "crates/prompt_gateway" - }, - { - "name": "llm_gateway", - "path": "crates/llm_gateway" - }, - { - "name": "arch/tools", + "name": "archgw_cli", "path": "arch/tools" }, { @@ -36,10 +28,15 @@ "name": "demos/insurance_agent", "path": "./demos/insurance_agent", }, - { - "name": "demos/function_calling/api_server", - "path": "./demos/function_calling/api_server", - }, ], - "settings": {} + "settings": { + }, + "extensions": { + "recommendations": [ + "ms-python.python", + "ms-python.debugpy", + "rust-lang.rust-analyzer", + "humao.rest-client" + ] + } } diff --git a/chatbot_ui/.vscode/settings.json b/chatbot_ui/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/chatbot_ui/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} diff --git a/model_server/.vscode/settings.json b/model_server/.vscode/settings.json new file mode 100644 index 00000000..3302ded8 --- /dev/null +++ b/model_server/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.defaultInterpreterPath": "${workspaceFolder}/venv/bin/python", +} From 1719b7d5f8bbf141bb82a81da4a7d991236aa54d Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 18 Oct 2024 13:14:18 -0700 Subject: [PATCH 3/5] Send back developer error correctly (#195) --- crates/llm_gateway/src/stream_context.rs | 5 ++- crates/prompt_gateway/src/stream_context.rs | 35 +++++++++++++++------ crates/prompt_gateway/tests/integration.rs | 1 - 3 files changed, 28 insertions(+), 13 deletions(-) diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 655f76ff..bd2fba5e 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -316,10 +316,9 @@ impl HttpContext for StreamContext { let chat_completions_response: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(de) => de, - Err(e) => { + Err(_e) => { debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + return Action::Continue; } }; diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 602f1629..da4d344f 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -33,6 +33,7 @@ use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; +use std::str::FromStr; use std::time::Duration; use common::stats::IncrementingMetric; @@ -70,11 +71,12 @@ pub enum ServerError { Serialization(serde_json::Error), #[error("{0}")] LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] + #[error("upstream application error host={host}, path={path}, status={status}, body={body}")] Upstream { - authority: String, + host: String, path: String, status: String, + body: String, }, #[error("jailbreak detected: {0}")] Jailbreak(String), @@ -149,7 +151,6 @@ impl StreamContext { } fn send_server_error(&self, error: ServerError, override_status_code: Option) { - debug!("server error occurred: {}", error); self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -164,6 +165,7 @@ impl StreamContext { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { + debug!("error deserializing embedding response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -234,6 +236,7 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing zero shot classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -263,6 +266,7 @@ impl StreamContext { callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching zero shot classification request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -276,6 +280,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(hallucination_response) => hallucination_response, Err(e) => { + debug!("error deserializing hallucination response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -339,6 +344,7 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { + debug!("error deserializing zero shot classification response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -450,6 +456,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching default prompt target request: {}", e); return self.send_server_error( ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST), @@ -465,6 +472,7 @@ impl StreamContext { let prompt_target = match self.prompt_targets.get(&prompt_target_name) { Some(prompt_target) => prompt_target.clone(), None => { + debug!("prompt target not found: {}", prompt_target_name); return self.send_server_error( ServerError::LogicError(format!( "Prompt target not found: {prompt_target_name}" @@ -537,6 +545,7 @@ impl StreamContext { msg_body } Err(e) => { + debug!("error serializing arch_fc request body: {}", e); return self.send_server_error(ServerError::Serialization(e), None); } }; @@ -569,6 +578,7 @@ impl StreamContext { callout_context.prompt_target_name = Some(prompt_target.name); if let Err(e) = self.http_call(call_args, callout_context) { + debug!("error dispatching arch_fc request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); } } @@ -580,6 +590,7 @@ impl StreamContext { let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { Ok(arch_fc_response) => arch_fc_response, Err(e) => { + debug!("error deserializing arch_fc response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -693,6 +704,7 @@ impl StreamContext { match serde_json::to_string(&hallucination_classification_request) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing hallucination classification request: {}", error); return self.send_server_error(ServerError::Serialization(error), None); } }; @@ -789,13 +801,15 @@ impl StreamContext { ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { + debug!("upstream error response: {}", http_status); return self.send_server_error( ServerError::Upstream { - authority: callout_context.upstream_cluster.unwrap(), + host: callout_context.upstream_cluster.unwrap(), path: callout_context.upstream_cluster_path.unwrap(), - status: http_status, + status: http_status.clone(), + body: String::from_utf8(body).unwrap(), }, - None, + Some(StatusCode::from_str(http_status.as_str()).unwrap()), ); } } else { @@ -893,6 +907,7 @@ impl StreamContext { .prompt_guards .jailbreak_on_exception_message() .unwrap_or("refrain from discussing jailbreaking."); + debug!("jailbreak detected: {}", msg); return self.send_server_error( ServerError::Jailbreak(String::from(msg)), Some(StatusCode::BAD_REQUEST), @@ -916,6 +931,7 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { + debug!("error serializing get embeddings request: {}", error); return self.send_server_error(ServerError::Deserialization(error), None); } }; @@ -952,6 +968,7 @@ impl StreamContext { }; if let Err(e) = self.http_call(call_args, call_context) { + debug!("error dispatching get embeddings request: {}", e); self.send_server_error(ServerError::HttpDispatch(e), None); } } @@ -985,6 +1002,7 @@ impl StreamContext { let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { Ok(chat_completions_resp) => chat_completions_resp, Err(e) => { + debug!("error deserializing default target response: {}", e); return self.send_server_error(ServerError::Deserialization(e), None); } }; @@ -1259,9 +1277,8 @@ impl HttpContext for StreamContext { match serde_json::from_slice(&body) { Ok(de) => de, Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; + debug!("invalid response: {}, {}", String::from_utf8_lossy(&body), e); + return Action::Continue; } }; diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 5f27adc3..0338f23b 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -487,7 +487,6 @@ fn bad_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, From 62a000036e286ce04cf7ba9ae521a36905c1b83e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Ulises=20Ni=C3=B1o=20Rivera?= Date: Fri, 18 Oct 2024 13:15:19 -0700 Subject: [PATCH 4/5] Update arch Dockerfile (#200) --- arch/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arch/Dockerfile b/arch/Dockerfile index 3a875a62..073c0b6b 100644 --- a/arch/Dockerfile +++ b/arch/Dockerfile @@ -12,8 +12,8 @@ FROM envoyproxy/envoy:v1.31-latest as envoy #Build config generator, so that we have a single build image for both Rust and Python FROM python:3-slim as arch -COPY --from=builder /arch/prompt_gateway/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm -COPY --from=builder /arch/llm_gateway/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm +COPY --from=builder /arch/target/wasm32-wasi/release/prompt_gateway.wasm /etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm +COPY --from=builder /arch/target/wasm32-wasi/release/llm_gateway.wasm /etc/envoy/proxy-wasm-plugins/llm_gateway.wasm COPY --from=envoy /usr/local/bin/envoy /usr/local/bin/envoy WORKDIR /config COPY arch/requirements.txt . From dd1c7be706d72dc2f0e880148cd8eac4f270f1fb Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 18 Oct 2024 13:25:39 -0700 Subject: [PATCH 5/5] Pass tool call and app function response back in metadata (#193) --- chatbot_ui/.vscode/launch.json | 1 + chatbot_ui/app/run.py | 46 ++++++---- crates/common/src/common_types.rs | 3 + crates/common/src/consts.rs | 2 + crates/prompt_gateway/src/stream_context.rs | 84 +++++++++---------- crates/prompt_gateway/tests/integration.rs | 4 +- .../app/function_calling/model_utils.py | 75 ++++++----------- model_server/app/tests/test_state.py | 66 +++++++++++++++ 8 files changed, 169 insertions(+), 112 deletions(-) create mode 100644 model_server/app/tests/test_state.py diff --git a/chatbot_ui/.vscode/launch.json b/chatbot_ui/.vscode/launch.json index 47ee5a58..8b42a191 100644 --- a/chatbot_ui/.vscode/launch.json +++ b/chatbot_ui/.vscode/launch.json @@ -5,6 +5,7 @@ "version": "0.2.0", "configurations": [ { + "python": "${workspaceFolder}/venv/bin/python", "name": "chatbot-ui", "cwd": "${workspaceFolder}/app", "type": "debugpy", diff --git a/chatbot_ui/app/run.py b/chatbot_ui/app/run.py index f2e85231..02d6e01c 100644 --- a/chatbot_ui/app/run.py +++ b/chatbot_ui/app/run.py @@ -2,14 +2,21 @@ import json import os from openai import OpenAI, DefaultHttpxClient import gradio as gr -import logging as log +import logging from dotenv import load_dotenv load_dotenv() +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(levelname)s - %(message)s", +) + +log = logging.getLogger(__name__) + CHAT_COMPLETION_ENDPOINT = os.getenv("CHAT_COMPLETION_ENDPOINT") ARCH_STATE_HEADER = "x-arch-state" -log.info("CHAT_COMPLETION_ENDPOINT: ", CHAT_COMPLETION_ENDPOINT) +log.info(f"CHAT_COMPLETION_ENDPOINT: {CHAT_COMPLETION_ENDPOINT}") client = OpenAI( api_key="--", @@ -23,23 +30,19 @@ def predict(message, state): state["history"] = [] history = state.get("history") history.append({"role": "user", "content": message}) - log.info("history: ", history) + log.info(f"history: {history}") # Custom headers custom_headers = { "x-arch-deterministic-provider": "openai", } - metadata = None - if "arch_state" in state: - metadata = {ARCH_STATE_HEADER: state["arch_state"]} - try: raw_response = client.chat.completions.with_raw_response.create( model="--", messages=history, temperature=1.0, - metadata=metadata, + # metadata=metadata, extra_headers=custom_headers, ) except Exception as e: @@ -49,26 +52,35 @@ def predict(message, state): log.info("Error calling gateway API: {}".format(e.message)) raise gr.Error("Error calling gateway API: {}".format(e.message)) - log.info("raw_response: ", raw_response.text) + log.error(f"raw_response: {raw_response.text}") response = raw_response.parse() # extract arch_state from metadata and store it in gradio session state # this state must be passed back to the gateway in the next request response_json = json.loads(raw_response.text) - arch_state = None if response_json: - metadata = response_json.get("metadata", {}) - if metadata: - arch_state = metadata.get(ARCH_STATE_HEADER, None) - if arch_state: - state["arch_state"] = arch_state + # load arch_state from metadata + arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}") + # parse arch_state into json object + arch_state = json.loads(arch_state_str) + # load messages from arch_state + arch_messages_str = arch_state.get("messages", "[]") + # parse messages into json object + arch_messages = json.loads(arch_messages_str) + # append messages from arch gateway to history + for message in arch_messages: + history.append(message) content = response.choices[0].message.content history.append({"role": "assistant", "content": content, "model": response.model}) + + # for gradio UI we don't want to show raw tool calls and messages from developer application + # so we're filtering those out + history_view = [h for h in history if h["role"] != "tool" and "content" in h] messages = [ - (history[i]["content"], history[i + 1]["content"]) - for i in range(0, len(history) - 1, 2) + (history_view[i]["content"], history_view[i + 1]["content"]) + for i in range(0, len(history_view) - 1, 2) ] return messages, state diff --git a/crates/common/src/common_types.rs b/crates/common/src/common_types.rs index fb0f902c..c8f91e0f 100644 --- a/crates/common/src/common_types.rs +++ b/crates/common/src/common_types.rs @@ -188,6 +188,8 @@ pub mod open_ai { pub model: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -381,6 +383,7 @@ mod test { content: Some("What city do you want to know the weather for?".to_string()), model: None, tool_calls: None, + tool_call_id: None, }], tools: Some(vec![super::open_ai::ChatCompletionTool { tool_type: ToolType::Function, diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index ce119eab..fdc21aed 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -5,6 +5,8 @@ pub const DEFAULT_HALLUCINATED_THRESHOLD: f64 = 0.25; pub const RATELIMIT_SELECTOR_HEADER_KEY: &str = "x-arch-ratelimit-selector"; pub const SYSTEM_ROLE: &str = "system"; pub const USER_ROLE: &str = "user"; +pub const TOOL_ROLE: &str = "tool"; +pub const ASSISTANT_ROLE: &str = "assistant"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const ARC_FC_CLUSTER: &str = "arch_fc"; pub const ARCH_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index da4d344f..18463b49 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -3,7 +3,7 @@ use acap::cos; use common::common_types::open_ai::{ ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, - StreamOptions, ToolCall, ToolCallState, ToolType, + StreamOptions, ToolCall, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, @@ -14,9 +14,9 @@ use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, - CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, + ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, - REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, + REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, USER_ROLE, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, @@ -29,12 +29,12 @@ use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use serde_json::Value; -use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; use std::rc::Rc; use std::str::FromStr; use std::time::Duration; +use derivative::Derivative; use common::stats::IncrementingMetric; @@ -49,11 +49,13 @@ enum ResponseHandlerType { DefaultTarget, } -#[derive(Debug, Clone)] +#[derive(Clone, Derivative)] +#[derivative(Debug)] pub struct StreamCallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target_name: Option, + #[derivative(Debug = "ignore")] request_body: ChatCompletionsRequest, tool_calls: Option>, similarity_scores: Option>, @@ -306,6 +308,7 @@ impl StreamContext { content: Some(response), model: Some(ARCH_FC_MODEL_NAME.to_string()), tool_calls: None, + tool_call_id: None, }; let chat_completion_response = ChatCompletionsResponse { @@ -797,7 +800,7 @@ impl StreamContext { fn function_call_response_handler( &mut self, body: Vec, - mut callout_context: StreamCallContext, + callout_context: StreamCallContext, ) { if let Some(http_status) = self.get_http_call_response_header(":status") { if http_status != StatusCode::OK.as_str() { @@ -841,11 +844,18 @@ impl StreamContext { content: system_prompt, model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } - messages.append(callout_context.request_body.messages.as_mut()); + // don't send tools message and api response to chat gpt + for m in callout_context.request_body.messages.iter() { + if m.role == TOOL_ROLE || m.content.is_none() { + continue; + } + messages.push(m.clone()); + } let user_message = match messages.pop() { Some(user_message) => user_message, @@ -872,6 +882,7 @@ impl StreamContext { content: Some(final_prompt), model: None, tool_calls: None, + tool_call_id: None, } }); @@ -1022,6 +1033,7 @@ impl StreamContext { content: Some(system_prompt.clone()), model: None, tool_calls: None, + tool_call_id: None, }; messages.push(system_prompt_message); } @@ -1032,6 +1044,7 @@ impl StreamContext { content: Some(api_resp.clone()), model: None, tool_calls: None, + tool_call_id: None, }); let chat_completion_request = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), @@ -1296,55 +1309,42 @@ impl HttpContext for StreamContext { self.arch_state = Some(Vec::new()); } - // compute sha hash from message history - let mut hasher = Sha256::new(); - let prompts: Vec = self - .chat_completions_request - .as_ref() - .unwrap() - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .map(|msg| msg.content.clone().unwrap()) - .collect(); - let prompts_merged = prompts.join("#.#"); - hasher.update(prompts_merged.clone()); - let hash_key = hasher.finalize(); - // conver hash to hex string - let hash_key_str = format!("{:x}", hash_key); - debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); - - // create new tool call state - let tool_call_state = ToolCallState { - key: hash_key_str, - message: self.user_prompt.clone(), - tool_call: tool_calls[0].function.clone(), - tool_response: self.tool_call_response.clone().unwrap(), - }; - - // push tool call state to arch state - self.arch_state - .as_mut() - .unwrap() - .push(ArchState::ToolCall(vec![tool_call_state])); - let mut data: Value = serde_json::from_slice(&body).unwrap(); // use serde::Value to manipulate the json object and ensure that we don't lose any data if let Value::Object(ref mut map) = data { // serialize arch state and add to metadata - let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); - debug!("arch_state: {}", arch_state_str); let metadata = map .entry("metadata") .or_insert(Value::Object(serde_json::Map::new())); if metadata == &Value::Null { *metadata = Value::Object(serde_json::Map::new()); } + + // since arch gateway generates tool calls (using arch-fc) and calls upstream api to + // get response, we will send these back to developer so they can see the api response + // and tool call arch-fc generated + let mut fc_messages = Vec::new(); + fc_messages.push(Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: Some(ARCH_FC_MODEL_NAME.to_string()), + tool_calls: self.tool_calls.clone(), + tool_call_id: None, + }); + fc_messages.push(Message { + role: TOOL_ROLE.to_string(), + content: self.tool_call_response.clone(), + model: None, + tool_calls: None, + tool_call_id: Some(self.tool_calls.as_ref().unwrap()[0].id.clone()), + }); + let fc_messages_str = serde_json::to_string(&fc_messages).unwrap(); + let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]); + let arch_state_str = serde_json::to_string(&arch_state).unwrap(); metadata.as_object_mut().unwrap().insert( ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), ); - let data_serialized = serde_json::to_string(&data).unwrap(); debug!("arch => user: {}", data_serialized); self.set_http_response_body(0, body_size, data_serialized.as_bytes()); diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 0338f23b..14ca1aa2 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -546,6 +546,7 @@ fn request_to_llm_gateway() { }, }]), model: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -647,6 +648,7 @@ fn request_to_llm_gateway() { content: Some("hello from fake llm gateway".to_string()), model: None, tool_calls: None, + tool_call_id: None, }, }], model: String::from("test"), @@ -665,8 +667,6 @@ fn request_to_llm_gateway() { .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) .expect_log(Some(LogLevel::Debug), None) .execute_and_expect(ReturnType::Action(Action::Continue)) diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index 3e4e6654..04078a1b 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -13,62 +13,38 @@ logger = get_model_server_logger() class Message(BaseModel): role: str - content: str + content: str = "" + tool_calls: List[Dict[str, Any]] = [] + tool_call_id: str = "" class ChatMessage(BaseModel): messages: list[Message] tools: List[Dict[str, Any]] - # TODO: make it default none - metadata: Dict[str, str] = {} - -def process_state(arch_state, history: list[Message]): - logger.info("state: {}".format(arch_state)) - state_json = json.loads(arch_state) - - state_map = {} - if state_json: - for tools_state in state_json: - for tool_state in tools_state: - state_map[tool_state["key"]] = tool_state - - logger.info(f"state_map: {json.dumps(state_map)}") - - sha_history = [] +def process_messages(history: list[Message]): updated_history = [] for hist in history: - updated_history.append({"role": hist.role, "content": hist.content}) - if hist.role == "user": - sha_history.append(hist.content) - sha256_hash = hashlib.sha256() - joined_key_str = ("#.#").join(sha_history) - sha256_hash.update(joined_key_str.encode()) - sha_key = sha256_hash.hexdigest() - logger.info(f"sha_key: {sha_key}") - if sha_key in state_map: - tool_call_state = state_map[sha_key] - if "tool_call" in tool_call_state: - tool_call_str = json.dumps(tool_call_state["tool_call"]) - updated_history.append( - { - "role": "assistant", - "content": f"\n{tool_call_str}\n", - } - ) - if "tool_response" in tool_call_state: - tool_resp = tool_call_state["tool_response"] - # TODO: try with role = user as well - updated_history.append( - { - "role": "user", - "content": f"\n{tool_resp}\n", - } - ) - # we dont want to match this state with any other messages - del state_map[sha_key] - + if hist.tool_calls: + if len(hist.tool_calls) > 1: + raise ValueError("Only one tool call is supported") + tool_call_str = json.dumps(hist.tool_calls[0]["function"]) + updated_history.append( + { + "role": "assistant", + "content": f"\n{tool_call_str}\n", + } + ) + elif hist.role == "tool": + updated_history.append( + { + "role": "user", + "content": f"\n{hist.content}\n", + } + ) + else: + updated_history.append({"role": hist.role, "content": hist.content}) return updated_history @@ -79,10 +55,7 @@ async def chat_completion(req: ChatMessage, res: Response): messages = [{"role": "system", "content": tools_encoded}] - metadata = req.metadata - arch_state = metadata.get("x-arch-state", "[]") - - updated_history = process_state(arch_state, req.messages) + updated_history = process_messages(req.messages) for message in updated_history: messages.append({"role": message["role"], "content": message["content"]}) diff --git a/model_server/app/tests/test_state.py b/model_server/app/tests/test_state.py new file mode 100644 index 00000000..9eb72c8c --- /dev/null +++ b/model_server/app/tests/test_state.py @@ -0,0 +1,66 @@ +from typing import List +import pytest +import json +from app.function_calling.model_utils import Message, process_messages + +test_input_history = """ +[ + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "model": "Arch-Function-1.5B", + "tool_calls": [ + { + "id": "call_3394", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + }, + { + "role": "tool", + "content": "--", + "tool_call_id": "call_3394" + }, + { + "role": "assistant", + "content": "--", + "model": "gpt-3.5-turbo-0125" + }, + { + "role": "user", + "content": "how is the weather in chicago for next 5 days?" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_5306", + "type": "function", + "function": { + "name": "weather_forecast", + "arguments": { "city": "Chicago", "days": 5 } + } + } + ] + } + ] + """ + + +def test_update_fc_history(): + history = json.loads(test_input_history) + message_history = [] + for h in history: + message_history.append(Message(**h)) + + updated_history = process_messages(message_history) + assert len(updated_history) == 6 + # ensure that tool role does not exist anymore + assert all([h["role"] != "tool" for h in updated_history])