mirror of
https://github.com/katanemo/plano.git
synced 2026-05-06 14:22:51 +02:00
redesign model_metrics_sources, drop legacy per-provider routing, return ranked model list
This commit is contained in:
parent
b12bf74e5c
commit
76b1f37052
12 changed files with 639 additions and 429 deletions
|
|
@ -119,7 +119,6 @@ async fn llm_chat_inner(
|
|||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
client_api,
|
||||
provider_id,
|
||||
|
|
@ -262,7 +261,6 @@ async fn llm_chat_inner(
|
|||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await
|
||||
|
|
@ -325,7 +323,6 @@ struct PreparedRequest {
|
|||
temperature: Option<f32>,
|
||||
tool_names: Option<Vec<String>>,
|
||||
user_message_preview: Option<String>,
|
||||
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
|
||||
inline_routing_preferences: Option<Vec<common::configuration::TopLevelRoutingPreference>>,
|
||||
client_api: Option<SupportedAPIsFromClient>,
|
||||
provider_id: hermesllm::ProviderId,
|
||||
|
|
@ -355,16 +352,14 @@ async fn parse_and_validate_request(
|
|||
"request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy and routing_preferences from request body if present
|
||||
let (chat_request_bytes, inline_routing_policy, inline_routing_preferences) =
|
||||
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|
||||
|err| {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
},
|
||||
)?;
|
||||
// Extract routing_preferences from request body if present
|
||||
let (chat_request_bytes, inline_routing_preferences) =
|
||||
crate::handlers::routing_service::extract_routing_policy(&raw_bytes).map_err(|err| {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
|
||||
*r.status_mut() = StatusCode::BAD_REQUEST;
|
||||
r
|
||||
})?;
|
||||
|
||||
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
|
||||
warn!(path = %request_path, "unsupported endpoint");
|
||||
|
|
@ -442,7 +437,6 @@ async fn parse_and_validate_request(
|
|||
temperature,
|
||||
tool_names,
|
||||
user_message_preview,
|
||||
inline_routing_policy,
|
||||
inline_routing_preferences,
|
||||
client_api,
|
||||
provider_id,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use common::configuration::{ModelUsagePreference, TopLevelRoutingPreference};
|
||||
use common::configuration::TopLevelRoutingPreference;
|
||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use hermesllm::ProviderRequestType;
|
||||
use hyper::StatusCode;
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
|
@ -10,7 +10,10 @@ use crate::streaming::truncate_message;
|
|||
use crate::tracing::routing;
|
||||
|
||||
pub struct RoutingResult {
|
||||
/// Primary model to use (first in the ranked list).
|
||||
pub model_name: String,
|
||||
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
|
||||
pub models: Vec<String>,
|
||||
pub route_name: Option<String>,
|
||||
}
|
||||
|
||||
|
|
@ -39,12 +42,8 @@ pub async fn router_chat_get_upstream_model(
|
|||
traceparent: &str,
|
||||
request_path: &str,
|
||||
request_id: &str,
|
||||
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
) -> Result<RoutingResult, RoutingError> {
|
||||
// Clone metadata for routing before converting (which consumes client_request)
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||
let chat_request = match ProviderRequestType::try_from((
|
||||
client_request,
|
||||
|
|
@ -79,22 +78,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
"router request"
|
||||
);
|
||||
|
||||
// Use inline preferences if provided, otherwise fall back to metadata extraction
|
||||
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
|
||||
{
|
||||
inline_usage_preferences
|
||||
} else {
|
||||
let usage_preferences_str: Option<String> =
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("plano_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
});
|
||||
usage_preferences_str
|
||||
.as_ref()
|
||||
.and_then(|s| serde_yaml::from_str(s).ok())
|
||||
};
|
||||
|
||||
// Prepare log message with latest message from chat request
|
||||
let latest_message_for_log = chat_request
|
||||
.messages
|
||||
|
|
@ -108,7 +91,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
|
||||
|
||||
info!(
|
||||
has_usage_preferences = usage_preferences.is_some(),
|
||||
path = %request_path,
|
||||
latest_message = %latest_message_for_log,
|
||||
"processing router request"
|
||||
|
|
@ -122,7 +104,6 @@ pub async fn router_chat_get_upstream_model(
|
|||
.determine_route(
|
||||
&chat_request.messages,
|
||||
traceparent,
|
||||
usage_preferences,
|
||||
inline_routing_preferences,
|
||||
request_id,
|
||||
)
|
||||
|
|
@ -134,10 +115,12 @@ pub async fn router_chat_get_upstream_model(
|
|||
|
||||
match routing_result {
|
||||
Ok(route) => match route {
|
||||
Some((route_name, model_name)) => {
|
||||
Some((route_name, ranked_models)) => {
|
||||
let model_name = ranked_models.first().cloned().unwrap_or_default();
|
||||
current_span.record("route.selected_model", model_name.as_str());
|
||||
Ok(RoutingResult {
|
||||
model_name,
|
||||
models: ranked_models,
|
||||
route_name: Some(route_name),
|
||||
})
|
||||
}
|
||||
|
|
@ -149,6 +132,7 @@ pub async fn router_chat_get_upstream_model(
|
|||
|
||||
Ok(RoutingResult {
|
||||
model_name: "none".to_string(),
|
||||
models: vec!["none".to_string()],
|
||||
route_name: None,
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelUsagePreference, SpanAttributes, TopLevelRoutingPreference};
|
||||
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
|
||||
use common::consts::REQUEST_ID_HEADER;
|
||||
use common::errors::BrightStaffError;
|
||||
use hermesllm::clients::SupportedAPIsFromClient;
|
||||
|
|
@ -15,61 +15,16 @@ use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
|
|||
use crate::router::llm::RouterService;
|
||||
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
|
||||
|
||||
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
|
||||
|
||||
type ExtractedRoutingPolicies = (
|
||||
Bytes,
|
||||
Option<Vec<ModelUsagePreference>>,
|
||||
Option<Vec<TopLevelRoutingPreference>>,
|
||||
);
|
||||
|
||||
/// Extracts `routing_policy` and `routing_preferences` from a JSON body, returning
|
||||
/// the cleaned body bytes and both sets of parsed preferences. Both fields are removed
|
||||
/// from the JSON before re-serializing so downstream parsers don't see them.
|
||||
///
|
||||
/// - `routing_policy` — legacy per-provider format (`Vec<ModelUsagePreference>`)
|
||||
/// - `routing_preferences` — v0.4.0+ format (`Vec<TopLevelRoutingPreference>`)
|
||||
///
|
||||
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
|
||||
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
|
||||
/// and the parsed preferences. The field is removed from the JSON before re-serializing
|
||||
/// so downstream parsers don't see it.
|
||||
pub fn extract_routing_policy(
|
||||
raw_bytes: &[u8],
|
||||
warn_on_size: bool,
|
||||
) -> Result<ExtractedRoutingPolicies, String> {
|
||||
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
|
||||
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
|
||||
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
|
||||
|
||||
// Extract legacy routing_policy
|
||||
let legacy_preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|o| o.remove("routing_policy"))
|
||||
.and_then(|policy_value| {
|
||||
if warn_on_size {
|
||||
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
|
||||
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
|
||||
warn!(
|
||||
size_bytes = policy_str.len(),
|
||||
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
|
||||
"routing_policy exceeds recommended size limit"
|
||||
);
|
||||
}
|
||||
}
|
||||
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
|
||||
Ok(prefs) => {
|
||||
info!(
|
||||
num_models = prefs.len(),
|
||||
"using inline routing_policy from request body"
|
||||
);
|
||||
Some(prefs)
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse routing_policy");
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Extract new v0.4.0 routing_preferences
|
||||
let top_level_preferences = json_body
|
||||
let routing_preferences = json_body
|
||||
.as_object_mut()
|
||||
.and_then(|o| o.remove("routing_preferences"))
|
||||
.and_then(
|
||||
|
|
@ -89,12 +44,13 @@ pub fn extract_routing_policy(
|
|||
);
|
||||
|
||||
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
|
||||
Ok((bytes, legacy_preferences, top_level_preferences))
|
||||
Ok((bytes, routing_preferences))
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct RoutingDecisionResponse {
|
||||
model: String,
|
||||
/// Ranked model list — use first, fall back to next on 429/5xx.
|
||||
models: Vec<String>,
|
||||
route: Option<String>,
|
||||
trace_id: String,
|
||||
}
|
||||
|
|
@ -166,19 +122,19 @@ async fn routing_decision_inner(
|
|||
"routing decision request body received"
|
||||
);
|
||||
|
||||
// Extract routing_policy and routing_preferences from body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_preferences, inline_routing_preferences) =
|
||||
match extract_routing_policy(&raw_bytes, true) {
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
// Extract routing_preferences from body before parsing as ProviderRequestType
|
||||
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(err) => {
|
||||
warn!(error = %err, "failed to parse request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Failed to parse request JSON: {}",
|
||||
err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let client_request = match ProviderRequestType::try_from((
|
||||
&chat_request_bytes[..],
|
||||
|
|
@ -195,14 +151,12 @@ async fn routing_decision_inner(
|
|||
}
|
||||
};
|
||||
|
||||
// Call the existing routing logic with inline preferences
|
||||
let routing_result = router_chat_get_upstream_model(
|
||||
router_service,
|
||||
client_request,
|
||||
&traceparent,
|
||||
&request_path,
|
||||
&request_id,
|
||||
inline_preferences,
|
||||
inline_routing_preferences,
|
||||
)
|
||||
.await;
|
||||
|
|
@ -210,13 +164,14 @@ async fn routing_decision_inner(
|
|||
match routing_result {
|
||||
Ok(result) => {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: result.model_name,
|
||||
models: result.models,
|
||||
route: result.route_name,
|
||||
trace_id,
|
||||
};
|
||||
|
||||
info!(
|
||||
model = %response.model,
|
||||
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
|
||||
total_models = response.models.len(),
|
||||
route = ?response.route,
|
||||
"routing decision completed"
|
||||
);
|
||||
|
|
@ -259,95 +214,23 @@ mod tests {
|
|||
#[test]
|
||||
fn extract_routing_policy_no_policy() {
|
||||
let body = make_chat_body("");
|
||||
let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_none());
|
||||
assert!(top_prefs.is_none());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_valid_policy() {
|
||||
let policy = r#""routing_policy": [
|
||||
{
|
||||
"model": "openai/gpt-4o",
|
||||
"routing_preferences": [
|
||||
{"name": "coding", "description": "code generation tasks"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"model": "openai/gpt-4o-mini",
|
||||
"routing_preferences": [
|
||||
{"name": "general", "description": "general questions"}
|
||||
]
|
||||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("should have parsed preferences");
|
||||
assert_eq!(prefs.len(), 2);
|
||||
assert_eq!(prefs[0].model, "openai/gpt-4o");
|
||||
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
|
||||
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
|
||||
assert_eq!(prefs[1].routing_preferences[0].name, "general");
|
||||
assert!(top_prefs.is_none());
|
||||
|
||||
// routing_policy should be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_policy_returns_none() {
|
||||
// routing_policy is present but has wrong shape
|
||||
let policy = r#""routing_policy": "not-an-array""#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
// Invalid policy should be ignored (returns None), not error
|
||||
assert!(prefs.is_none());
|
||||
// routing_policy should still be stripped from cleaned body
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_invalid_json_returns_error() {
|
||||
let body = b"not valid json";
|
||||
let result = extract_routing_policy(body, false);
|
||||
let result = extract_routing_policy(body);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Failed to parse JSON"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_empty_array() {
|
||||
let policy = r#""routing_policy": []"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (_, prefs, _) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
let prefs = prefs.expect("empty array is valid");
|
||||
assert_eq!(prefs.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs, _) = extract_routing_policy(&body, false).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_policy").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_top_level_routing_preferences() {
|
||||
fn extract_routing_policy_routing_preferences() {
|
||||
let policy = r#""routing_preferences": [
|
||||
{
|
||||
"name": "code generation",
|
||||
|
|
@ -357,31 +240,44 @@ mod tests {
|
|||
}
|
||||
]"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, legacy_prefs, top_prefs) = extract_routing_policy(&body, false).unwrap();
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(legacy_prefs.is_none());
|
||||
let top_prefs = top_prefs.expect("should have parsed top-level routing_preferences");
|
||||
assert_eq!(top_prefs.len(), 1);
|
||||
assert_eq!(top_prefs[0].name, "code generation");
|
||||
assert_eq!(
|
||||
top_prefs[0].models,
|
||||
vec!["openai/gpt-4o", "openai/gpt-4o-mini"]
|
||||
);
|
||||
let prefs = prefs.expect("should have parsed routing_preferences");
|
||||
assert_eq!(prefs.len(), 1);
|
||||
assert_eq!(prefs[0].name, "code generation");
|
||||
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
|
||||
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_routing_policy_preserves_other_fields() {
|
||||
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
|
||||
let body = make_chat_body(policy);
|
||||
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
|
||||
|
||||
assert!(prefs.is_some());
|
||||
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
|
||||
assert_eq!(cleaned_json["temperature"], 0.5);
|
||||
assert_eq!(cleaned_json["max_tokens"], 100);
|
||||
assert!(cleaned_json.get("routing_preferences").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn routing_decision_response_serialization() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "openai/gpt-4o".to_string(),
|
||||
models: vec![
|
||||
"openai/gpt-4o-mini".to_string(),
|
||||
"openai/gpt-4o".to_string(),
|
||||
],
|
||||
route: Some("code_generation".to_string()),
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "openai/gpt-4o");
|
||||
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
|
||||
assert_eq!(parsed["models"][1], "openai/gpt-4o");
|
||||
assert_eq!(parsed["route"], "code_generation");
|
||||
assert_eq!(parsed["trace_id"], "abc123");
|
||||
}
|
||||
|
|
@ -389,13 +285,13 @@ mod tests {
|
|||
#[test]
|
||||
fn routing_decision_response_serialization_no_route() {
|
||||
let response = RoutingDecisionResponse {
|
||||
model: "none".to_string(),
|
||||
models: vec!["none".to_string()],
|
||||
route: None,
|
||||
trace_id: "abc123".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed["model"], "none");
|
||||
assert_eq!(parsed["models"][0], "none");
|
||||
assert!(parsed["route"].is_null());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -174,28 +174,11 @@ async fn init_app_state(
|
|||
.map(|p| p.name.clone())
|
||||
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
|
||||
|
||||
// Validate version-gated routing_preferences rules.
|
||||
// Validate that top-level routing_preferences requires v0.4.0+.
|
||||
let config_version = parse_semver(&config.version);
|
||||
let is_v040_plus = config_version >= (0, 4, 0);
|
||||
|
||||
if is_v040_plus {
|
||||
// v0.4.0+: per-provider routing_preferences are forbidden.
|
||||
let providers_with_per_provider_prefs: Vec<&str> = config
|
||||
.model_providers
|
||||
.iter()
|
||||
.filter(|p| p.routing_preferences.is_some())
|
||||
.filter_map(|p| p.model.as_deref())
|
||||
.collect();
|
||||
if !providers_with_per_provider_prefs.is_empty() {
|
||||
return Err(format!(
|
||||
"routing_preferences inside model_providers is not allowed in v0.4.0+. \
|
||||
Use the top-level routing_preferences instead. \
|
||||
Offending models: {}",
|
||||
providers_with_per_provider_prefs.join(", ")
|
||||
)
|
||||
.into());
|
||||
}
|
||||
} else if config.routing_preferences.is_some() {
|
||||
if !is_v040_plus && config.routing_preferences.is_some() {
|
||||
return Err(
|
||||
"top-level routing_preferences requires version v0.4.0 or above. \
|
||||
Update the version field or remove routing_preferences."
|
||||
|
|
@ -224,17 +207,34 @@ async fn init_app_state(
|
|||
}
|
||||
}
|
||||
|
||||
// Initialize ModelMetricsService if model_metrics_sources is configured.
|
||||
let metrics_service: Option<Arc<ModelMetricsService>> =
|
||||
if let Some(ref sources) = config.model_metrics_sources {
|
||||
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
|
||||
Some(Arc::new(svc))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
|
||||
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
|
||||
config.model_metrics_sources
|
||||
{
|
||||
use common::configuration::MetricsSource;
|
||||
let cost_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::CostMetrics { .. }))
|
||||
.count();
|
||||
let prom_count = sources
|
||||
.iter()
|
||||
.filter(|s| matches!(s, MetricsSource::PrometheusMetrics { .. }))
|
||||
.count();
|
||||
if cost_count > 1 {
|
||||
return Err("model_metrics_sources: only one cost_metrics source is allowed".into());
|
||||
}
|
||||
if prom_count > 1 {
|
||||
return Err(
|
||||
"model_metrics_sources: only one prometheus_metrics source is allowed".into(),
|
||||
);
|
||||
}
|
||||
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
|
||||
Some(Arc::new(svc))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let router_service = Arc::new(RouterService::new(
|
||||
config.model_providers.clone(),
|
||||
config.routing_preferences.clone(),
|
||||
metrics_service,
|
||||
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{
|
||||
LlmProvider, ModelUsagePreference, RoutingPreference, TopLevelRoutingPreference,
|
||||
},
|
||||
configuration::TopLevelRoutingPreference,
|
||||
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
|
||||
};
|
||||
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::Message;
|
||||
use hyper::header;
|
||||
use thiserror::Error;
|
||||
|
|
@ -22,7 +22,6 @@ pub struct RouterService {
|
|||
client: reqwest::Client,
|
||||
router_model: Arc<dyn RouterModel>,
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
}
|
||||
|
|
@ -40,60 +39,37 @@ pub type Result<T> = std::result::Result<T, RoutingError>;
|
|||
|
||||
impl RouterService {
|
||||
pub fn new(
|
||||
providers: Vec<LlmProvider>,
|
||||
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
|
||||
metrics_service: Option<Arc<ModelMetricsService>>,
|
||||
router_url: String,
|
||||
routing_model_name: String,
|
||||
routing_provider_name: String,
|
||||
) -> Self {
|
||||
// Build top-level preference map and sentinel llm_routes when v0.4.0 format is used.
|
||||
let (top_level_preferences, llm_routes, llm_usage_defined) =
|
||||
if let Some(top_prefs) = top_level_prefs {
|
||||
let top_level_map: HashMap<String, TopLevelRoutingPreference> =
|
||||
top_prefs.into_iter().map(|p| (p.name.clone(), p)).collect();
|
||||
// Build sentinel routes: route_name → first model (RouterModelV1 needs a model
|
||||
// mapping, but RouterService overrides the selection via metrics_service).
|
||||
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_map
|
||||
.iter()
|
||||
.filter_map(|(name, pref)| {
|
||||
pref.models.first().map(|first_model| {
|
||||
(
|
||||
first_model.clone(),
|
||||
vec![RoutingPreference {
|
||||
name: name.clone(),
|
||||
description: pref.description.clone(),
|
||||
}],
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
let defined = !top_level_map.is_empty();
|
||||
(top_level_map, sentinel_routes, defined)
|
||||
} else {
|
||||
// Legacy per-provider format.
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
|
||||
.map_or_else(HashMap::new, |prefs| {
|
||||
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
|
||||
});
|
||||
|
||||
let routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
|
||||
.iter()
|
||||
.filter_map(|provider| {
|
||||
provider
|
||||
.routing_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| (provider.name.clone(), prefs.clone()))
|
||||
})
|
||||
.collect();
|
||||
|
||||
let defined = !providers_with_usage.is_empty();
|
||||
(HashMap::new(), routes, defined)
|
||||
};
|
||||
// Build sentinel routes for RouterModelV1: route_name → first model.
|
||||
// RouterModelV1 uses this to build its prompt; RouterService overrides
|
||||
// the model selection via rank_models() after the route is determined.
|
||||
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_preferences
|
||||
.iter()
|
||||
.filter_map(|(name, pref)| {
|
||||
pref.models.first().map(|first_model| {
|
||||
(
|
||||
first_model.clone(),
|
||||
vec![RoutingPreference {
|
||||
name: name.clone(),
|
||||
description: pref.description.clone(),
|
||||
}],
|
||||
)
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
sentinel_routes,
|
||||
routing_model_name,
|
||||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
|
@ -103,7 +79,6 @@ impl RouterService {
|
|||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
llm_usage_defined,
|
||||
top_level_preferences,
|
||||
metrics_service,
|
||||
}
|
||||
|
|
@ -113,10 +88,9 @@ impl RouterService {
|
|||
&self,
|
||||
messages: &[Message],
|
||||
traceparent: &str,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
request_id: &str,
|
||||
) -> Result<Option<(String, String)>> {
|
||||
) -> Result<Option<(String, Vec<String>)>> {
|
||||
if messages.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
|
@ -126,41 +100,27 @@ impl RouterService {
|
|||
inline_routing_preferences
|
||||
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
|
||||
|
||||
// Determine whether any routing is defined.
|
||||
let has_top_level = inline_top_map.is_some() || !self.top_level_preferences.is_empty();
|
||||
|
||||
if usage_preferences
|
||||
.as_ref()
|
||||
.is_none_or(|prefs| prefs.len() < 2)
|
||||
&& !self.llm_usage_defined
|
||||
&& !has_top_level
|
||||
{
|
||||
// No routing defined — skip the router call entirely.
|
||||
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// For top-level format, build a synthetic ModelUsagePreference list so RouterModelV1
|
||||
// For inline overrides, build synthetic ModelUsagePreference list so RouterModelV1
|
||||
// generates the correct prompt (route name + description pairs).
|
||||
// For config-level prefs the sentinel routes are already baked into RouterModelV1.
|
||||
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
|
||||
if let Some(ref inline_map) = inline_top_map {
|
||||
Some(
|
||||
inline_map
|
||||
.values()
|
||||
.map(|p| ModelUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else if !self.top_level_preferences.is_empty() {
|
||||
// Config top-level prefs: already encoded as sentinel routes in RouterModelV1,
|
||||
// pass None so it uses the pre-built llm_route_json_str.
|
||||
None
|
||||
} else {
|
||||
usage_preferences.clone()
|
||||
};
|
||||
inline_top_map.as_ref().map(|inline_map| {
|
||||
inline_map
|
||||
.values()
|
||||
.map(|p| ModelUsagePreference {
|
||||
model: p.models.first().cloned().unwrap_or_default(),
|
||||
routing_preferences: vec![RoutingPreference {
|
||||
name: p.name.clone(),
|
||||
description: p.description.clone(),
|
||||
}],
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let router_request = self
|
||||
.router_model
|
||||
|
|
@ -209,21 +169,20 @@ impl RouterService {
|
|||
.router_model
|
||||
.parse_response(&content, &effective_usage_preferences)?;
|
||||
|
||||
let result = if let Some((route_name, _sentinel_model)) = parsed {
|
||||
// Check if this route belongs to the top-level preference format.
|
||||
let result = if let Some((route_name, _sentinel)) = parsed {
|
||||
let top_pref = inline_top_map
|
||||
.as_ref()
|
||||
.and_then(|m| m.get(&route_name))
|
||||
.or_else(|| self.top_level_preferences.get(&route_name));
|
||||
|
||||
if let Some(pref) = top_pref {
|
||||
let selected_model = match &self.metrics_service {
|
||||
Some(svc) => svc.select_model(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.first().cloned().unwrap_or_default(),
|
||||
let ranked = match &self.metrics_service {
|
||||
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
|
||||
None => pref.models.clone(),
|
||||
};
|
||||
Some((route_name, selected_model))
|
||||
Some((route_name, ranked))
|
||||
} else {
|
||||
Some((route_name, _sentinel_model))
|
||||
None
|
||||
}
|
||||
} else {
|
||||
None
|
||||
|
|
|
|||
|
|
@ -2,59 +2,75 @@ use std::collections::HashMap;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use common::configuration::{ModelMetricsSources, SelectionPolicy, SelectionPreference};
|
||||
use serde::Deserialize;
|
||||
use common::configuration::{MetricsSource, SelectionPolicy, SelectionPreference};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::{info, warn};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct MetricsResponse {
|
||||
#[serde(default)]
|
||||
cost: HashMap<String, f64>,
|
||||
#[serde(default)]
|
||||
latency: HashMap<String, f64>,
|
||||
}
|
||||
|
||||
pub struct ModelMetricsService {
|
||||
cost: Arc<RwLock<HashMap<String, f64>>>,
|
||||
latency: Arc<RwLock<HashMap<String, f64>>>,
|
||||
}
|
||||
|
||||
impl ModelMetricsService {
|
||||
pub async fn new(sources: &ModelMetricsSources, client: reqwest::Client) -> Self {
|
||||
pub async fn new(sources: &[MetricsSource], client: reqwest::Client) -> Self {
|
||||
let cost_data = Arc::new(RwLock::new(HashMap::new()));
|
||||
let latency_data = Arc::new(RwLock::new(HashMap::new()));
|
||||
|
||||
let metrics = fetch_metrics(&sources.url, &client).await;
|
||||
info!(
|
||||
cost_models = metrics.cost.len(),
|
||||
latency_models = metrics.latency.len(),
|
||||
url = %sources.url,
|
||||
"fetched model metrics"
|
||||
);
|
||||
*cost_data.write().await = metrics.cost;
|
||||
*latency_data.write().await = metrics.latency;
|
||||
for source in sources {
|
||||
match source {
|
||||
MetricsSource::CostMetrics {
|
||||
url,
|
||||
refresh_interval,
|
||||
auth,
|
||||
} => {
|
||||
let data = fetch_cost_metrics(url, auth.as_ref(), &client).await;
|
||||
info!(models = data.len(), url = %url, "fetched cost metrics");
|
||||
*cost_data.write().await = data;
|
||||
|
||||
if let Some(interval_secs) = sources.refresh_interval {
|
||||
let cost_clone = Arc::clone(&cost_data);
|
||||
let latency_clone = Arc::clone(&latency_data);
|
||||
let client_clone = client.clone();
|
||||
let url = sources.url.clone();
|
||||
tokio::spawn(async move {
|
||||
let interval = Duration::from_secs(interval_secs);
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let metrics = fetch_metrics(&url, &client_clone).await;
|
||||
info!(
|
||||
cost_models = metrics.cost.len(),
|
||||
latency_models = metrics.latency.len(),
|
||||
url = %url,
|
||||
"refreshed model metrics"
|
||||
);
|
||||
*cost_clone.write().await = metrics.cost;
|
||||
*latency_clone.write().await = metrics.latency;
|
||||
if let Some(interval_secs) = refresh_interval {
|
||||
let cost_clone = Arc::clone(&cost_data);
|
||||
let client_clone = client.clone();
|
||||
let url = url.clone();
|
||||
let auth = auth.clone();
|
||||
let interval = Duration::from_secs(*interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data =
|
||||
fetch_cost_metrics(&url, auth.as_ref(), &client_clone).await;
|
||||
info!(models = data.len(), url = %url, "refreshed cost metrics");
|
||||
*cost_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
MetricsSource::PrometheusMetrics {
|
||||
url,
|
||||
query,
|
||||
refresh_interval,
|
||||
} => {
|
||||
let data = fetch_prometheus_metrics(url, query, &client).await;
|
||||
info!(models = data.len(), url = %url, "fetched prometheus latency metrics");
|
||||
*latency_data.write().await = data;
|
||||
|
||||
if let Some(interval_secs) = refresh_interval {
|
||||
let latency_clone = Arc::clone(&latency_data);
|
||||
let client_clone = client.clone();
|
||||
let url = url.clone();
|
||||
let query = query.clone();
|
||||
let interval = Duration::from_secs(*interval_secs);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(interval).await;
|
||||
let data =
|
||||
fetch_prometheus_metrics(&url, &query, &client_clone).await;
|
||||
info!(models = data.len(), url = %url, "refreshed prometheus latency metrics");
|
||||
*latency_clone.write().await = data;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ModelMetricsService {
|
||||
|
|
@ -63,63 +79,136 @@ impl ModelMetricsService {
|
|||
}
|
||||
}
|
||||
|
||||
/// Select the best model from `models` according to `policy`.
|
||||
/// Falls back to `models[0]` if metric data is unavailable for all candidates.
|
||||
pub async fn select_model(&self, models: &[String], policy: &SelectionPolicy) -> String {
|
||||
/// Rank `models` by `policy`, returning them in preference order.
|
||||
/// Models with no metric data are appended at the end in their original order.
|
||||
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
|
||||
match policy.prefer {
|
||||
SelectionPreference::Cheapest => {
|
||||
let data = self.cost.read().await;
|
||||
select_by_ascending_metric(models, &data)
|
||||
rank_by_ascending_metric(models, &data)
|
||||
}
|
||||
SelectionPreference::Fastest => {
|
||||
let data = self.latency.read().await;
|
||||
select_by_ascending_metric(models, &data)
|
||||
}
|
||||
SelectionPreference::Random => {
|
||||
let idx = rand_index(models.len());
|
||||
models[idx].clone()
|
||||
rank_by_ascending_metric(models, &data)
|
||||
}
|
||||
SelectionPreference::Random => shuffle(models),
|
||||
SelectionPreference::None => models.to_vec(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn select_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> String {
|
||||
models
|
||||
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
|
||||
let mut with_data: Vec<(&String, f64)> = models
|
||||
.iter()
|
||||
.filter_map(|m| data.get(m.as_str()).map(|v| (m, *v)))
|
||||
.min_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(m, _)| m.clone())
|
||||
.unwrap_or_else(|| models[0].clone())
|
||||
.collect();
|
||||
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let without_data: Vec<&String> = models
|
||||
.iter()
|
||||
.filter(|m| !data.contains_key(m.as_str()))
|
||||
.collect();
|
||||
|
||||
with_data
|
||||
.iter()
|
||||
.map(|(m, _)| (*m).clone())
|
||||
.chain(without_data.iter().map(|m| (*m).clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Simple non-crypto random index using system time nanoseconds.
|
||||
fn rand_index(len: usize) -> usize {
|
||||
fn shuffle(models: &[String]) -> Vec<String> {
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
let nanos = SystemTime::now()
|
||||
let seed = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map(|d| d.subsec_nanos() as usize)
|
||||
.unwrap_or(0);
|
||||
nanos % len
|
||||
let mut result = models.to_vec();
|
||||
let mut state = seed;
|
||||
for i in (1..result.len()).rev() {
|
||||
state = state
|
||||
.wrapping_mul(6364136223846793005)
|
||||
.wrapping_add(1442695040888963407);
|
||||
let j = state % (i + 1);
|
||||
result.swap(i, j);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn fetch_metrics(url: &str, client: &reqwest::Client) -> MetricsResponse {
|
||||
match client.get(url).send().await {
|
||||
Ok(resp) => match resp.json::<MetricsResponse>().await {
|
||||
async fn fetch_cost_metrics(
|
||||
url: &str,
|
||||
auth: Option<&common::configuration::MetricsAuth>,
|
||||
client: &reqwest::Client,
|
||||
) -> HashMap<String, f64> {
|
||||
let mut req = client.get(url);
|
||||
if let Some(auth) = auth {
|
||||
if auth.auth_type == "bearer" {
|
||||
req = req.header("Authorization", format!("Bearer {}", auth.token));
|
||||
} else {
|
||||
warn!(auth_type = %auth.auth_type, "unsupported auth type for cost_metrics, skipping auth");
|
||||
}
|
||||
}
|
||||
match req.send().await {
|
||||
Ok(resp) => match resp.json::<HashMap<String, f64>>().await {
|
||||
Ok(data) => data,
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %url, "failed to parse metrics response");
|
||||
MetricsResponse {
|
||||
cost: HashMap::new(),
|
||||
latency: HashMap::new(),
|
||||
}
|
||||
warn!(error = %err, url = %url, "failed to parse cost metrics response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %url, "failed to fetch metrics");
|
||||
MetricsResponse {
|
||||
cost: HashMap::new(),
|
||||
latency: HashMap::new(),
|
||||
warn!(error = %err, url = %url, "failed to fetch cost metrics");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusResponse {
|
||||
data: PrometheusData,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusData {
|
||||
result: Vec<PrometheusResult>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct PrometheusResult {
|
||||
metric: HashMap<String, String>,
|
||||
value: (f64, String), // (timestamp, value_str)
|
||||
}
|
||||
|
||||
async fn fetch_prometheus_metrics(
|
||||
url: &str,
|
||||
query: &str,
|
||||
client: &reqwest::Client,
|
||||
) -> HashMap<String, f64> {
|
||||
let query_url = format!("{}/api/v1/query", url.trim_end_matches('/'));
|
||||
match client
|
||||
.get(&query_url)
|
||||
.query(&[("query", query)])
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => match resp.json::<PrometheusResponse>().await {
|
||||
Ok(prom) => prom
|
||||
.data
|
||||
.result
|
||||
.into_iter()
|
||||
.filter_map(|r| {
|
||||
let model_name = r.metric.get("model_name")?.clone();
|
||||
let value: f64 = r.value.1.parse().ok()?;
|
||||
Some((model_name, value))
|
||||
})
|
||||
.collect(),
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %query_url, "failed to parse prometheus response");
|
||||
HashMap::new()
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
warn!(error = %err, url = %query_url, "failed to fetch prometheus metrics");
|
||||
HashMap::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -134,32 +223,35 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_by_ascending_metric_picks_lowest() {
|
||||
fn test_rank_by_ascending_metric_picks_lowest_first() {
|
||||
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
|
||||
let mut data = HashMap::new();
|
||||
data.insert("a".to_string(), 0.01);
|
||||
data.insert("b".to_string(), 0.005);
|
||||
data.insert("c".to_string(), 0.02);
|
||||
assert_eq!(select_by_ascending_metric(&models, &data), "b");
|
||||
assert_eq!(
|
||||
rank_by_ascending_metric(&models, &data),
|
||||
vec!["b", "a", "c"]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_by_ascending_metric_fallback_to_first() {
|
||||
fn test_rank_by_ascending_metric_no_data_preserves_order() {
|
||||
let models = vec!["x".to_string(), "y".to_string()];
|
||||
let data = HashMap::new();
|
||||
assert_eq!(select_by_ascending_metric(&models, &data), "x");
|
||||
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["x", "y"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_by_ascending_metric_partial_data() {
|
||||
fn test_rank_by_ascending_metric_partial_data() {
|
||||
let models = vec!["a".to_string(), "b".to_string()];
|
||||
let mut data = HashMap::new();
|
||||
data.insert("b".to_string(), 100.0);
|
||||
assert_eq!(select_by_ascending_metric(&models, &data), "b");
|
||||
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["b", "a"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_model_cheapest() {
|
||||
async fn test_rank_models_cheapest() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
|
|
@ -171,13 +263,13 @@ mod tests {
|
|||
};
|
||||
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
|
||||
let result = service
|
||||
.select_model(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, "gpt-4o-mini");
|
||||
assert_eq!(result, vec!["gpt-4o-mini", "gpt-4o"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_model_fastest() {
|
||||
async fn test_rank_models_fastest() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new(HashMap::new())),
|
||||
latency: Arc::new(RwLock::new({
|
||||
|
|
@ -189,21 +281,57 @@ mod tests {
|
|||
};
|
||||
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
|
||||
let result = service
|
||||
.select_model(&models, &make_policy(SelectionPreference::Fastest))
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Fastest))
|
||||
.await;
|
||||
assert_eq!(result, "claude-sonnet");
|
||||
assert_eq!(result, vec!["claude-sonnet", "gpt-4o"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_model_fallback_no_metrics() {
|
||||
async fn test_rank_models_fallback_no_metrics() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new(HashMap::new())),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["model-a".to_string(), "model-b".to_string()];
|
||||
let result = service
|
||||
.select_model(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, "model-a");
|
||||
assert_eq!(result, vec!["model-a", "model-b"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_partial_data_appended_last() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o".to_string(), 0.005);
|
||||
m
|
||||
})),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["gpt-4o-mini".to_string(), "gpt-4o".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
|
||||
.await;
|
||||
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_rank_models_none_preserves_order() {
|
||||
let service = ModelMetricsService {
|
||||
cost: Arc::new(RwLock::new({
|
||||
let mut m = HashMap::new();
|
||||
m.insert("gpt-4o-mini".to_string(), 0.0001);
|
||||
m.insert("gpt-4o".to_string(), 0.005);
|
||||
m
|
||||
})),
|
||||
latency: Arc::new(RwLock::new(HashMap::new())),
|
||||
};
|
||||
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
|
||||
let result = service
|
||||
.rank_models(&models, &make_policy(SelectionPreference::None))
|
||||
.await;
|
||||
// none → original order, despite gpt-4o-mini being cheaper
|
||||
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -10,6 +10,20 @@ pub enum RoutingModelError {
|
|||
|
||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
/// Internal route descriptor passed to the router model to build its prompt.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
/// Groups a model with its routing preferences (used internally by RouterModelV1).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
}
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(
|
||||
&self,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use common::configuration::{ModelUsagePreference, RoutingPreference};
|
||||
use super::router_model::{ModelUsagePreference, RoutingPreference};
|
||||
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
|
||||
use hermesllm::transforms::lib::ExtractText;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
|
|||
|
|
@ -110,6 +110,8 @@ pub enum SelectionPreference {
|
|||
Cheapest,
|
||||
Fastest,
|
||||
Random,
|
||||
/// Return models in the same order they were defined — no reordering.
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -126,9 +128,25 @@ pub struct TopLevelRoutingPreference {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelMetricsSources {
|
||||
pub url: String,
|
||||
pub refresh_interval: Option<u64>,
|
||||
pub struct MetricsAuth {
|
||||
#[serde(rename = "type")]
|
||||
pub auth_type: String, // only "bearer" supported
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum MetricsSource {
|
||||
CostMetrics {
|
||||
url: String,
|
||||
refresh_interval: Option<u64>,
|
||||
auth: Option<MetricsAuth>,
|
||||
},
|
||||
PrometheusMetrics {
|
||||
url: String,
|
||||
query: String,
|
||||
refresh_interval: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -150,7 +168,7 @@ pub struct Configuration {
|
|||
pub listeners: Vec<Listener>,
|
||||
pub state_storage: Option<StateStorageConfig>,
|
||||
pub routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
|
||||
pub model_metrics_sources: Option<ModelMetricsSources>,
|
||||
pub model_metrics_sources: Option<Vec<MetricsSource>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
|
|
@ -346,18 +364,6 @@ impl LlmProviderType {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize, Debug)]
|
||||
pub struct ModelUsagePreference {
|
||||
pub model: String,
|
||||
pub routing_preferences: Vec<RoutingPreference>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RoutingPreference {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct AgentUsagePreference {
|
||||
pub model: String,
|
||||
|
|
@ -407,7 +413,6 @@ pub struct LlmProvider {
|
|||
pub port: Option<u16>,
|
||||
pub rate_limits: Option<LlmRatelimit>,
|
||||
pub usage: Option<String>,
|
||||
pub routing_preferences: Option<Vec<RoutingPreference>>,
|
||||
pub cluster_name: Option<String>,
|
||||
pub base_url_path_prefix: Option<String>,
|
||||
pub internal: Option<bool>,
|
||||
|
|
@ -451,7 +456,6 @@ impl Default for LlmProvider {
|
|||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
cluster_name: None,
|
||||
base_url_path_prefix: None,
|
||||
internal: None,
|
||||
|
|
|
|||
|
|
@ -274,7 +274,6 @@ mod tests {
|
|||
port: None,
|
||||
rate_limits: None,
|
||||
usage: None,
|
||||
routing_preferences: None,
|
||||
internal: None,
|
||||
stream: None,
|
||||
passthrough_auth: None,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue