mirror of
https://github.com/katanemo/plano.git
synced 2026-05-10 16:22:42 +02:00
model routing: cost/latency ranking with ranked fallback list (#849)
This commit is contained in:
parent
3a531ce22a
commit
e5751d6b13
23 changed files with 1524 additions and 317 deletions
|
|
@ -119,7 +119,7 @@ async fn llm_chat_inner(
|
|||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
client_api,
|
||||
provider_id,
|
||||
} = parsed;
|
||||
|
|
@ -261,7 +261,7 @@ async fn llm_chat_inner(
|
|||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
|
@ -323,7 +323,7 @@ struct PreparedRequest {
|
|||
temperature: Option<f32>,
|
||||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
|
||||
inline_routing_preferences: Option<Vec<common::configuration::TopLevelRoutingPreference>>,
|
||||
client_api: Option<SupportedAPIsFromClient>,
|
||||
provider_id: hermesllm::ProviderId,
|
||||
}
|
||||
|
|
@ -352,16 +352,14 @@ async fn parse_and_validate_request(
|
|||
"request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy) =
|
||||
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|
||||
|err| {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
},
|
||||
)?;
|
||||
// Extract routing_preferences from request body if present
|
||||
let (chat_request_bytes, inline_routing_preferences) =
|
||||
crate::handlers::routing_service::extract_routing_policy(&raw_bytes).map_err(|err| {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
})?;
|
||||
|
||||
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
|
||||
warn!(path = %request_path, "unsupported endpoint");
|
||||
|
|
@ -439,7 +437,7 @@ async fn parse_and_validate_request(
|
|||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
client_api,
|
||||
provider_id,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use common::configuration::TopLevelRoutingPreference;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hermesllm::ProviderRequestType;
|
||||
use hyper::StatusCode;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -10,7 +10,10 @@ use crate::streaming::truncate_message;
|
|||
use crate::tracing::routing;
|
||||
|
||||
pub struct RoutingResult {
|
||||
/// Primary model to use (first in the ranked list).
|
||||
pub model_name: String,
|
||||
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
|
||||
pub models: Vec<String>,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -39,11 +42,8 @@ pub async fn router_chat_get_upstream_model(
|
|||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
|
|
@ -78,22 +78,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
"router request"
|
||||
);
|
||||
|
||||
// Use inline preferences if provided, otherwise fall back to metadata extraction
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
|
||||
{
|
||||
inline_usage_preferences
|
||||
} else {
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok())
|
||||
};
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
|
|
@ -107,7 +91,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
|
||||
|
||||
info!(
|
||||
has_usage_preferences = usage_preferences.is_some(),
|
||||
path = %request_path,
|
||||
latest_message = %latest_message_for_log,
|
||||
"processing router request"
|
||||
|
|
@ -121,7 +104,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
.determine_route(
|
||||
&chat_request.messages,
|
||||
traceparent,
|
||||
usage_preferences,
|
||||
inline_routing_preferences,
|
||||
request_id,
|
||||
)
|
||||
.await;
|
||||
|
|
@ -132,10 +115,12 @@ pub async fn router_chat_get_upstream_model(
|
|||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((route_name, model_name)) => {
|
||||
Some((route_name, ranked_models)) => {
|
||||
let model_name = ranked_models.first().cloned().unwrap_or_default();
|
||||
current_span.record("route.selected_model", model_name.as_str());
|
||||
Ok(RoutingResult {
|
||||
model_name,
|
||||
models: ranked_models,
|
||||
route_name: Some(route_name),
|
||||
})
|
||||
}
|
||||
|
|
@ -147,6 +132,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
|
||||
Ok(RoutingResult {
|
||||
model_name: "none".to_string(),
|
||||
models: vec!["none".to_string()],
|
||||
route_name: None,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelUsagePreference, SpanAttributes};
|
||||
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
|
|
@ -15,56 +15,42 @@ use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
|||
use crate::router::llm::RouterService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||
|
||||
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
|
||||
/// and parsed preferences. The `routing_policy` field is removed from the JSON
|
||||
/// before re-serializing so downstream parsers don't see the non-standard field.
|
||||
///
|
||||
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
|
||||
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
|
||||
/// and the parsed preferences. The field is removed from the JSON before re-serializing
|
||||
/// so downstream parsers don't see it.
|
||||
pub fn extract_routing_policy(
|
||||
raw_bytes: &[u8],
|
||||
warn_on_size: bool,
|
||||
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
|
||||
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
|
||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||
|
||||
let preferences = json_body
|
||||
let routing_preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|obj| obj.remove("routing_policy"))
|
||||
.and_then(|policy_value| {
|
||||
if warn_on_size {
|
||||
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
|
||||
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
|
||||
warn!(
|
||||
size_bytes = policy_str.len(),
|
||||
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
|
||||
"routing_policy exceeds recommended size limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
|
||||
.and_then(|o| o.remove("routing_preferences"))
|
||||
.and_then(
|
||||
|value| match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
|
||||
Ok(prefs) => {
|
||||
info!(
|
||||
num_models = prefs.len(),
|
||||
"using inline routing_policy from request body"
|
||||
num_routes = prefs.len(),
|
||||
"using inline routing_preferences from request body"
|
||||
);
|
||||
Some(prefs)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse routing_policy");
|
||||
warn!(error = %err, "failed to parse routing_preferences");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
Ok((bytes, preferences))
|
||||
Ok((bytes, routing_preferences))
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RoutingDecisionResponse {
|
||||
model: String,
|
||||
/// Ranked model list — use first, fall back to next on 429/5xx.
|
||||
models: Vec<String>,
|
||||
route: Option<String>,
|
||||
trace_id: String,
|
||||
}
|
||||
|
|
@ -136,8 +122,9 @@ async fn routing_decision_inner(
|
|||
"routing decision request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy from request body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
|
||||
// Extract routing_preferences from body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
|
|
@ -164,27 +151,27 @@ async fn routing_decision_inner(
|
|||
}
|
||||
};
|
||||
|
||||
// Call the existing routing logic with inline preferences
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_preferences,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await;
|
||||
|
||||
match routing_result {
|
||||
Ok(result) => {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: result.model_name,
|
||||
models: result.models,
|
||||
route: result.route_name,
|
||||
trace_id,
|
||||
};
|
||||
|
||||
info!(
|
||||
model = %response.model,
|
||||
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
|
||||
total_models = response.models.len(),
|
||||
route = ?response.route,
|
||||
"routing decision completed"
|
||||
);
|
||||
|
|
@ -227,101 +214,70 @@ mod tests {
|
|||
#[test]
|
||||
fn extract_routing_policy_no_policy() {
|
||||
let body = make_chat_body("");
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_valid_policy() {
|
||||
let policy = r#""routing_policy": [
|
||||
{
|
||||
"model": "openai/gpt-4o",
|
||||
"routing_preferences": [
|
||||
{"name": "coding", "description": "code generation tasks"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"routing_preferences": [
|
||||
{"name": "general", "description": "general questions"}
|
||||
]
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should have parsed preferences");
|
||||
assert_eq!(prefs.len(), 2);
|
||||
assert_eq!(prefs[0].model, "openai/gpt-4o");
|
||||
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
|
||||
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
|
||||
assert_eq!(prefs[1].routing_preferences[0].name, "general");
|
||||
|
||||
// routing_policy should be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_policy_returns_none() {
|
||||
// routing_policy is present but has wrong shape
|
||||
let policy = r#""routing_policy": "not-an-array""#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
// Invalid policy should be ignored (returns None), not error
|
||||
assert!(prefs.is_none());
|
||||
// routing_policy should still be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_json_returns_error() {
|
||||
let body = b"not valid json";
|
||||
let result = extract_routing_policy(body, false);
|
||||
let result = extract_routing_policy(body);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Failed to parse JSON"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_empty_array() {
|
||||
let policy = r#""routing_policy": []"#;
|
||||
fn extract_routing_policy_routing_preferences() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "code generation",
|
||||
"description": "generate new code",
|
||||
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
|
||||
"selection_policy": {"prefer": "fastest"}
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
let prefs = prefs.expect("empty array is valid");
|
||||
assert_eq!(prefs.len(), 0);
|
||||
let prefs = prefs.expect("should have parsed routing_preferences");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].name, "code generation");
|
||||
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
|
||||
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
models: vec![
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
"openai/gpt-4o".to_string(),
|
||||
],
|
||||
route: Some("code_generation".to_string()),
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "openai/gpt-4o");
|
||||
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
|
||||
assert_eq!(parsed["models"][1], "openai/gpt-4o");
|
||||
assert_eq!(parsed["route"], "code_generation");
|
||||
assert_eq!(parsed["trace_id"], "abc123");
|
||||
}
|
||||
|
|
@ -329,13 +285,13 @@ mod tests {
|
|||
#[test]
|
||||
fn routing_decision_response_serialization_no_route() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "none".to_string(),
|
||||
models: vec!["none".to_string()],
|
||||
route: None,
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "none");
|
||||
assert_eq!(parsed["models"][0], "none");
|
||||
assert!(parsed["route"].is_null());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue