diff --git a/.github/workflows/e2e_archgw.yml b/.github/workflows/e2e_archgw.yml index 5897e5f2..64454d57 100644 --- a/.github/workflows/e2e_archgw.yml +++ b/.github/workflows/e2e_archgw.yml @@ -24,7 +24,7 @@ jobs: - name: build arch docker image run: | - cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.2.8 + cd ../../ && docker build -f arch/Dockerfile . -t katanemo/archgw -t katanemo/archgw:0.2.8 -t katanemo/archgw:latest - name: start archgw env: diff --git a/crates/Cargo.lock b/crates/Cargo.lock index b2501deb..545cd2fe 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -205,6 +205,7 @@ dependencies = [ "opentelemetry-otlp", "opentelemetry-stdout", "opentelemetry_sdk", + "pretty_assertions", "reqwest", "serde", "serde_json", diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 1ad17dc8..1bccb230 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -19,6 +19,7 @@ opentelemetry-http = "0.29.0" opentelemetry-otlp = "0.29.0" opentelemetry-stdout = "0.29.0" opentelemetry_sdk = "0.29.0" +pretty_assertions = "1.4.1" reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 6d1cb7fb..f4e44c26 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,6 +1,6 @@ use common::{ api::open_ai::{ChatCompletionsRequest, Message}, - consts::USER_ROLE, + consts::{SYSTEM_ROLE, USER_ROLE}, }; use serde::{Deserialize, Serialize}; use tracing::info; @@ -63,12 +63,19 @@ struct LlmRouterResponse { impl RouterModel for RouterModelV1 { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { + let messages_str = messages + .iter() + .filter(|m| m.role != SYSTEM_ROLE) + .map(|m| { + let content_json_str = serde_json::to_string(&m.content).unwrap_or_default(); + format!("{}: {}", m.role, content_json_str) + }) + .collect::>() + .join("\n"); + let message = ARCH_ROUTER_V1_SYSTEM_PROMPT .replace("{routes}", &self.llm_providers_with_usage_yaml) - .replace( - "{conversation}", - &serde_json::to_string_pretty(messages).unwrap(), - ); + .replace("{conversation}", messages_str.as_str()); ChatCompletionsRequest { model: self.routing_model.clone(), @@ -138,3 +145,81 @@ impl std::fmt::Debug for dyn RouterModel { write!(f, "RouterModel") } } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn test_system_prompt_format() { + let expected_prompt = r#" +You are an advanced Routing Assistant designed to select the optimal route based on user requests. +Your task is to analyze conversations and match them to the most appropriate predefined route. +Review the available routes config: + +# ROUTES CONFIG START +route1: description1 +route2: description2 +# ROUTES CONFIG END + +Examine the following conversation between a user and an assistant: + +# CONVERSATION START +user: "Hello, I want to book a flight." +assistant: "Sure, where would you like to go?" +user: "seattle" +# CONVERSATION END + +Your goal is to identify the most appropriate route that matches the user's LATEST intent. Follow these steps: + +1. Carefully read and analyze the provided conversation, focusing on the user's latest request and the conversation scenario. +2. Check if the user's request and scenario matches any of the routes in the routing configuration (focus on the description). +3. Find the route that best matches. +4. Use context clues from the entire conversation to determine the best fit. +5. Return the best match possible. You only response the name of the route that best matches the user's request, use the exact name in the routes config. +6. If no route relatively close to matches the user's latest intent or user last message is thank you or greeting, return an empty route ''. + +# OUTPUT FORMAT +Your final output must follow this JSON format: +{ + "route": "route_name" # The matched route name, or empty string '' if no match +} + +Based on your analysis, provide only the JSON object as your final output with no additional text, explanations, or whitespace. +"#; + + let routes_yaml = "route1: description1\nroute2: description2"; + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(routes_yaml.to_string(), routing_model.clone()); + + let messages = vec![ + Message { + role: "system".to_string(), + content: Some("You are a helpful assistant.".to_string()), + ..Default::default() + }, + Message { + role: "user".to_string(), + content: Some("Hello, I want to book a flight.".to_string()), + ..Default::default() + }, + Message { + role: "assistant".to_string(), + content: Some("Sure, where would you like to go?".to_string()), + ..Default::default() + }, + Message { + role: "user".to_string(), + content: Some("seattle".to_string()), + ..Default::default() + }, + ]; + + let req = router.generate_request(&messages); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt); + } +} diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index d71b0d58..7b3cf66c 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -171,6 +171,18 @@ pub struct Message { pub tool_call_id: Option, } +impl Default for Message { + fn default() -> Self { + Message { + role: ASSISTANT_ROLE.to_string(), + content: None, + model: None, + tool_calls: None, + tool_call_id: None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Choice { pub finish_reason: Option,