merge main and resolve conflicts

This commit is contained in:
Adil Hafeez 2026-03-11 18:57:36 +00:00
commit bcb7f60005
26 changed files with 2145 additions and 213 deletions

View file

@ -18,7 +18,7 @@ use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
mod router;
pub(crate) mod router;
use crate::app_state::AppState;
use crate::handlers::request::extract_request_id;
@ -120,6 +120,7 @@ async fn llm_chat_inner(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
} = parsed;
// Record LLM-specific span attributes
@ -186,6 +187,7 @@ async fn llm_chat_inner(
&traceparent,
&request_path,
&request_id,
inline_routing_policy,
)
.await
}
@ -245,6 +247,7 @@ struct PreparedRequest {
temperature: Option<f32>,
tool_names: Option<Vec<String>>,
user_message_preview: Option<String>,
inline_routing_policy: Option<Vec<common::configuration::ModelUsagePreference>>,
}
/// Parse the body, resolve the model alias, and validate the model exists.
@ -256,7 +259,7 @@ async fn parse_and_validate_request(
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: &Arc<RwLock<LlmProviders>>,
) -> Result<PreparedRequest, Response<BoxBody<Bytes, hyper::Error>>> {
let chat_request_bytes = request
let raw_bytes = request
.collect()
.await
.map_err(|_| {
@ -267,10 +270,21 @@ async fn parse_and_validate_request(
.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
body = %String::from_utf8_lossy(&raw_bytes),
"request body received"
);
// Extract routing_policy from request body if present
let (chat_request_bytes, inline_routing_policy) =
crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err(
|err| {
warn!(error = %err, "failed to parse request JSON");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
},
)?;
let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| {
warn!(path = %request_path, "unsupported endpoint");
let mut r = Response::new(full(format!("Unsupported endpoint: {}", request_path)));
@ -296,6 +310,7 @@ async fn parse_and_validate_request(
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases);
let (provider_id, _) = get_provider_info(llm_providers, &alias_resolved_model).await;
// Validate model exists in configuration
if llm_providers
@ -332,6 +347,14 @@ async fn parse_and_validate_request(
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("removed archgw_preference_config from metadata");
}
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
if let Some(ref client_api_kind) = client_api {
let upstream_api =
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
client_request.normalize_for_upstream(provider_id, &upstream_api);
}
Ok(PreparedRequest {
client_request,
@ -344,6 +367,7 @@ async fn parse_and_validate_request(
temperature,
tool_names,
user_message_preview,
inline_routing_policy,
})
}

View file

@ -10,6 +10,7 @@ use crate::tracing::routing;
pub struct RoutingResult {
pub model_name: String,
pub route_name: Option<String>,
}
pub struct RoutingError {
@ -37,6 +38,7 @@ pub async fn router_chat_get_upstream_model(
traceparent: &str,
request_path: &str,
request_id: &str,
inline_usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<RoutingResult, RoutingError> {
// Clone metadata for routing before converting (which consumes client_request)
let routing_metadata = client_request.metadata().clone();
@ -75,16 +77,21 @@ pub async fn router_chat_get_upstream_model(
"router request"
);
// Extract usage preferences from metadata
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
// Use inline preferences if provided, otherwise fall back to metadata extraction
let usage_preferences: Option<Vec<ModelUsagePreference>> = if inline_usage_preferences.is_some()
{
inline_usage_preferences
} else {
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("plano_preference_config")
.map(|value| value.to_string())
});
usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok())
};
// Prepare log message with latest message from chat request
let latest_message_for_log = chat_request
@ -133,9 +140,12 @@ pub async fn router_chat_get_upstream_model(
match routing_result {
Ok(route) => match route {
Some((_, model_name)) => {
Some((route_name, model_name)) => {
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult { model_name })
Ok(RoutingResult {
model_name,
route_name: Some(route_name),
})
}
None => {
// No route determined, return sentinel value "none"
@ -145,6 +155,7 @@ pub async fn router_chat_get_upstream_model(
Ok(RoutingResult {
model_name: "none".to_string(),
route_name: None,
})
}
},

View file

@ -5,6 +5,7 @@ pub mod llm;
pub mod models;
pub mod request;
pub mod response;
pub mod routing_service;
pub mod utils;
#[cfg(test)]

View file

@ -0,0 +1,357 @@
use bytes::Bytes;
use common::configuration::{ModelUsagePreference, SpanAttributes};
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use std::sync::Arc;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::llm::router::router_chat_get_upstream_model;
use crate::router::llm::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120;
/// Extracts `routing_policy` from a JSON body, returning the cleaned body bytes
/// and parsed preferences. The `routing_policy` field is removed from the JSON
/// before re-serializing so downstream parsers don't see the non-standard field.
///
/// If `warn_on_size` is true, logs a warning when the serialized policy exceeds 5KB.
pub fn extract_routing_policy(
raw_bytes: &[u8],
warn_on_size: bool,
) -> Result<(Bytes, Option<Vec<ModelUsagePreference>>), String> {
let mut json_body: serde_json::Value = serde_json::from_slice(raw_bytes)
.map_err(|err| format!("Failed to parse JSON: {}", err))?;
let preferences = json_body
.as_object_mut()
.and_then(|obj| obj.remove("routing_policy"))
.and_then(|policy_value| {
if warn_on_size {
let policy_str = serde_json::to_string(&policy_value).unwrap_or_default();
if policy_str.len() > ROUTING_POLICY_SIZE_WARNING_BYTES {
warn!(
size_bytes = policy_str.len(),
limit_bytes = ROUTING_POLICY_SIZE_WARNING_BYTES,
"routing_policy exceeds recommended size limit"
);
}
}
match serde_json::from_value::<Vec<ModelUsagePreference>>(policy_value) {
Ok(prefs) => {
info!(
num_models = prefs.len(),
"using inline routing_policy from request body"
);
Some(prefs)
}
Err(err) => {
warn!(error = %err, "failed to parse routing_policy");
None
}
}
});
let bytes = Bytes::from(serde_json::to_vec(&json_body).unwrap());
Ok((bytes, preferences))
}
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
model: String,
route: Option<String>,
trace_id: String,
}
pub async fn routing_decision(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_path: String,
span_attributes: Arc<Option<SpanAttributes>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_headers = request.headers().clone();
let request_id: String = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
let request_span = info_span!(
"routing_decision",
component = "routing",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
);
routing_decision_inner(
request,
router_service,
request_id,
request_path,
request_headers,
custom_attrs,
)
.instrument(request_span)
.await
}
async fn routing_decision_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_id: String,
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent
let traceparent: String = match request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(tp) => tp,
None => {
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %generated_tp,
"TRACE_PARENT header missing, generated new traceparent"
);
generated_tp
}
};
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
let trace_id = traceparent
.split('-')
.nth(1)
.unwrap_or("unknown")
.to_string();
// Parse request body
let raw_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&raw_bytes),
"routing decision request body received"
);
// Extract routing_policy from request body before parsing as ProviderRequestType
let (chat_request_bytes, inline_preferences) = match extract_routing_policy(&raw_bytes, true) {
Ok(result) => result,
Err(err) => {
warn!(error = %err, "failed to parse request JSON");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request JSON: {}",
err
))
.into_response());
}
};
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(error = %err, "failed to parse request for routing decision");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
// Call the existing routing logic with inline preferences
let routing_result = router_chat_get_upstream_model(
router_service,
client_request,
&traceparent,
&request_path,
&request_id,
inline_preferences,
)
.await;
match routing_result {
Ok(result) => {
let response = RoutingDecisionResponse {
model: result.model_name,
route: result.route_name,
trace_id,
};
info!(
model = %response.model,
route = ?response.route,
"routing decision completed"
);
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap())
}
Err(err) => {
warn!(error = %err.message, "routing decision failed");
Ok(BrightStaffError::InternalServerError(err.message).into_response())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_chat_body(extra_fields: &str) -> Vec<u8> {
let extra = if extra_fields.is_empty() {
String::new()
} else {
format!(", {}", extra_fields)
};
format!(
r#"{{"model": "gpt-4o-mini", "messages": [{{"role": "user", "content": "hello"}}]{}}}"#,
extra
)
.into_bytes()
}
#[test]
fn extract_routing_policy_no_policy() {
let body = make_chat_body("");
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_none());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_valid_policy() {
let policy = r#""routing_policy": [
{
"model": "openai/gpt-4o",
"routing_preferences": [
{"name": "coding", "description": "code generation tasks"}
]
},
{
"model": "openai/gpt-4o-mini",
"routing_preferences": [
{"name": "general", "description": "general questions"}
]
}
]"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("should have parsed preferences");
assert_eq!(prefs.len(), 2);
assert_eq!(prefs[0].model, "openai/gpt-4o");
assert_eq!(prefs[0].routing_preferences[0].name, "coding");
assert_eq!(prefs[1].model, "openai/gpt-4o-mini");
assert_eq!(prefs[1].routing_preferences[0].name, "general");
// routing_policy should be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
assert_eq!(cleaned_json["model"], "gpt-4o-mini");
}
#[test]
fn extract_routing_policy_invalid_policy_returns_none() {
// routing_policy is present but has wrong shape
let policy = r#""routing_policy": "not-an-array""#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
// Invalid policy should be ignored (returns None), not error
assert!(prefs.is_none());
// routing_policy should still be stripped from cleaned body
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn extract_routing_policy_invalid_json_returns_error() {
let body = b"not valid json";
let result = extract_routing_policy(body, false);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Failed to parse JSON"));
}
#[test]
fn extract_routing_policy_empty_array() {
let policy = r#""routing_policy": []"#;
let body = make_chat_body(policy);
let (_, prefs) = extract_routing_policy(&body, false).unwrap();
let prefs = prefs.expect("empty array is valid");
assert_eq!(prefs.len(), 0);
}
#[test]
fn extract_routing_policy_preserves_other_fields() {
let policy = r#""routing_policy": [{"model": "gpt-4o", "routing_preferences": [{"name": "test", "description": "test"}]}], "temperature": 0.5, "max_tokens": 100"#;
let body = make_chat_body(policy);
let (cleaned, prefs) = extract_routing_policy(&body, false).unwrap();
assert!(prefs.is_some());
let cleaned_json: serde_json::Value = serde_json::from_slice(&cleaned).unwrap();
assert_eq!(cleaned_json["temperature"], 0.5);
assert_eq!(cleaned_json["max_tokens"], 100);
assert!(cleaned_json.get("routing_policy").is_none());
}
#[test]
fn routing_decision_response_serialization() {
let response = RoutingDecisionResponse {
model: "openai/gpt-4o".to_string(),
route: Some("code_generation".to_string()),
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "openai/gpt-4o");
assert_eq!(parsed["route"], "code_generation");
assert_eq!(parsed["trace_id"], "abc123");
}
#[test]
fn routing_decision_response_serialization_no_route() {
let response = RoutingDecisionResponse {
model: "none".to_string(),
route: None,
trace_id: "abc123".to_string(),
};
let json = serde_json::to_string(&response).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(parsed["model"], "none");
assert!(parsed["route"].is_null());
}
}

