diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 89c9ee13..614acb66 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -104,7 +104,7 @@ pub async fn chat_completions( debug!("usage preferences from request: {:?}", usage_preferences); - let mut determined_route = match router_service + let determined_model = match router_service .determine_route( &chat_completion_request.messages, trace_parent.clone(), @@ -121,14 +121,17 @@ pub async fn chat_completions( } }; - if determined_route.is_none() { - debug!("No LLM model selected, using default from request"); - determined_route = Some(chat_completion_request.model.clone()); - } - info!( - "sending request to llm provider: {} with llm model: {:?}", - llm_provider_endpoint, determined_route + "sending request to llm provider: {} determined_model: {:?}, model from request: {}", + llm_provider_endpoint, determined_model, chat_completion_request.model + ); + + request_headers.insert( + ARCH_PROVIDER_HINT_HEADER, + header::HeaderValue::from_str( + &determined_model.unwrap_or(chat_completion_request.model.clone()), + ) + .unwrap(), ); if let Some(trace_parent) = trace_parent { @@ -138,13 +141,6 @@ pub async fn chat_completions( ); } - if let Some(selected_route) = determined_route { - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&selected_route).unwrap(), - ); - } - let chat_request_parsed_bytes = serde_json::to_string(&chat_request_user_preferences_removed).unwrap(); diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index c1320c66..a20bf7c4 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use common::{ configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, @@ -48,9 +48,14 @@ impl RouterService { .cloned() .collect::>(); - let llm_routes: Vec = providers_with_usage + let llm_routes: HashMap> = providers_with_usage .iter() - .flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default()) + .filter_map(|provider| { + provider + .routing_preferences + .as_ref() + .map(|prefs| (provider.name.clone(), prefs.clone())) + }) .collect(); let router_model = Arc::new(router_model_v1::RouterModelV1::new( @@ -151,21 +156,22 @@ impl RouterService { if let Some(ContentType::Text(content)) = &chat_completion_response.choices[0].message.content { - let route_name = self.router_model.parse_response(content)?; + let parsed_response = self + .router_model + .parse_response(content, &usage_preferences)?; info!( "router response: {}, selected_model: {:?}, response time: {}ms", content.replace("\n", "\\n"), - route_name, + parsed_response, router_response_time.as_millis() ); - if let Some(ref route) = route_name { - if route == "other" { - return Ok(None); - } + if let Some(ref route) = parsed_response { + // return model name if route is found + return Ok(Some(route.1.clone())); } - Ok(route_name) + Ok(None) } else { Ok(None) } diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index dafa8776..ec0c1a1f 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -16,6 +16,10 @@ pub trait RouterModel: Send + Sync { messages: &[Message], usage_preferences: &Option>, ) -> ChatCompletionsRequest; - fn parse_response(&self, content: &str) -> Result>; + fn parse_response( + &self, + content: &str, + usage_preferences: &Option>, + ) -> Result>; fn get_model_name(&self) -> String; } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 0dcefff6..dc0e1563 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use common::{ configuration::{ModelUsagePreference, RoutingPreference}, consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, @@ -32,21 +34,30 @@ Based on your analysis, provide your response in the following JSON formats if y pub type Result = std::result::Result; pub struct RouterModelV1 { llm_route_json_str: String, + llm_route_to_model_map: HashMap, routing_model: String, max_token_length: usize, } impl RouterModelV1 { pub fn new( - llm_routes: Vec, + llm_routes: HashMap>, routing_model: String, max_token_length: usize, ) -> Self { + let llm_route_values: Vec = + llm_routes.values().flatten().cloned().collect(); let llm_route_json_str = - serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string()); + serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string()); + let llm_route_to_model_map: HashMap = llm_routes + .iter() + .flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone()))) + .collect(); + RouterModelV1 { routing_model, max_token_length, llm_route_json_str, + llm_route_to_model_map, } } } @@ -171,20 +182,51 @@ impl RouterModel for RouterModelV1 { } } - fn parse_response(&self, content: &str) -> Result> { + fn parse_response( + &self, + content: &str, + usage_preferences: &Option>, + ) -> Result> { if content.is_empty() { return Ok(None); } let router_resp_fixed = fix_json_response(content); let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?; - let selected_llm = router_response.route.unwrap_or_default().to_string(); + let selected_route = router_response.route.unwrap_or_default().to_string(); - if selected_llm.is_empty() { + if selected_route.is_empty() || selected_route == "other" { return Ok(None); } - Ok(Some(selected_llm)) + if let Some(usage_preferences) = usage_preferences { + // If usage preferences are defined, we need to find the model that matches the selected route + let matching_preference = usage_preferences + .iter() + .find(|pref| pref.name == selected_route); + + if let Some(preference) = matching_preference { + return Ok(Some((selected_route, preference.model.clone()))); + } else { + warn!( + "No matching model found for route: {}, usage preferences: {:?}", + selected_route, usage_preferences + ); + return Ok(None); + } + } + + // If no usage preferences are defined, we return the route with the routing model + if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() { + return Ok(Some((selected_route, model))); + } + + warn!( + "No model found for route: {}, router model preferences: {:?}", + selected_route, self.llm_route_to_model_map + ); + + Ok(None) } fn get_model_name(&self) -> String { @@ -235,7 +277,7 @@ mod tests { You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -251,15 +293,13 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -310,15 +350,13 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -358,7 +396,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -375,15 +413,13 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235); @@ -419,7 +455,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -436,15 +472,14 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200); @@ -480,7 +515,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -497,15 +532,13 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230); @@ -549,7 +582,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -565,15 +598,13 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -619,7 +650,7 @@ Based on your analysis, provide your response in the following JSON formats if y You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: -[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] +[{"name":"Image generation","description":"generating image"}] @@ -635,15 +666,13 @@ Based on your analysis, provide your response in the following JSON formats if y {"route": "route_name"} "#; let routes_str = r#" - [ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} - ] + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } "#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -712,56 +741,54 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_parse_response() { let routes_str = r#" -[ - {"name": "Image generation", "description": "generating image"}, - {"name": "image conversion", "description": "convert images to provided format"}, - {"name": "image search", "description": "search image"}, - {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, - {"name": "Speech Recognition", "description": "Converting spoken language into written text"} -] -"#; - let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + { + "gpt-4o": [ + {"name": "Image generation", "description": "generating image"} + ] + } + "#; + let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000); // Case 1: Valid JSON with non-empty route - let input = r#"{"route": "route1"}"#; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route1".to_string())); + let input = r#"{"route": "Image generation"}"#; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); // Case 2: Valid JSON with empty route let input = r#"{"route": ""}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 3: Valid JSON with null route let input = r#"{"route": null}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 4: JSON missing route field let input = r#"{}"#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 4.1: empty string let input = r#""#; - let result = router.parse_response(input).unwrap(); + let result = router.parse_response(input, &None).unwrap(); assert_eq!(result, None); // Case 5: Malformed JSON let input = r#"{"route": "route1""#; // missing closing } - let result = router.parse_response(input); + let result = router.parse_response(input, &None); assert!(result.is_err()); // Case 6: Single quotes and \n in JSON - let input = "{'route': 'route2'}\\n"; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route2".to_string())); + let input = "{'route': 'Image generation'}\\n"; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); // Case 7: Code block marker - let input = "```json\n{\"route\": \"route1\"}\n```"; - let result = router.parse_response(input).unwrap(); - assert_eq!(result, Some("route1".to_string())); + let input = "```json\n{\"route\": \"Image generation\"}\n```"; + let result = router.parse_response(input, &None).unwrap(); + assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); } }