pass model name in header when a route is selected when using usage preferences (#531)

This commit is contained in:
Adil Hafeez 2025-07-17 13:41:58 -07:00 committed by GitHub
parent 2340a45353
commit f819ee3507
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 150 additions and 117 deletions

View file

@ -104,7 +104,7 @@ pub async fn chat_completions(
debug!("usage preferences from request: {:?}", usage_preferences);
let mut determined_route = match router_service
let determined_model = match router_service
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
@ -121,14 +121,17 @@ pub async fn chat_completions(
}
};
if determined_route.is_none() {
debug!("No LLM model selected, using default from request");
determined_route = Some(chat_completion_request.model.clone());
}
info!(
"sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, determined_route
"sending request to llm provider: {} determined_model: {:?}, model from request: {}",
llm_provider_endpoint, determined_model, chat_completion_request.model
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(
&determined_model.unwrap_or(chat_completion_request.model.clone()),
)
.unwrap(),
);
if let Some(trace_parent) = trace_parent {
@ -138,13 +141,6 @@ pub async fn chat_completions(
);
}
if let Some(selected_route) = determined_route {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_route).unwrap(),
);
}
let chat_request_parsed_bytes =
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
@ -48,9 +48,14 @@ impl RouterService {
.cloned()
.collect::<Vec<LlmProvider>>();
let llm_routes: Vec<RoutingPreference> = providers_with_usage
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
.iter()
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
})
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
@ -151,21 +156,22 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let route_name = self.router_model.parse_response(content)?;
let parsed_response = self
.router_model
.parse_response(content, &usage_preferences)?;
info!(
"router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
route_name,
parsed_response,
router_response_time.as_millis()
);
if let Some(ref route) = route_name {
if route == "other" {
return Ok(None);
}
if let Some(ref route) = parsed_response {
// return model name if route is found
return Ok(Some(route.1.clone()));
}
Ok(route_name)
Ok(None)
} else {
Ok(None)
}

View file

@ -16,6 +16,10 @@ pub trait RouterModel: Send + Sync {
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,3 +1,5 @@
use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
@ -32,21 +34,30 @@ Based on your analysis, provide your response in the following JSON formats if y
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub struct RouterModelV1 {
llm_route_json_str: String,
llm_route_to_model_map: HashMap<String, String>,
routing_model: String,
max_token_length: usize,
}
impl RouterModelV1 {
pub fn new(
llm_routes: Vec<RoutingPreference>,
llm_routes: HashMap<String, Vec<RoutingPreference>>,
routing_model: String,
max_token_length: usize,
) -> Self {
let llm_route_values: Vec<RoutingPreference> =
llm_routes.values().flatten().cloned().collect();
let llm_route_json_str =
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
let llm_route_to_model_map: HashMap<String, String> = llm_routes
.iter()
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
.collect();
RouterModelV1 {
routing_model,
max_token_length,
llm_route_json_str,
llm_route_to_model_map,
}
}
}
@ -171,20 +182,51 @@ impl RouterModel for RouterModelV1 {
}
}
fn parse_response(&self, content: &str) -> Result<Option<String>> {
fn parse_response(
&self,
content: &str,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>> {
if content.is_empty() {
return Ok(None);
}
let router_resp_fixed = fix_json_response(content);
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
let selected_llm = router_response.route.unwrap_or_default().to_string();
let selected_route = router_response.route.unwrap_or_default().to_string();
if selected_llm.is_empty() {
if selected_route.is_empty() || selected_route == "other" {
return Ok(None);
}
Ok(Some(selected_llm))
if let Some(usage_preferences) = usage_preferences {
// If usage preferences are defined, we need to find the model that matches the selected route
let matching_preference = usage_preferences
.iter()
.find(|pref| pref.name == selected_route);
if let Some(preference) = matching_preference {
return Ok(Some((selected_route, preference.model.clone())));
} else {
warn!(
"No matching model found for route: {}, usage preferences: {:?}",
selected_route, usage_preferences
);
return Ok(None);
}
}
// If no usage preferences are defined, we return the route with the routing model
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
return Ok(Some((selected_route, model)));
}
warn!(
"No model found for route: {}, router model preferences: {:?}",
selected_route, self.llm_route_to_model_map
);
Ok(None)
}
fn get_model_name(&self) -> String {
@ -235,7 +277,7 @@ mod tests {
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -251,15 +293,13 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -310,15 +350,13 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -358,7 +396,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -375,15 +413,13 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
@ -419,7 +455,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -436,15 +472,14 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
@ -480,7 +515,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -497,15 +532,13 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
@ -549,7 +582,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -565,15 +598,13 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -619,7 +650,7 @@ Based on your analysis, provide your response in the following JSON formats if y
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
[{"name":"Image generation","description":"generating image"}]
</routes>
<conversation>
@ -635,15 +666,13 @@ Based on your analysis, provide your response in the following JSON formats if y
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -712,56 +741,54 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_parse_response() {
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
{
"gpt-4o": [
{"name": "Image generation", "description": "generating image"}
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "route1"}"#;
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
let input = r#"{"route": "Image generation"}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 3: Valid JSON with null route
let input = r#"{"route": null}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4: JSON missing route field
let input = r#"{}"#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 4.1: empty string
let input = r#""#;
let result = router.parse_response(input).unwrap();
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, None);
// Case 5: Malformed JSON
let input = r#"{"route": "route1""#; // missing closing }
let result = router.parse_response(input);
let result = router.parse_response(input, &None);
assert!(result.is_err());
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'route2'}\\n";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route2".to_string()));
let input = "{'route': 'Image generation'}\\n";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
// Case 7: Code block marker
let input = "```json\n{\"route\": \"route1\"}\n```";
let result = router.parse_response(input).unwrap();
assert_eq!(result, Some("route1".to_string()));
let input = "```json\n{\"route\": \"Image generation\"}\n```";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
}
}