View file

@ -3,6 +3,7 @@ use brightstaff::handlers::agents::orchestrator::agent_chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm::RouterService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
@ -221,6 +222,24 @@ async fn route(
}
}
// --- Routing decision routes (/routing/...) ---
if let Some(stripped) = path.strip_prefix("/routing") {
let stripped = stripped.to_string();
if matches!(
stripped.as_str(),
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return routing_decision(
req,
Arc::clone(&state.router_service),
stripped,
Arc::clone(&state.span_attributes),
)
.with_context(parent_cx)
.await;
}
}
// --- Standard routes ---
match (req.method(), path.as_str()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {

View file

@ -112,6 +112,7 @@ pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
}]),
})]
}
InputParam::SingleItem(item) => vec![item.clone()],
InputParam::Items(items) => items.clone(),
}
}
@ -128,3 +129,101 @@ pub async fn retrieve_and_combine_input(
let combined_input = storage.merge(&prev_state, current_input);
Ok(combined_input)
}
#[cfg(test)]
mod tests {
use super::extract_input_items;
use hermesllm::apis::openai_responses::{
InputContent, InputItem, InputMessage, InputParam, MessageContent, MessageRole,
};
#[test]
fn test_extract_input_items_converts_text_to_user_message_item() {
let extracted = extract_input_items(&InputParam::Text("hello world".to_string()));
assert_eq!(extracted.len(), 1);
let InputItem::Message(message) = &extracted[0] else {
panic!("expected InputItem::Message");
};
assert!(matches!(message.role, MessageRole::User));
let MessageContent::Items(items) = &message.content else {
panic!("expected MessageContent::Items");
};
assert_eq!(items.len(), 1);
let InputContent::InputText { text } = &items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(text, "hello world");
}
#[test]
fn test_extract_input_items_preserves_single_item() {
let item = InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "assistant note".to_string(),
}]),
});
let extracted = extract_input_items(&InputParam::SingleItem(item.clone()));
assert_eq!(extracted.len(), 1);
let InputItem::Message(message) = &extracted[0] else {
panic!("expected InputItem::Message");
};
assert!(matches!(message.role, MessageRole::Assistant));
let MessageContent::Items(items) = &message.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text } = &items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(text, "assistant note");
}
#[test]
fn test_extract_input_items_preserves_items_list() {
let items = vec![
InputItem::Message(InputMessage {
role: MessageRole::User,
content: MessageContent::Items(vec![InputContent::InputText {
text: "first".to_string(),
}]),
}),
InputItem::Message(InputMessage {
role: MessageRole::Assistant,
content: MessageContent::Items(vec![InputContent::InputText {
text: "second".to_string(),
}]),
}),
];
let extracted = extract_input_items(&InputParam::Items(items.clone()));
assert_eq!(extracted.len(), items.len());
let InputItem::Message(first) = &extracted[0] else {
panic!("expected first item to be message");
};
assert!(matches!(first.role, MessageRole::User));
let MessageContent::Items(first_items) = &first.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text: first_text } = &first_items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(first_text, "first");
let InputItem::Message(second) = &extracted[1] else {
panic!("expected second item to be message");
};
assert!(matches!(second.role, MessageRole::Assistant));
let MessageContent::Items(second_items) = &second.content else {
panic!("expected MessageContent::Items");
};
let InputContent::InputText { text: second_text } = &second_items[0] else {
panic!("expected InputContent::InputText");
};
assert_eq!(second_text, "second");
}
}