merge main into plano-session_pinning

This commit is contained in:
Adil Hafeez 2026-04-02 22:58:47 -07:00
commit f699cfb059
86 changed files with 11996 additions and 8063 deletions

View file

@ -762,7 +762,7 @@ impl ArchFunctionHandler {
// Keep system message if present
if let Some(first) = messages.first() {
if first.role == Role::System {
if first.role == Role::System || first.role == Role::Developer {
if let Some(MessageContent::Text(content)) = &first.content {
num_tokens += content.len() / 4; // Approximate 4 chars per token
}

View file

@ -38,6 +38,8 @@ use crate::tracing::{
};
use model_selection::router_chat_get_upstream_model;
const PERPLEXITY_PROVIDER_PREFIX: &str = "perplexity/";
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
state: Arc<AppState>,
@ -134,7 +136,7 @@ async fn llm_chat_inner(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
} = parsed;
@ -284,7 +286,7 @@ async fn llm_chat_inner(
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
inline_routing_preferences,
)
.await
}
@ -357,7 +359,7 @@ struct PreparedRequest {
temperature: Option<f32>,
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<common::configuration::TopLevelRoutingPreference>>,
client_api: Option<SupportedAPIsFromClient>,
provider_id: hermesllm::ProviderId,
}
@ -386,16 +388,14 @@ async fn parse_and_validate_request(
"request body received"
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
},
)?;
// Extract routing_preferences from request body if present
let (chat_request_bytes, inline_routing_preferences) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes).map_err(|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
})?;
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
warn!(path = %request_path, "unsupported endpoint");
@ -420,7 +420,7 @@ async fn parse_and_validate_request(
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases);
let (provider_id, _) = get_provider_info(llm_providers, &alias_resolved_model).await;
let (provider_id, _, _) = get_provider_info(llm_providers, &alias_resolved_model).await;
// Validate model exists in configuration
if llm_providers
@ -473,7 +473,7 @@ async fn parse_and_validate_request(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
inline_routing_preferences,
client_api,
provider_id,
})
@ -775,7 +775,8 @@ async fn get_upstream_path(
resolved_model: &str,
is_streaming: bool,
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
let (provider_id, base_url_path_prefix, use_unversioned_paths) =
get_provider_info(llm_providers, model_name).await;
let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else {
return request_path.to_string();
@ -787,6 +788,7 @@ async fn get_upstream_path(
resolved_model,
is_streaming,
base_url_path_prefix.as_deref(),
use_unversioned_paths,
)
}
@ -794,21 +796,124 @@ async fn get_upstream_path(
async fn get_provider_info(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
) -> (hermesllm::ProviderId, Option<String>, bool) {
let providers_lock = llm_providers.read().await;
if let Some(provider) = providers_lock.get(model_name) {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
let use_unversioned_paths = provider.name.starts_with(PERPLEXITY_PROVIDER_PREFIX);
return (provider_id, prefix, use_unversioned_paths);
}
if let Some(provider) = providers_lock.default() {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
let use_unversioned_paths = provider.name.starts_with(PERPLEXITY_PROVIDER_PREFIX);
(provider_id, prefix, use_unversioned_paths)
} else {
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
(hermesllm::ProviderId::OpenAI, None, false)
}
}
#[cfg(test)]
mod tests {
use super::{get_provider_info, get_upstream_path};
use common::configuration::{LlmProvider, LlmProviderType};
use common::llm_providers::LlmProviders;
use hermesllm::apis::OpenAIApi;
use hermesllm::clients::SupportedAPIsFromClient;
use std::sync::Arc;
use tokio::sync::RwLock;
fn build_provider(name: &str, model: &str) -> LlmProvider {
LlmProvider {
name: name.to_string(),
provider_interface: LlmProviderType::OpenAI,
access_key: Some("test_key".to_string()),
model: Some(model.to_string()),
default: Some(false),
..Default::default()
}
}
fn providers_lock(providers: Vec<LlmProvider>) -> Arc<RwLock<LlmProviders>> {
Arc::new(RwLock::new(
LlmProviders::try_from(providers).expect("test providers should be valid"),
))
}
#[tokio::test]
async fn test_get_provider_info_marks_perplexity_as_unversioned() {
let providers = providers_lock(vec![build_provider("perplexity/sonar-pro", "sonar-pro")]);
let (provider_id, prefix, use_unversioned_paths) =
get_provider_info(&providers, "perplexity/sonar-pro").await;
assert_eq!(provider_id, hermesllm::ProviderId::OpenAI);
assert_eq!(prefix, None);
assert!(use_unversioned_paths);
}
#[tokio::test]
async fn test_get_upstream_path_for_perplexity_uses_unversioned_chat_endpoint() {
let providers = providers_lock(vec![build_provider("perplexity/sonar-pro", "sonar-pro")]);
let upstream_path = get_upstream_path(
&providers,
"perplexity/sonar-pro",
"/v1/chat/completions",
"sonar-pro",
false,
)
.await;
assert_eq!(upstream_path, "/chat/completions");
}
#[tokio::test]
async fn test_get_upstream_path_for_non_perplexity_keeps_v1_chat_endpoint() {
let providers = providers_lock(vec![build_provider("openai/gpt-4o-mini", "gpt-4o-mini")]);
let upstream_path = get_upstream_path(
&providers,
"openai/gpt-4o-mini",
"/v1/chat/completions",
"gpt-4o-mini",
false,
)
.await;
assert_eq!(upstream_path, "/v1/chat/completions");
}
#[tokio::test]
async fn test_perplexity_with_and_without_versioning_paths() {
let providers = providers_lock(vec![build_provider("perplexity/sonar-pro", "sonar-pro")]);
// This is the path Plano should use for Perplexity (works).
let success_path = get_upstream_path(
&providers,
"perplexity/sonar-pro",
"/v1/chat/completions",
"sonar-pro",
false,
)
.await;
assert_eq!(success_path, "/chat/completions");
// This is the generic OpenAI default path; for Perplexity this would 404.
let fail_path = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions)
.target_endpoint_for_provider(
&hermesllm::ProviderId::OpenAI,
"/v1/chat/completions",
"sonar-pro",
false,
None,
false,
);
assert_eq!(fail_path, "/v1/chat/completions");
assert_ne!(success_path, fail_path);
}
}

View file

@ -1,6 +1,6 @@
use common::configuration::ModelUsagePreference;
use common::configuration::TopLevelRoutingPreference;
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
use hermesllm::{ProviderRequest, ProviderRequestType};
use hermesllm::ProviderRequestType;
use hyper::StatusCode;
use std::sync::Arc;
use tracing::{debug, info, warn};
@ -10,7 +10,10 @@ use crate::streaming::truncate_message;
use crate::tracing::routing;
pub struct RoutingResult {
/// Primary model to use (first in the ranked list).
pub model_name: String,
/// Full ranked list — use subsequent entries as fallbacks on 429/5xx.
pub models: Vec<String>,
pub route_name: Option<String>,
}
@ -39,11 +42,8 @@ pub async fn router_chat_get_upstream_model(
traceparent: &str,
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
// Convert to ChatCompletionsRequest for routing (regardless of input type)
let chat_request = match ProviderRequestType::try_from((
client_request,
@ -78,22 +78,6 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Use inline preferences if provided, otherwise fall back to metadata extraction
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
{
inline_usage_preferences
} else {
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok())
};
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
.messages
@ -107,7 +91,6 @@ pub async fn router_chat_get_upstream_model(
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
info!(
has_usage_preferences = usage_preferences.is_some(),
path = %request_path,
latest_message = %latest_message_for_log,
"processing router request"
@ -121,7 +104,7 @@ pub async fn router_chat_get_upstream_model(
.determine_route(
&chat_request.messages,
traceparent,
usage_preferences,
inline_routing_preferences,
request_id,
)
.await;
@ -132,10 +115,12 @@ pub async fn router_chat_get_upstream_model(
match routing_result {
Ok(route) => match route {
Some((route_name, model_name)) => {
Some((route_name, ranked_models)) => {
let model_name = ranked_models.first().cloned().unwrap_or_default();
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult {
model_name,
models: ranked_models,
route_name: Some(route_name),
})
}
@ -147,6 +132,7 @@ pub async fn router_chat_get_upstream_model(
Ok(RoutingResult {
model_name: "none".to_string(),
models: vec!["none".to_string()],
route_name: None,
})
}

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::configuration::{SpanAttributes, TopLevelRoutingPreference};
use common::consts::{REQUEST_ID_HEADER, SESSION_ID_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
@ -15,56 +15,42 @@ use crate::handlers::llm::model_selection::router_chat_get_upstream_model;
use crate::router::llm::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
/// before re-serializing so downstream parsers don't see the non-standard field.
///
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
/// Extracts `routing_preferences` from a JSON body, returning the cleaned body bytes
/// and the parsed preferences. The field is removed from the JSON before re-serializing
/// so downstream parsers don't see it.
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
) -> Result<(Bytes, Option<Vec<TopLevelRoutingPreference>>), String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let preferences = json_body
let routing_preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
.and_then(|policy_value| {
if warn_on_size {
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
warn!(
size_bytes = policy_str.len(),
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
"routing_policy exceeds recommended size limit"
);
}
}
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
.and_then(|o| o.remove("routing_preferences"))
.and_then(
|value| match serde_json::from_value::<Vec<TopLevelRoutingPreference>>(value) {
Ok(prefs) => {
info!(
num_models = prefs.len(),
"using inline routing_policy from request body"
num_routes = prefs.len(),
"using inline routing_preferences from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_policy");
warn!(error = %err, "failed to parse routing_preferences");
None
}
}
});
},
);
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
Ok((bytes, routing_preferences))
}
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
model: String,
/// Ranked model list — use first, fall back to next on 429/5xx.
models: Vec<String>,
route: Option<String>,
trace_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
@ -148,7 +134,7 @@ async fn routing_decision_inner(
"returning pinned routing decision from cache"
);
let response = RoutingDecisionResponse {
model: cached.model_name,
models: vec![cached.model_name],
route: cached.route_name,
trace_id,
session_id: Some(sid.clone()),
@ -174,8 +160,9 @@ async fn routing_decision_inner(
"routing decision request body received"
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
// Extract routing_preferences from body before parsing as ProviderRequestType
let (chat_request_bytes, inline_routing_preferences) = match extract_routing_policy(&raw_bytes)
{
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
@ -202,14 +189,13 @@ async fn routing_decision_inner(
}
};
// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
Arc::clone(&router_service),
client_request,
&traceparent,
&request_path,
&request_id,
inline_preferences,
inline_routing_preferences,
)
.await;
@ -227,7 +213,7 @@ async fn routing_decision_inner(
}
let response = RoutingDecisionResponse {
model: result.model_name,
models: result.models,
route: result.route_name,
trace_id,
session_id,
@ -235,7 +221,8 @@ async fn routing_decision_inner(
};
info!(
model = %response.model,
primary_model = %response.models.first().map(|s| s.as_str()).unwrap_or("none"),
total_models = response.models.len(),
route = ?response.route,
"routing decision completed"
);
@ -261,6 +248,7 @@ async fn routing_decision_inner(
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
let extra = if extra_fields.is_empty() {
@ -278,95 +266,118 @@ mod tests {
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_valid_policy() {
let policy = r#""routing_policy": [
{
"model": "openai/gpt-4o",
"routing_preferences": [
{"name": "coding", "description": "code generation tasks"}
]
},
{
"model": "openai/gpt-4o-mini",
"routing_preferences": [
{"name": "general", "description": "general questions"}
]
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
assert_eq!(prefs[0].model, "openai/gpt-4o");
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
assert_eq!(prefs[1].routing_preferences[0].name, "general");
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
#[test]
fn extract_routing_policy_invalid_policy_returns_none() {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
// routing_policy should still be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_invalid_json_returns_error() {
let body = b"not valid json";
let result = extract_routing_policy(body, false);
let result = extract_routing_policy(body);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse JSON"));
}
#[test]
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
fn extract_routing_policy_routing_preferences() {
let policy = r#""routing_preferences": [
{
"name": "code generation",
"description": "generate new code",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": "fastest"}
}
]"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
let prefs = prefs.expect("should have parsed routing_preferences");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].name, "code generation");
assert_eq!(prefs[0].models, vec!["openai/gpt-4o", "openai/gpt-4o-mini"]);
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let policy = r#""routing_preferences": [{"name": "test", "description": "test", "models": ["gpt-4o"], "selection_policy": {"prefer": "none"}}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let (cleaned, prefs) = extract_routing_policy(&body).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_policy").is_none());
assert!(cleaned_json.get("routing_preferences").is_none());
}
#[test]
fn extract_routing_policy_prefer_null_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": null}
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs = prefs.expect("should parse routing_preferences when prefer is null");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn extract_routing_policy_selection_policy_missing_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"]
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs =
prefs.expect("should parse routing_preferences when selection_policy is missing");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn extract_routing_policy_prefer_empty_string_defaults_to_none() {
let policy = r#""routing_preferences": [
{
"name": "coding",
"description": "code generation, writing functions, debugging",
"models": ["openai/gpt-4o", "openai/gpt-4o-mini"],
"selection_policy": {"prefer": ""}
}
]"#;
let body = make_chat_body(policy);
let (_cleaned, prefs) = extract_routing_policy(&body).unwrap();
let prefs =
prefs.expect("should parse routing_preferences when selection_policy.prefer is empty");
assert_eq!(prefs.len(), 1);
assert_eq!(prefs[0].selection_policy.prefer, SelectionPreference::None);
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {
model: "openai/gpt-4o".to_string(),
models: vec![
"openai/gpt-4o-mini".to_string(),
"openai/gpt-4o".to_string(),
],
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
session_id: Some("sess-abc".to_string()),
@ -374,7 +385,8 @@ mod tests {
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["models"][0], "openai/gpt-4o-mini");
assert_eq!(parsed["models"][1], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
assert_eq!(parsed["session_id"], "sess-abc");
@ -384,7 +396,7 @@ mod tests {
#[test]
fn routing_decision_response_serialization_no_route() {
let response = RoutingDecisionResponse {
model: "none".to_string(),
models: vec!["none".to_string()],
route: None,
trace_id: "abc123".to_string(),
session_id: None,
@ -392,7 +404,7 @@ mod tests {
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert_eq!(parsed["models"][0], "none");
assert!(parsed["route"].is_null());
assert!(parsed.get("session_id").is_none());
assert_eq!(parsed["pinned"], false);

View file

@ -6,6 +6,7 @@ use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::model_metrics::ModelMetricsService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
@ -40,6 +41,17 @@ const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator";
const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator";
/// Parse a version string like `v0.4.0`, `v0.3.0`, `0.2.0` into a `(major, minor, patch)` tuple.
/// Missing parts default to 0. Non-numeric parts are treated as 0.
fn parse_semver(version: &str) -> (u32, u32, u32) {
let v = version.trim_start_matches('v');
let mut parts = v.splitn(3, '.').map(|p| p.parse::<u32>().unwrap_or(0));
let major = parts.next().unwrap_or(0);
let minor = parts.next().unwrap_or(0);
let patch = parts.next().unwrap_or(0);
(major, minor, patch)
}
/// CORS pre-flight response for the models endpoint.
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut response = Response::new(empty());
@ -163,11 +175,131 @@ async fn init_app_state(
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let session_ttl_seconds = config.routing.as_ref().and_then(|r| r.session_ttl_seconds);
let session_max_entries = config.routing.as_ref().and_then(|r| r.session_max_entries);
// Validate that top-level routing_preferences requires v0.4.0+.
let config_version = parse_semver(&config.version);
let is_v040_plus = config_version >= (0, 4, 0);
if !is_v040_plus && config.routing_preferences.is_some() {
return Err(
"top-level routing_preferences requires version v0.4.0 or above. \
Update the version field or remove routing_preferences."
.into(),
);
}
// Validate that all models referenced in top-level routing_preferences exist in model_providers.
// The CLI renders model_providers with `name` = "openai/gpt-4o" and `model` = "gpt-4o",
// so we accept a match against either field.
if let Some(ref route_prefs) = config.routing_preferences {
let provider_model_names: std::collections::HashSet<&str> = config
.model_providers
.iter()
.flat_map(|p| std::iter::once(p.name.as_str()).chain(p.model.as_deref()))
.collect();
for pref in route_prefs {
for model in &pref.models {
if !provider_model_names.contains(model.as_str()) {
return Err(format!(
"routing_preferences route '{}' references model '{}' \
which is not declared in model_providers",
pref.name, model
)
.into());
}
}
}
}
// Validate and initialize ModelMetricsService if model_metrics_sources is configured.
let metrics_service: Option<Arc<ModelMetricsService>> = if let Some(ref sources) =
config.model_metrics_sources
{
use common::configuration::MetricsSource;
let cost_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::Cost(_)))
.count();
let latency_count = sources
.iter()
.filter(|s| matches!(s, MetricsSource::Latency(_)))
.count();
if cost_count > 1 {
return Err("model_metrics_sources: only one cost metrics source is allowed".into());
}
if latency_count > 1 {
return Err("model_metrics_sources: only one latency metrics source is allowed".into());
}
let svc = ModelMetricsService::new(sources, reqwest::Client::new()).await;
Some(Arc::new(svc))
} else {
None
};
// Validate that selection_policy.prefer is compatible with the configured metric sources.
if let Some(ref prefs) = config.routing_preferences {
use common::configuration::{MetricsSource, SelectionPreference};
let has_cost_source = config
.model_metrics_sources
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| matches!(s, MetricsSource::Cost(_)));
let has_latency_source = config
.model_metrics_sources
.as_deref()
.unwrap_or_default()
.iter()
.any(|s| matches!(s, MetricsSource::Latency(_)));
for pref in prefs {
if pref.selection_policy.prefer == SelectionPreference::Cheapest && !has_cost_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: cheapest but no cost metrics source is configured — \
add a cost metrics source to model_metrics_sources",
pref.name
)
.into());
}
if pref.selection_policy.prefer == SelectionPreference::Fastest && !has_latency_source {
return Err(format!(
"routing_preferences route '{}' uses prefer: fastest but no latency metrics source is configured — \
add a latency metrics source to model_metrics_sources",
pref.name
)
.into());
}
}
}
// Warn about models in routing_preferences that have no matching pricing/latency data.
if let (Some(ref prefs), Some(ref svc)) = (&config.routing_preferences, &metrics_service) {
let cost_data = svc.cost_snapshot().await;
let latency_data = svc.latency_snapshot().await;
for pref in prefs {
use common::configuration::SelectionPreference;
for model in &pref.models {
let missing = match pref.selection_policy.prefer {
SelectionPreference::Cheapest => !cost_data.contains_key(model.as_str()),
SelectionPreference::Fastest => !latency_data.contains_key(model.as_str()),
_ => false,
};
if missing {
warn!(
model = %model,
route = %pref.name,
"model has no metric data — will be ranked last"
);
}
}
}
}
let router_service = Arc::new(RouterService::new(
config.model_providers.clone(),
config.routing_preferences.clone(),
metrics_service,
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,

View file

@ -1,9 +1,11 @@
use std::{collections::HashMap, sync::Arc, time::Duration, time::Instant};
use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
configuration::TopLevelRoutingPreference,
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
};
use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
@ -11,6 +13,7 @@ use tokio::sync::RwLock;
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::model_metrics::ModelMetricsService;
use super::router_model::RouterModel;
use crate::router::router_model_v1;
@ -30,7 +33,8 @@ pub struct RouterService {
client: reqwest::Client,
router_model: Arc<dyn RouterModel>,
routing_provider_name: String,
llm_usage_defined: bool,
top_level_preferences: HashMap<String, TopLevelRoutingPreference>,
metrics_service: Option<Arc<ModelMetricsService>>,
session_cache: RwLock<HashMap<String, CachedRoute>>,
session_ttl: Duration,
session_max_entries: usize,
@ -49,31 +53,39 @@ pub type Result<T> = std::result::Result<T, RoutingError>;
impl RouterService {
pub fn new(
providers: Vec<LlmProvider>,
top_level_prefs: Option<Vec<TopLevelRoutingPreference>>,
metrics_service: Option<Arc<ModelMetricsService>>,
router_url: String,
routing_model_name: String,
routing_provider_name: String,
session_ttl_seconds: Option<u64>,
session_max_entries: Option<usize>,
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
let top_level_preferences: HashMap<String, TopLevelRoutingPreference> = top_level_prefs
.map_or_else(HashMap::new, |prefs| {
prefs.into_iter().map(|p| (p.name.clone(), p)).collect()
});
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
// Build sentinel routes for RouterModelV1: route_name → first model.
// RouterModelV1 uses this to build its prompt; RouterService overrides
// the model selection via rank_models() after the route is determined.
let sentinel_routes: HashMap<String, Vec<RoutingPreference>> = top_level_preferences
.iter()
.filter_map(|provider| {
provider
.routing_preferences
.as_ref()
.map(|prefs| (provider.name.clone(), prefs.clone()))
.filter_map(|(name, pref)| {
pref.models.first().map(|first_model| {
(
first_model.clone(),
vec![RoutingPreference {
name: name.clone(),
description: pref.description.clone(),
}],
)
})
})
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_routes,
sentinel_routes,
routing_model_name,
router_model_v1::MAX_TOKEN_LEN,
));
@ -87,7 +99,8 @@ impl RouterService {
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
top_level_preferences,
metrics_service,
session_cache: RwLock::new(HashMap::new()),
session_ttl,
session_max_entries,
@ -153,24 +166,43 @@ impl RouterService {
&self,
messages: &[Message],
traceparent: &str,
usage_preferences: Option<Vec<ModelUsagePreference>>,
inline_routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
request_id: &str,
) -> Result<Option<(String, String)>> {
) -> Result<Option<(String, Vec<String>)>> {
if messages.is_empty() {
return Ok(None);
}
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.len() < 2)
&& !self.llm_usage_defined
{
// Build inline top-level map from request if present (inline overrides config).
let inline_top_map: Option<HashMap<String, TopLevelRoutingPreference>> =
inline_routing_preferences
.map(|prefs| prefs.into_iter().map(|p| (p.name.clone(), p)).collect());
// No routing defined — skip the router call entirely.
if inline_top_map.is_none() && self.top_level_preferences.is_empty() {
return Ok(None);
}
// For inline overrides, build synthetic ModelUsagePreference list so RouterModelV1
// generates the correct prompt (route name + description pairs).
// For config-level prefs the sentinel routes are already baked into RouterModelV1.
let effective_usage_preferences: Option<Vec<ModelUsagePreference>> =
inline_top_map.as_ref().map(|inline_map| {
inline_map
.values()
.map(|p| ModelUsagePreference {
model: p.models.first().cloned().unwrap_or_default(),
routing_preferences: vec![RoutingPreference {
name: p.name.clone(),
description: p.description.clone(),
}],
})
.collect()
});
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
.generate_request(messages, &effective_usage_preferences);
debug!(
model = %self.router_model.get_model_name(),
@ -210,18 +242,38 @@ impl RouterService {
return Ok(None);
};
// Parse the route name from the router response.
let parsed = self
.router_model
.parse_response(&content, &usage_preferences)?;
.parse_response(&content, &effective_usage_preferences)?;
let result = if let Some((route_name, _sentinel)) = parsed {
let top_pref = inline_top_map
.as_ref()
.and_then(|m| m.get(&route_name))
.or_else(|| self.top_level_preferences.get(&route_name));
if let Some(pref) = top_pref {
let ranked = match &self.metrics_service {
Some(svc) => svc.rank_models(&pref.models, &pref.selection_policy).await,
None => pref.models.clone(),
};
Some((route_name, ranked))
} else {
None
}
} else {
None
};
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed,
selected_model = ?result,
response_time_ms = elapsed.as_millis(),
"arch-router determined route"
);
Ok(parsed)
Ok(result)
}
}
@ -231,7 +283,8 @@ mod tests {
fn make_router_service(ttl_seconds: u64, max_entries: usize) -> RouterService {
RouterService::new(
vec![],
None,
None,
"http://localhost:12001/v1/chat/completions".to_string(),
"Arch-Router".to_string(),
"arch-router".to_string(),

View file

@ -1,5 +1,6 @@
pub(crate) mod http;
pub mod llm;
pub mod model_metrics;
pub mod orchestrator;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;

View file

@ -0,0 +1,388 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use common::configuration::{
CostProvider, LatencyProvider, MetricsSource, SelectionPolicy, SelectionPreference,
};
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
const DO_PRICING_URL: &str = "https://api.digitalocean.com/v2/gen-ai/models/catalog";
pub struct ModelMetricsService {
cost: Arc<RwLock<HashMap<String, f64>>>,
latency: Arc<RwLock<HashMap<String, f64>>>,
}
impl ModelMetricsService {
pub async fn new(sources: &[MetricsSource], client: reqwest::Client) -> Self {
let cost_data = Arc::new(RwLock::new(HashMap::new()));
let latency_data = Arc::new(RwLock::new(HashMap::new()));
for source in sources {
match source {
MetricsSource::Cost(cfg) => match cfg.provider {
CostProvider::Digitalocean => {
let aliases = cfg.model_aliases.clone().unwrap_or_default();
let data = fetch_do_pricing(&client, &aliases).await;
info!(models = data.len(), "fetched digitalocean pricing");
*cost_data.write().await = data;
if let Some(interval_secs) = cfg.refresh_interval {
let cost_clone = Arc::clone(&cost_data);
let client_clone = client.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data = fetch_do_pricing(&client_clone, &aliases).await;
info!(models = data.len(), "refreshed digitalocean pricing");
*cost_clone.write().await = data;
}
});
}
}
},
MetricsSource::Latency(cfg) => match cfg.provider {
LatencyProvider::Prometheus => {
let data = fetch_prometheus_metrics(&cfg.url, &cfg.query, &client).await;
info!(models = data.len(), url = %cfg.url, "fetched latency metrics");
*latency_data.write().await = data;
if let Some(interval_secs) = cfg.refresh_interval {
let latency_clone = Arc::clone(&latency_data);
let client_clone = client.clone();
let url = cfg.url.clone();
let query = cfg.query.clone();
let interval = Duration::from_secs(interval_secs);
tokio::spawn(async move {
loop {
tokio::time::sleep(interval).await;
let data =
fetch_prometheus_metrics(&url, &query, &client_clone).await;
info!(models = data.len(), url = %url, "refreshed latency metrics");
*latency_clone.write().await = data;
}
});
}
}
},
}
}
ModelMetricsService {
cost: cost_data,
latency: latency_data,
}
}
/// Rank `models` by `policy`, returning them in preference order.
/// Models with no metric data are appended at the end in their original order.
pub async fn rank_models(&self, models: &[String], policy: &SelectionPolicy) -> Vec<String> {
let cost_data = self.cost.read().await;
let latency_data = self.latency.read().await;
debug!(
input_models = ?models,
cost_data = ?cost_data.iter().collect::<Vec<_>>(),
latency_data = ?latency_data.iter().collect::<Vec<_>>(),
prefer = ?policy.prefer,
"rank_models called"
);
match policy.prefer {
SelectionPreference::Cheapest => {
for m in models {
if !cost_data.contains_key(m.as_str()) {
warn!(model = %m, "no cost data for model — ranking last (prefer: cheapest)");
}
}
rank_by_ascending_metric(models, &cost_data)
}
SelectionPreference::Fastest => {
for m in models {
if !latency_data.contains_key(m.as_str()) {
warn!(model = %m, "no latency data for model — ranking last (prefer: fastest)");
}
}
rank_by_ascending_metric(models, &latency_data)
}
SelectionPreference::None => models.to_vec(),
}
}
/// Returns a snapshot of the current cost data. Used at startup to warn about unmatched models.
pub async fn cost_snapshot(&self) -> HashMap<String, f64> {
self.cost.read().await.clone()
}
/// Returns a snapshot of the current latency data. Used at startup to warn about unmatched models.
pub async fn latency_snapshot(&self) -> HashMap<String, f64> {
self.latency.read().await.clone()
}
}
fn rank_by_ascending_metric(models: &[String], data: &HashMap<String, f64>) -> Vec<String> {
let mut with_data: Vec<(&String, f64)> = models
.iter()
.filter_map(|m| {
let v = *data.get(m.as_str())?;
if v.is_nan() {
None
} else {
Some((m, v))
}
})
.collect();
with_data.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let without_data: Vec<&String> = models
.iter()
.filter(|m| data.get(m.as_str()).is_none_or(|v| v.is_nan()))
.collect();
with_data
.iter()
.map(|(m, _)| (*m).clone())
.chain(without_data.iter().map(|m| (*m).clone()))
.collect()
}
#[derive(serde::Deserialize)]
struct DoModelList {
data: Vec<DoModel>,
}
#[derive(serde::Deserialize)]
struct DoModel {
model_id: String,
pricing: Option<DoPricing>,
}
#[derive(serde::Deserialize)]
struct DoPricing {
input_price_per_million: Option<f64>,
output_price_per_million: Option<f64>,
}
async fn fetch_do_pricing(
client: &reqwest::Client,
aliases: &HashMap<String, String>,
) -> HashMap<String, f64> {
match client.get(DO_PRICING_URL).send().await {
Ok(resp) => match resp.json::<DoModelList>().await {
Ok(list) => list
.data
.into_iter()
.filter_map(|m| {
let pricing = m.pricing?;
let raw_key = m.model_id.clone();
let key = aliases.get(&raw_key).cloned().unwrap_or(raw_key);
let cost = pricing.input_price_per_million.unwrap_or(0.0)
+ pricing.output_price_per_million.unwrap_or(0.0);
Some((key, cost))
})
.collect(),
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to parse digitalocean pricing response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = DO_PRICING_URL, "failed to fetch digitalocean pricing");
HashMap::new()
}
}
}
#[derive(serde::Deserialize)]
struct PrometheusResponse {
data: PrometheusData,
}
#[derive(serde::Deserialize)]
struct PrometheusData {
result: Vec<PrometheusResult>,
}
#[derive(serde::Deserialize)]
struct PrometheusResult {
metric: HashMap<String, String>,
value: (f64, String), // (timestamp, value_str)
}
async fn fetch_prometheus_metrics(
url: &str,
query: &str,
client: &reqwest::Client,
) -> HashMap<String, f64> {
let query_url = format!("{}/api/v1/query", url.trim_end_matches('/'));
match client
.get(&query_url)
.query(&[("query", query)])
.send()
.await
{
Ok(resp) => match resp.json::<PrometheusResponse>().await {
Ok(prom) => prom
.data
.result
.into_iter()
.filter_map(|r| {
let model_name = r.metric.get("model_name")?.clone();
let value: f64 = r.value.1.parse().ok()?;
Some((model_name, value))
})
.collect(),
Err(err) => {
warn!(error = %err, url = %query_url, "failed to parse prometheus response");
HashMap::new()
}
},
Err(err) => {
warn!(error = %err, url = %query_url, "failed to fetch prometheus metrics");
HashMap::new()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::SelectionPreference;
fn make_policy(prefer: SelectionPreference) -> SelectionPolicy {
SelectionPolicy { prefer }
}
#[test]
fn test_rank_by_ascending_metric_picks_lowest_first() {
let models = vec!["a".to_string(), "b".to_string(), "c".to_string()];
let mut data = HashMap::new();
data.insert("a".to_string(), 0.01);
data.insert("b".to_string(), 0.005);
data.insert("c".to_string(), 0.02);
assert_eq!(
rank_by_ascending_metric(&models, &data),
vec!["b", "a", "c"]
);
}
#[test]
fn test_rank_by_ascending_metric_no_data_preserves_order() {
let models = vec!["x".to_string(), "y".to_string()];
let data = HashMap::new();
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["x", "y"]);
}
#[test]
fn test_rank_by_ascending_metric_partial_data() {
let models = vec!["a".to_string(), "b".to_string()];
let mut data = HashMap::new();
data.insert("b".to_string(), 100.0);
assert_eq!(rank_by_ascending_metric(&models, &data), vec!["b", "a"]);
}
#[tokio::test]
async fn test_rank_models_cheapest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m.insert("gpt-4o-mini".to_string(), 0.0001);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o-mini", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fastest() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 200.0);
m.insert("claude-sonnet".to_string(), 120.0);
m
})),
};
let models = vec!["gpt-4o".to_string(), "claude-sonnet".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Fastest))
.await;
assert_eq!(result, vec!["claude-sonnet", "gpt-4o"]);
}
#[tokio::test]
async fn test_rank_models_fallback_no_metrics() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new(HashMap::new())),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["model-a".to_string(), "model-b".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["model-a", "model-b"]);
}
#[tokio::test]
async fn test_rank_models_partial_data_appended_last() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o-mini".to_string(), "gpt-4o".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::Cheapest))
.await;
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[tokio::test]
async fn test_rank_models_none_preserves_order() {
let service = ModelMetricsService {
cost: Arc::new(RwLock::new({
let mut m = HashMap::new();
m.insert("gpt-4o-mini".to_string(), 0.0001);
m.insert("gpt-4o".to_string(), 0.005);
m
})),
latency: Arc::new(RwLock::new(HashMap::new())),
};
let models = vec!["gpt-4o".to_string(), "gpt-4o-mini".to_string()];
let result = service
.rank_models(&models, &make_policy(SelectionPreference::None))
.await;
// none → original order, despite gpt-4o-mini being cheaper
assert_eq!(result, vec!["gpt-4o", "gpt-4o-mini"]);
}
#[test]
fn test_rank_by_ascending_metric_nan_treated_as_missing() {
let models = vec![
"a".to_string(),
"b".to_string(),
"c".to_string(),
"d".to_string(),
];
let mut data = HashMap::new();
data.insert("a".to_string(), f64::NAN);
data.insert("b".to_string(), 0.5);
data.insert("c".to_string(), 0.1);
// "d" has no entry at all
let result = rank_by_ascending_metric(&models, &data);
// c (0.1) < b (0.5), then NaN "a" and missing "d" appended in original order
assert_eq!(result, vec!["c", "b", "a", "d"]);
}
}

View file

@ -183,6 +183,7 @@ impl OrchestratorModel for OrchestratorModelV1 {
.iter()
.filter(|m| {
m.role != Role::System
&& m.role != Role::Developer
&& m.role != Role::Tool
&& !m.content.extract_text().is_empty()
})

View file

@ -1,5 +1,5 @@
use common::configuration::ModelUsagePreference;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
@ -10,6 +10,20 @@ pub enum RoutingModelError {
pub type Result<T> = std::result::Result<T, RoutingModelError>;
/// Internal route descriptor passed to the router model to build its prompt.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPreference {
pub name: String,
pub description: String,
}
/// Groups a model with its routing preferences (used internally by RouterModelV1).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUsagePreference {
pub model: String,
pub routing_preferences: Vec<RoutingPreference>,
}
pub trait RouterModel: Send + Sync {
fn generate_request(
&self,

View file

@ -1,6 +1,6 @@
use std::collections::HashMap;
use common::configuration::{ModelUsagePreference, RoutingPreference};
use super::router_model::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hermesllm::transforms::lib::ExtractText;
use serde::{Deserialize, Serialize};
@ -80,6 +80,7 @@ impl RouterModel for RouterModelV1 {
.iter()
.filter(|m| {
m.role != Role::System
&& m.role != Role::Developer
&& m.role != Role::Tool
&& !m.content.extract_text().is_empty()
})

View file

@ -2,7 +2,6 @@ use crate::{
configuration::LlmProvider,
consts::{ARCH_FC_MODEL_NAME, ASSISTANT_ROLE},
};
use core::{panic, str};
use serde::{ser::SerializeMap, Deserialize, Serialize};
use std::{
collections::{HashMap, VecDeque},
@ -193,7 +192,7 @@ impl Display for ContentType {
// skip image URLs or their data in text representation
None
} else {
panic!("Unsupported content type: {:?}", part.content_type);
None
}
})
.collect();

View file

@ -1,5 +1,5 @@
use hermesllm::apis::openai::{ModelDetail, ModelObject, Models};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::fmt::Display;
@ -112,6 +112,77 @@ pub enum StateStorageType {
Postgres,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum SelectionPreference {
Cheapest,
Fastest,
/// Return models in the same order they were defined — no reordering.
#[default]
#[serde(alias = "")]
None,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct SelectionPolicy {
#[serde(default, deserialize_with = "deserialize_selection_preference")]
pub prefer: SelectionPreference,
}
fn deserialize_selection_preference<'de, D>(
deserializer: D,
) -> Result<SelectionPreference, D::Error>
where
D: Deserializer<'de>,
{
Ok(Option::<SelectionPreference>::deserialize(deserializer)?.unwrap_or_default())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TopLevelRoutingPreference {
pub name: String,
pub description: String,
pub models: Vec<String>,
#[serde(default)]
pub selection_policy: SelectionPolicy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum MetricsSource {
Cost(CostMetricsConfig),
Latency(LatencyMetricsConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CostMetricsConfig {
pub provider: CostProvider,
pub refresh_interval: Option<u64>,
/// Map DO catalog keys (`lowercase(creator)/model_id`) to Plano model names.
/// Example: `openai/openai-gpt-oss-120b: openai/gpt-4o`
pub model_aliases: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CostProvider {
Digitalocean,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LatencyMetricsConfig {
pub provider: LatencyProvider,
pub url: String,
pub query: String,
pub refresh_interval: Option<u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LatencyProvider {
Prometheus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
@ -131,6 +202,8 @@ pub struct Configuration {
pub filters: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
pub state_storage: Option<StateStorageConfig>,
pub routing_preferences: Option<Vec<TopLevelRoutingPreference>>,
pub model_metrics_sources: Option<Vec<MetricsSource>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
@ -246,6 +319,8 @@ pub enum TimeUnit {
Minute,
#[serde(rename = "hour")]
Hour,
#[serde(rename = "day")]
Day,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
@ -326,18 +401,6 @@ impl LlmProviderType {
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub model: String,
pub routing_preferences: Vec<RoutingPreference>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoutingPreference {
pub name: String,
pub description: String,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct AgentUsagePreference {
pub model: String,
@ -387,7 +450,6 @@ pub struct LlmProvider {
pub port: Option<u16>,
pub rate_limits: Option<LlmRatelimit>,
pub usage: Option<String>,
pub routing_preferences: Option<Vec<RoutingPreference>>,
pub cluster_name: Option<String>,
pub base_url_path_prefix: Option<String>,
pub internal: Option<bool>,
@ -431,7 +493,6 @@ impl Default for LlmProvider {
port: None,
rate_limits: None,
usage: None,
routing_preferences: None,
cluster_name: None,
base_url_path_prefix: None,
internal: None,

View file

@ -75,7 +75,10 @@ pub trait Client: Context {
fn add_call_context(&self, id: u32, call_context: Self::CallContext) {
let callouts = self.callouts();
if callouts.borrow_mut().insert(id, call_context).is_some() {
panic!("Duplicate http call with id={}", id);
log::warn!(
"Duplicate http call with id={}, previous context overwritten",
id
);
}
self.active_http_calls().increment(1);
}

View file

@ -274,7 +274,6 @@ mod tests {
port: None,
rate_limits: None,
usage: None,
routing_preferences: None,
internal: None,
stream: None,
passthrough_auth: None,

View file

@ -73,7 +73,10 @@ impl RatelimitMap {
match new_ratelimit_map.datastore.get_mut(&ratelimit_config.model) {
Some(limits) => match limits.get_mut(&ratelimit_config.selector) {
Some(_) => {
panic!("repeated selector. Selectors per provider must be unique")
log::error!(
"repeated selector for model '{}'. Selectors per provider must be unique, skipping duplicate",
ratelimit_config.model
);
}
None => {
limits.insert(ratelimit_config.selector, limit);
@ -150,6 +153,10 @@ fn get_quota(limit: Limit) -> Quota {
TimeUnit::Second => Quota::per_second(tokens),
TimeUnit::Minute => Quota::per_minute(tokens),
TimeUnit::Hour => Quota::per_hour(tokens),
TimeUnit::Day => {
let per_hour = limit.tokens.saturating_div(24).max(1);
Quota::per_hour(NonZero::new(per_hour).expect("per_hour must be positive"))
}
}
}

View file

@ -572,7 +572,9 @@ impl ProviderRequest for MessagesRequest {
let mut regular_messages = Vec::new();
for msg in messages {
if msg.role == crate::apis::openai::Role::System {
if msg.role == crate::apis::openai::Role::System
|| msg.role == crate::apis::openai::Role::Developer
{
system_messages.push(msg.clone());
} else {
regular_messages.push(msg.clone());

View file

@ -150,6 +150,7 @@ pub enum Role {
User,
Assistant,
Tool,
Developer,
}
#[skip_serializing_none]
@ -736,6 +737,7 @@ impl ProviderStreamResponse for ChatCompletionsStreamResponse {
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
Role::Developer => "developer",
})
})
}

View file

@ -92,6 +92,7 @@ impl SupportedAPIsFromClient {
model_id: &str,
is_streaming: bool,
base_url_path_prefix: Option<&str>,
use_unversioned_paths: bool,
) -> String {
// Helper function to build endpoint with optional prefix override
let build_endpoint = |provider_prefix: &str, suffix: &str| -> String {
@ -161,7 +162,13 @@ impl SupportedAPIsFromClient {
build_endpoint("/v1", endpoint_suffix)
}
}
_ => build_endpoint("/v1", endpoint_suffix),
_ => {
if use_unversioned_paths {
build_endpoint("", endpoint_suffix)
} else {
build_endpoint("/v1", endpoint_suffix)
}
}
}
};
@ -343,7 +350,8 @@ mod tests {
"/v1/chat/completions",
"gpt-4",
false,
None
None,
false
),
"/v1/chat/completions"
);
@ -355,7 +363,8 @@ mod tests {
"/v1/chat/completions",
"llama2",
false,
None
None,
false
),
"/openai/v1/chat/completions"
);
@ -367,7 +376,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
None
None,
false
),
"/api/paas/v4/chat/completions"
);
@ -379,7 +389,8 @@ mod tests {
"/v1/chat/completions",
"qwen-turbo",
false,
None
None,
false
),
"/compatible-mode/v1/chat/completions"
);
@ -391,7 +402,8 @@ mod tests {
"/v1/chat/completions",
"gpt-4",
false,
None
None,
false
),
"/openai/deployments/gpt-4/chat/completions?api-version=2025-01-01-preview"
);
@ -403,12 +415,30 @@ mod tests {
"/v1/chat/completions",
"gemini-pro",
false,
None
None,
false
),
"/v1beta/openai/chat/completions"
);
}
#[test]
fn test_target_endpoint_unversioned_paths() {
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
assert_eq!(
api.target_endpoint_for_provider(
&ProviderId::OpenAI,
"/v1/chat/completions",
"sonar-pro",
false,
None,
true
),
"/chat/completions"
);
}
#[test]
fn test_target_endpoint_with_base_url_prefix() {
let api = SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions);
@ -420,7 +450,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
Some("/api/coding/paas/v4")
Some("/api/coding/paas/v4"),
false
),
"/api/coding/paas/v4/chat/completions"
);
@ -432,7 +463,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
Some("api/coding/paas/v4")
Some("api/coding/paas/v4"),
false
),
"/api/coding/paas/v4/chat/completions"
);
@ -444,7 +476,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
Some("/api/coding/paas/v4/")
Some("/api/coding/paas/v4/"),
false
),
"/api/coding/paas/v4/chat/completions"
);
@ -456,7 +489,8 @@ mod tests {
"/v1/chat/completions",
"gpt-4",
false,
Some("/custom/api/v2")
Some("/custom/api/v2"),
false
),
"/custom/api/v2/chat/completions"
);
@ -468,7 +502,8 @@ mod tests {
"/v1/chat/completions",
"llama2",
false,
Some("/api/v2")
Some("/api/v2"),
false
),
"/api/v2/v1/chat/completions"
);
@ -485,7 +520,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
Some("/")
Some("/"),
false
),
"/api/paas/v4/chat/completions"
);
@ -497,7 +533,8 @@ mod tests {
"/v1/chat/completions",
"chatglm",
false,
None
None,
false
),
"/api/paas/v4/chat/completions"
);
@ -514,7 +551,8 @@ mod tests {
"/v1/messages",
"us.amazon.nova-pro-v1:0",
false,
None
None,
false
),
"/model/us.amazon.nova-pro-v1:0/converse"
);
@ -526,7 +564,8 @@ mod tests {
"/v1/messages",
"us.amazon.nova-pro-v1:0",
true,
None
None,
false
),
"/model/us.amazon.nova-pro-v1:0/converse-stream"
);
@ -538,7 +577,8 @@ mod tests {
"/v1/messages",
"us.amazon.nova-pro-v1:0",
false,
Some("/custom/path")
Some("/custom/path"),
false
),
"/custom/path/model/us.amazon.nova-pro-v1:0/converse"
);
@ -550,7 +590,8 @@ mod tests {
"/v1/messages",
"us.amazon.nova-pro-v1:0",
true,
Some("/custom/path")
Some("/custom/path"),
false
),
"/custom/path/model/us.amazon.nova-pro-v1:0/converse-stream"
);
@ -567,7 +608,8 @@ mod tests {
"/v1/messages",
"claude-3-opus",
false,
None
None,
false
),
"/v1/messages"
);
@ -579,7 +621,8 @@ mod tests {
"/v1/messages",
"claude-3-opus",
false,
Some("/api/v2")
Some("/api/v2"),
false
),
"/api/v2/messages"
);
@ -596,7 +639,8 @@ mod tests {
"/custom/path",
"llama2",
false,
None
None,
false
),
"/v1/chat/completions"
);
@ -608,7 +652,8 @@ mod tests {
"/custom/path",
"chatglm",
false,
None
None,
false
),
"/v1/chat/completions"
);
@ -620,7 +665,8 @@ mod tests {
"/custom/path",
"chatglm",
false,
Some("/api/v2")
Some("/api/v2"),
false
),
"/api/v2/chat/completions"
);
@ -637,7 +683,8 @@ mod tests {
"/v1/chat/completions",
"gpt-4-deployment",
false,
None
None,
false
),
"/openai/deployments/gpt-4-deployment/chat/completions?api-version=2025-01-01-preview"
);
@ -649,7 +696,8 @@ mod tests {
"/v1/chat/completions",
"gpt-4-deployment",
false,
Some("/custom/azure/path")
Some("/custom/azure/path"),
false
),
"/custom/azure/path/gpt-4-deployment/chat/completions?api-version=2025-01-01-preview"
);
@ -664,7 +712,8 @@ mod tests {
"/v1/responses",
"grok-4-1-fast-reasoning",
false,
None
None,
false
),
"/v1/responses"
);

