model routing: cost/latency ranking with ranked fallback list (#849)

This commit is contained in:
Adil Hafeez 2026-03-30 13:46:52 -07:00 committed by GitHub
parent 3a531ce22a
commit e5751d6b13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 1524 additions and 317 deletions

View file

@ -119,7 +119,7 @@ async fn llm_chat_inner(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
} = parsed;
@ -261,7 +261,7 @@ async fn llm_chat_inner(
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
inline_routing_preferences,
)
.await
}
@ -323,7 +323,7 @@ struct PreparedRequest {
temperature: Option<f32>,
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<common::configuration::TopLevelRoutingPreference>>,
client_api: Option<SupportedAPIsFromClient>,
provider_id: hermesllm::ProviderId,
}
@ -352,16 +352,14 @@ async fn parse_and_validate_request(
"request body received"
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
},
)?;
// Extract routing_preferences from request body if present
let (chat_request_bytes, inline_routing_preferences) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes).map_err(|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
})?;
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
warn!(path = %request_path, "unsupported endpoint");
@ -439,7 +437,7 @@ async fn parse_and_validate_request(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
})

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference;
use common::configuration::TopLevelRoutingPreference;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hermesllm::ProviderRequestType;
use hyper::StatusCode;
use std::sync::Arc;
use tracing::{debug, info, warn};
@ -10,7 +10,10 @@ use crate::streaming::truncate_message;
use crate::tracing::routing;
pub struct RoutingResult {
/// Primary model to use (first in the ranked list).
pub model_name: String,
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
pub models: Vec<String>,
pub route_name: Option<String>,
}
@ -39,11 +42,8 @@ pub async fn router_chat_get_upstream_model(
traceparent: &str,
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
@ -78,22 +78,6 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Use inline preferences if provided, otherwise fall back to metadata extraction
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
{
inline_usage_preferences
} else {
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok())
};
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
.messages
@ -107,7 +91,6 @@ pub async fn router_chat_get_upstream_model(
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
info!(
has_usage_preferences = usage_preferences.is_some(),
path = %request_path,
latest_message = %latest_message_for_log,
"processing router request"
@ -121,7 +104,7 @@ pub async fn router_chat_get_upstream_model(
.determine_route(
&chat_request.messages,
traceparent,
usage_preferences,
inline_routing_preferences,
request_id,
)
.await;
@ -132,10 +115,12 @@ pub async fn router_chat_get_upstream_model(
match routing_result {
Ok(route) => match route {
Some((route_name, model_name)) => {
Some((route_name, ranked_models)) => {
let model_name = ranked_models.first().cloned().unwrap_or_default();
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult {
model_name,
models: ranked_models,
route_name: Some(route_name),
})
}
@ -147,6 +132,7 @@ pub async fn router_chat_get_upstream_model(
Ok(RoutingResult {
model_name: "none".to_string(),
models: vec!["none".to_string()],
route_name: None,
})
}

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
use common::consts::REQUEST_ID_HEADER;
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
@ -15,56 +15,42 @@ use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
use crate::router::llm::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
/// before re-serializing so downstream parsers don't see the non-standard field.
///
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
/// and the parsed preferences. The field is removed from the JSON before re-serializing
/// so downstream parsers don't see it.
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let preferences = json_body
let routing_preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
.and_then(|policy_value| {
if warn_on_size {
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
warn!(
size_bytes = policy_str.len(),
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
"routing_policy exceeds recommended size limit"
);
}
}
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
.and_then(|o| o.remove("routing_preferences"))
.and_then(
|value| match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
Ok(prefs) => {
info!(
num_models = prefs.len(),
"using inline routing_policy from request body"
num_routes = prefs.len(),
"using inline routing_preferences from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_policy");
warn!(error = %err, "failed to parse routing_preferences");
None
}
}
});
},
);
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
Ok((bytes, routing_preferences))
}
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
model: String,
/// Ranked model list — use first, fall back to next on 429/5xx.
models: Vec<String>,
route: Option<String>,
trace_id: String,
}
@ -136,8 +122,9 @@ async fn routing_decision_inner(
"routing decision request body received"
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
// Extract routing_preferences from body before parsing as ProviderRequestType
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
{
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
@ -164,27 +151,27 @@ async fn routing_decision_inner(
}
};
// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
router_service,
client_request,
&traceparent,
&request_path,
&request_id,
inline_preferences,
inline_routing_preferences,
)
.await;
match routing_result {
Ok(result) => {
let response = RoutingDecisionResponse {
model: result.model_name,
models: result.models,
route: result.route_name,
trace_id,
};
info!(
model = %response.model,
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
total_models = response.models.len(),
route = ?response.route,
"routing decision completed"
);
@ -227,101 +214,70 @@ mod tests {
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_valid_policy() {
let policy = r#""routing_policy": [
{
"model": "openai/gpt-4o",
"routing_preferences": [
{"name": "coding", "description": "code generation tasks"}
]
},
{
"model": "openai/gpt-4o-mini",
"routing_preferences": [
{"name": "general", "description": "general questions"}
]
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
assert_eq!(prefs[0].model, "openai/gpt-4o");
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
assert_eq!(prefs[1].routing_preferences[0].name, "general");
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
#[test]
fn extract_routing_policy_invalid_policy_returns_none() {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
// routing_policy should still be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_invalid_json_returns_error() {
let body = b"not valid json";
let result = extract_routing_policy(body, false);
let result = extract_routing_policy(body);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse JSON"));
}
#[test]
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
fn extract_routing_policy_routing_preferences() {
let policy = r#""routing_preferences": [
{
"name": "code generation",
"description": "generate new code",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": "fastest"}
}
]"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
let prefs = prefs.expect("should have parsed routing_preferences");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].name, "code generation");
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_policy").is_none());
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {
model: "openai/gpt-4o".to_string(),
models: vec![
"openai/gpt-4o-mini".to_string(),
"openai/gpt-4o".to_string(),
],
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
assert_eq!(parsed["models"][1], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
}
@ -329,13 +285,13 @@ mod tests {
#[test]
fn routing_decision_response_serialization_no_route() {
let response = RoutingDecisionResponse {
model: "none".to_string(),
models: vec!["none".to_string()],
route: None,
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert_eq!(parsed["models"][0], "none");
assert!(parsed["route"].is_null());
}
}