In request path use same format for usage preferences as arch_config (#533)

This commit is contained in:
Adil Hafeez 2025-07-21 18:31:19 -07:00 committed by GitHub
parent 79a62fffe8
commit d341f4365b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 83 additions and 187 deletions

View file

@ -73,7 +73,7 @@ impl RouterModel for RouterModelV1 {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
usage_preferences_from_request: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
@ -150,31 +150,17 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<Message>>();
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| {
let llm_route: Vec<RoutingPreference> = prefs
.iter()
.map(|pref| RoutingPreference {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
})
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &llm_route_json)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
);
// Generate the router request message based on the usage preferences.
// If preferences are passed in request then we use them otherwise we use the default routing model preferences.
let router_message = match convert_to_router_preferences(usage_preferences_from_request) {
Some(prefs) => generate_router_message(&prefs, &selected_conversation_list),
None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list),
};
ChatCompletionsRequest {
model: self.routing_model.clone(),
messages: vec![Message {
content: Some(ContentType::Text(messages_content)),
content: Some(ContentType::Text(router_message)),
role: USER_ROLE.to_string(),
}],
temperature: Some(0.01),
@ -201,12 +187,18 @@ impl RouterModel for RouterModelV1 {
if let Some(usage_preferences) = usage_preferences {
// If usage preferences are defined, we need to find the model that matches the selected route
let matching_preference = usage_preferences
let model_name: Option<String> = usage_preferences
.iter()
.find(|pref| pref.name == selected_route);
.map(|pref| {
pref.routing_preferences
.iter()
.find(|routing_pref| routing_pref.name == selected_route)
.map(|_| pref.model.clone())
})
.find_map(|model| model);
if let Some(preference) = matching_preference {
return Ok(Some((selected_route, preference.model.clone())));
if let Some(model_name) = model_name {
return Ok(Some((selected_route, model_name)));
} else {
warn!(
"No matching model found for route: {}, usage preferences: {:?}",
@ -216,7 +208,7 @@ impl RouterModel for RouterModelV1 {
}
}
// If no usage preferences are defined, we return the route with the routing model
// If no usage preferences are passed in request then use the default routing model preferences
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
return Ok(Some((selected_route, model)));
}
@ -234,6 +226,37 @@ impl RouterModel for RouterModelV1 {
}
}
fn generate_router_message(prefs: &str, selected_conversation_list: &Vec<Message>) -> String {
ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", prefs)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
)
}
fn convert_to_router_preferences(
prefs_from_request: &Option<Vec<ModelUsagePreference>>,
) -> Option<String> {
if let Some(usage_preferences) = prefs_from_request {
let routing_preferences = usage_preferences
.iter()
.flat_map(|pref| {
pref.routing_preferences
.iter()
.map(|routing_pref| RoutingPreference {
name: routing_pref.name.clone(),
description: routing_pref.description.clone(),
})
})
.collect::<Vec<RoutingPreference>>();
return Some(serde_json::to_string(&routing_preferences).unwrap_or_default());
}
None
}
fn fix_json_response(body: &str) -> String {
let mut updated_body = body.to_string();
@ -299,7 +322,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -356,7 +380,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -379,9 +404,11 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
routing_preferences: vec![RoutingPreference {
name: "code-generation".to_string(),
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
}],
}]);
let req = router.generate_request(&conversation, &usage_preferences);
@ -419,7 +446,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
@ -478,7 +506,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
@ -538,7 +567,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
@ -604,7 +634,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -672,7 +703,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -747,14 +779,18 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "Image generation"}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
@ -784,11 +820,17 @@ Based on your analysis, provide your response in the following JSON formats if y
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'Image generation'}\\n";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 7: Code block marker
let input = "```json\n{\"route\": \"Image generation\"}\n```";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
}
}