View file

@ -97,7 +97,7 @@ impl TryFrom<ResponsesInputConverter> for Vec<Message> {
MessageRole::User => Role::User,
MessageRole::Assistant => Role::Assistant,
MessageRole::System => Role::System,
MessageRole::Developer => Role::System, // Map developer to system
MessageRole::Developer => Role::Developer,
MessageRole::Tool => Role::Tool,
};
@ -281,7 +281,7 @@ impl TryFrom<Message> for MessagesMessage {
]),
});
}
Role::System => {
Role::System | Role::Developer => {
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately".to_string(),
));
@ -303,7 +303,7 @@ impl TryFrom<Message> for BedrockMessage {
Role::User => ConversationRole::User,
Role::Assistant => ConversationRole::Assistant,
Role::Tool => ConversationRole::User, // Tool results become user messages in Bedrock
Role::System => {
Role::System | Role::Developer => {
return Err(TransformError::UnsupportedConversion(
"System messages should be handled separately in Bedrock".to_string(),
));
@ -452,7 +452,7 @@ impl TryFrom<Message> for BedrockMessage {
},
});
}
Role::System => {
Role::System | Role::Developer => {
// Already handled above with early return
unreachable!()
}
@ -706,7 +706,7 @@ impl TryFrom<ChatCompletionsRequest> for AnthropicMessagesRequest {
for message in req.messages {
match message.role {
Role::System => {
Role::System | Role::Developer => {
system_prompt = Some(message.into());
}
_ => {
@ -755,7 +755,7 @@ impl TryFrom<ChatCompletionsRequest> for ConverseRequest {
for message in req.messages {
match message.role {
Role::System => {
Role::System | Role::Developer => {
let system_text = message.content.extract_text();
system_messages.push(SystemContentBlock::Text { text: system_text });
}

View file

@ -95,6 +95,7 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
Role::Assistant => "assistant".to_string(),
Role::System => "system".to_string(),
Role::Tool => "tool".to_string(),
Role::Developer => "developer".to_string(),
},
content,
});

View file

@ -122,6 +122,7 @@ impl StreamContext {
.unwrap_or(&"".to_string()),
self.streaming_response,
self.llm_provider().base_url_path_prefix.as_deref(),
self.llm_provider().name.starts_with("perplexity/"),
);
if target_endpoint != request_path {
self.set_http_request_header(":path", Some(&target_endpoint));