diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 1bd44498..1c0e8905 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -12,7 +12,7 @@ use hyper::{Request, Response, StatusCode}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; -use tracing::{debug, info, warn}; +use tracing::{debug, info, trace, warn}; use crate::router::llm_router::RouterService; @@ -47,7 +47,7 @@ pub async fn chat_completions( } }; - debug!( + trace!( "arch-router request body: {}", &serde_json::to_string(&chat_completion_request).unwrap() ); diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs index 9dd68dd1..a9c5a65d 100644 --- a/crates/brightstaff/src/handlers/preferences.rs +++ b/crates/brightstaff/src/handlers/preferences.rs @@ -14,7 +14,8 @@ pub async fn list_preferences( let providers_with_usage = prov .iter() .map(|provider| ModelUsagePreference { - model: provider.name.clone(), + name: provider.name.clone(), + model: provider.model.clone().unwrap_or_default(), usage: provider.usage.clone(), }) .collect::>(); @@ -101,7 +102,8 @@ pub async fn update_preferences( if let Some(usage_provider) = usage_model_map.get(&provider.name) { provider.usage = usage_provider.usage.clone(); updated_models_list.push(ModelUsagePreference { - model: provider.name.clone(), + name: provider.name.clone(), + model: provider.model.clone().unwrap_or_default(), usage: provider.usage.clone(), }); } diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index c72b19e9..78d634d5 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use common::{ configuration::{LlmProvider, LlmRoute, ModelUsagePreference}, @@ -19,6 +19,7 @@ pub struct RouterService { router_model: Arc, routing_model_name: String, llm_usage_defined: bool, + llm_provider_map: HashMap, } #[derive(Debug, Error)] @@ -55,12 +56,18 @@ impl RouterService { router_model_v1::MAX_TOKEN_LEN, )); + let llm_provider_map: HashMap = providers + .into_iter() + .map(|provider| (provider.name.clone(), provider)) + .collect(); + RouterService { router_url, client: reqwest::Client::new(), router_model, routing_model_name, llm_usage_defined: !providers_with_usage.is_empty(), + llm_provider_map, } } @@ -76,7 +83,7 @@ impl RouterService { let router_request = self .router_model - .generate_request(messages, usage_preferences); + .generate_request(messages, &usage_preferences); info!( "sending request to arch-router model: {}, endpoint: {}", @@ -147,13 +154,40 @@ impl RouterService { if let Some(ContentType::Text(content)) = &chat_completion_response.choices[0].message.content { + let mut selected_model: Option = None; + if let Some(selected_llm_name) = self.router_model.parse_response(content)? { + if selected_llm_name != "other" { + if let Some(usage_preferences) = usage_preferences { + for usage in usage_preferences { + if usage.name == selected_llm_name { + selected_model = Some(usage.model); + break; + } + } + if selected_model.is_none() { + warn!( + "Selected LLM model not found in usage preferences: {}", + selected_llm_name + ); + } + } else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) { + selected_model = provider.model.clone(); + } else { + warn!( + "Selected LLM model not found in provider map: {}", + selected_llm_name + ); + } + } + } info!( - "router response: {}, response time: {}ms", + "router response: {}, selected_model: {:?}, response time: {}ms", content.replace("\n", "\\n"), + selected_model, router_response_time.as_millis() ); - let selected_llm = self.router_model.parse_response(content)?; - Ok(selected_llm) + + Ok(selected_model) } else { Ok(None) } diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index b377b3a3..dafa8776 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -14,7 +14,7 @@ pub trait RouterModel: Send + Sync { fn generate_request( &self, messages: &[Message], - usage_preferences: Option>, + usage_preferences: &Option>, ) -> ChatCompletionsRequest; fn parse_response(&self, content: &str) -> Result>; fn get_model_name(&self) -> String; diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 7dd57223..e6ccd912 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -58,7 +58,7 @@ impl RouterModel for RouterModelV1 { fn generate_request( &self, messages: &[Message], - usage_preferences: Option>, + usage_preferences: &Option>, ) -> ChatCompletionsRequest { // remove system prompt, tool calls, tool call response and messages without content // if content is empty its likely a tool call @@ -137,7 +137,16 @@ impl RouterModel for RouterModelV1 { let llm_route_json = usage_preferences .as_ref() - .map(|prefs| serde_json::to_string(prefs).unwrap_or_default()) + .map(|prefs| { + let llm_route: Vec = prefs + .iter() + .map(|pref| LlmRoute { + name: pref.name.clone(), + description: pref.usage.clone().unwrap_or_default(), + }) + .collect(); + serde_json::to_string(&llm_route).unwrap_or_default() + }) .unwrap_or_else(|| self.llm_route_json_str.clone()); let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT @@ -268,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt.to_string()); + } + + #[test] + fn test_system_prompt_format_usage_preferences() { + let expected_prompt = r#" +You are a helpful assistant designed to find the best suited route. +You are provided with route description within XML tags: + +[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}] + + + +[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}] + + +Your task is to decide which route is best suit with user intent on the conversation in XML tags. Follow the instruction: +1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. +2. You must analyze the route descriptions and find the best match route for user latest intent. +3. You only response the name of the route that best matches the user's request, use the exact name in the . + +Based on your analysis, provide your response in the following JSON formats if you decide to match any route: +{"route": "route_name"} +"#; + let routes_str = r#" + [ + {"name": "Image generation", "description": "generating image"}, + {"name": "image conversion", "description": "convert images to provided format"}, + {"name": "image search", "description": "search image"}, + {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, + {"name": "Speech Recognition", "description": "Converting spoken language into written text"} + ] + "#; + let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); + + let conversation_str = r#" + [ + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" + } + ] + "#; + let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + + let usage_preferences = Some(vec![ModelUsagePreference { + name: "code-generation".to_string(), + model: "claude/claude-3-7-sonnet".to_string(), + usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()), + }]); + let req = router.generate_request(&conversation, &usage_preferences); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -329,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -390,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -459,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -529,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -625,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation, None); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index f46b6cc6..44f1474e 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -180,6 +180,7 @@ impl Display for LlmProviderType { #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { + pub name: String, pub model: String, pub usage: Option, }