merge main into model-listener-filter-chain

This commit is contained in:
Adil Hafeez 2026-03-10 06:52:19 +00:00
commit aeb8aa9a54
99 changed files with 5792 additions and 655 deletions

3
crates/Cargo.lock generated
View file

@ -436,11 +436,14 @@ name = "common"
version = "0.1.0"
dependencies = [
"axum",
"bytes",
"derivative",
"duration-string",
"governor",
"hermesllm",
"hex",
"http-body-util",
"hyper 1.6.0",
"log",
"pretty_assertions",
"proxy-wasm",

View file

@ -2,6 +2,8 @@ use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use common::configuration::SpanAttributes;
use common::errors::BrightStaffError;
use common::llm_providers::LlmProviders;
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
@ -19,17 +21,17 @@ use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{operation_component, set_service_name};
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentFilterChainError {
#[error("Forwarded error: {0}")]
Brightstaff(#[from] BrightStaffError),
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] super::response_handler::ResponseError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
@ -42,8 +44,11 @@ pub async fn agent_chat(
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
span_attributes: Arc<Option<SpanAttributes>>,
llm_providers: Arc<RwLock<LlmProviders>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let custom_attrs =
collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref());
// Extract request_id from headers or generate a new one
let request_id: String = match request
.headers()
@ -76,6 +81,7 @@ pub async fn agent_chat(
listeners,
llm_providers,
request_id,
custom_attrs,
)
.await
{
@ -103,16 +109,15 @@ pub async fn agent_chat(
"agent_response": body
});
let status_code = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
let json_string = error_json.to_string();
let mut response =
Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(response);
return Ok(BrightStaffError::ForwardedError {
status_code,
message: json_string,
}
.into_response());
}
// Print detailed error information with full error chain for other errors
@ -145,8 +150,11 @@ pub async fn agent_chat(
// Log the error for debugging
info!(error = %error_json, "structured error info");
// Return JSON error response
Ok(ResponseHandler::create_json_error_response(&error_json))
Ok(BrightStaffError::ForwardedError {
status_code: StatusCode::BAD_REQUEST,
message: error_json.to_string(),
}
.into_response())
}
}
}
@ -161,6 +169,7 @@ async fn handle_agent_chat_inner(
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(orchestrator_service);
@ -183,6 +192,9 @@ async fn handle_agent_chat_inner(
get_active_span(|span| {
span.update_name(listener.name.to_string());
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
info!(listener = %listener.name, "handling request");
@ -249,10 +261,7 @@ async fn handle_agent_chat_inner(
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
let mut bad_request =
Response::new(ResponseHandler::create_full_body(err_msg.to_string()));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
}
}
@ -348,6 +357,9 @@ async fn handle_agent_chat_inner(
set_service_name(operation_component::AGENT);
get_active_span(|span| {
span.update_name(format!("{} /v1/chat/completions", agent_name));
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
pipeline_processor

View file

@ -5,9 +5,10 @@ use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::handlers::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use common::errors::BrightStaffError;
use http_body_util::BodyExt;
use hyper::StatusCode;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agents based on orchestration
@ -129,8 +130,24 @@ mod tests {
}
// Test 4: Error Response Creation
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
// Helper to extract body as JSON
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"]["code"], "ModelNotFound");
assert_eq!(
body["error"]["details"]["rejected_model_id"],
"gpt-5-secret"
);
assert!(body["error"]["message"]
.as_str()
.unwrap()
.contains("gpt-5-secret"));
println!("✅ All modular components working correctly!");
}
@ -149,12 +166,21 @@ mod tests {
AgentSelectionError::ListenerNotFound(_)
));
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
let technical_reason = "Database connection timed out";
let err = BrightStaffError::InternalServerError(technical_reason.to_string());
let response = err.into_response();
// --- 1. EXTRACT BYTES ---
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
// --- 2. DECLARE body_json HERE ---
let body_json: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body");
// --- 3. USE body_json ---
assert_eq!(body_json["error"]["code"], "InternalServerError");
assert_eq!(body_json["error"]["details"]["reason"], technical_reason);
println!("✅ Error handling working correctly!");
}

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias};
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
@ -30,13 +30,11 @@ use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::{llm as tracing_llm, operation_component, set_service_name};
use crate::tracing::{
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
};
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
use common::errors::BrightStaffError;
#[allow(clippy::too_many_arguments)]
pub async fn llm_chat(
@ -45,6 +43,7 @@ pub async fn llm_chat(
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
listeners: Arc<RwLock<Vec<Listener>>>,
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
@ -59,6 +58,8 @@ pub async fn llm_chat(
Some(id) => id,
None => uuid::Uuid::new_v4().to_string(),
};
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
@ -80,6 +81,7 @@ pub async fn llm_chat(
full_qualified_llm_provider_url,
model_aliases,
llm_providers,
custom_attrs,
state_storage,
request_id,
request_path,
@ -98,6 +100,7 @@ async fn llm_chat_inner(
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
custom_attrs: HashMap<String, String>,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
@ -107,6 +110,11 @@ async fn llm_chat_inner(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = match request_headers
@ -144,10 +152,11 @@ async fn llm_chat_inner(
error = %err,
"failed to parse request as ProviderRequestType"
);
let err_msg = format!("Failed to parse request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
@ -172,9 +181,7 @@ async fn llm_chat_inner(
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
let mut bad_request = Response::new(full(err_msg.to_string()));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
}
} else {
@ -195,14 +202,8 @@ async fn llm_chat_inner(
.get(&alias_resolved_model)
.is_none()
{
let err_msg = format!(
"Model '{}' not found in configured providers",
alias_resolved_model
);
warn!(model = %alias_resolved_model, "model not found in configured providers");
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response());
}
// Handle provider/model slug format (e.g., "openai/gpt-4")
@ -261,12 +262,7 @@ async fn llm_chat_inner(
let agents_guard = agents_list.read().await;
let agent_map: HashMap<String, Agent> = agents_guard
.as_ref()
.map(|agents| {
agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect()
})
.map(|agents| agents.iter().map(|a| (a.id.clone(), a.clone())).collect())
.unwrap_or_default();
// Create a temporary AgentFilterChain to reuse PipelineProcessor
@ -387,13 +383,10 @@ async fn llm_chat_inner(
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
let err_msg = format!(
"Conversation state not found for previous_response_id: {}",
prev_resp_id
);
let mut conflict_response = Response::new(full(err_msg));
*conflict_response.status_mut() = StatusCode::CONFLICT;
return Ok(conflict_response);
return Ok(BrightStaffError::ConversationStateNotFound(
prev_resp_id.to_string(),
)
.into_response());
}
Err(e) => {
// Log warning but continue on other storage errors
@ -444,9 +437,11 @@ async fn llm_chat_inner(
{
Ok(result) => result,
Err(err) => {
let mut internal_error = Response::new(full(err.message));
*internal_error.status_mut() = err.status_code;
return Ok(internal_error);
return Ok(BrightStaffError::ForwardedError {
status_code: err.status_code,
message: err.message,
}
.into_response());
}
};
@ -512,10 +507,11 @@ async fn llm_chat_inner(
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request: {}",
err
))
.into_response());
}
};
@ -572,12 +568,11 @@ async fn llm_chat_inner(
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
Err(err) => Ok(BrightStaffError::InternalServerError(format!(
"Failed to create response: {}",
err
))
.into_response()),
}
}
@ -650,3 +645,9 @@ async fn get_provider_info(
(hermesllm::ProviderId::OpenAI, None)
}
}
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}

View file

@ -7,6 +7,7 @@ pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod routing_service;
pub mod utils;
#[cfg(test)]

View file

@ -1,25 +1,17 @@
use bytes::Bytes;
use common::errors::BrightStaffError;
use hermesllm::apis::OpenAIApi;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::SseEvent;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::{Response, StatusCode};
use hyper::Response;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{info, warn, Instrument};
/// Errors that can occur during response handling
#[derive(Debug, thiserror::Error)]
pub enum ResponseError {
#[error("Failed to create response: {0}")]
ResponseCreationFailed(#[from] hyper::http::Error),
#[error("Stream error: {0}")]
StreamError(String),
}
/// Service for handling HTTP responses and streaming
pub struct ResponseHandler;
@ -35,40 +27,6 @@ impl ResponseHandler {
.boxed()
}
/// Create an error response with a given status code and message
pub fn create_error_response(
status: StatusCode,
message: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::new(Self::create_full_body(message.to_string()));
*response.status_mut() = status;
response
}
/// Create a bad request response
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::BAD_REQUEST, message)
}
/// Create an internal server error response
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
}
/// Create a JSON error response
pub fn create_json_error_response(
error_json: &serde_json::Value,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let json_string = error_json.to_string();
let mut response = Response::new(Self::create_full_body(json_string));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
response
}
/// Create a streaming response from a reqwest response.
/// The spawned streaming task is instrumented with both `agent_span` and `orchestrator_span`
/// so their durations reflect the actual time spent streaming to the client.
@ -77,13 +35,13 @@ impl ResponseHandler {
llm_response: reqwest::Response,
agent_span: tracing::Span,
orchestrator_span: tracing::Span,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, BrightStaffError> {
// Copy headers from the original response
let response_headers = llm_response.headers();
let mut response_builder = Response::builder();
let headers = response_builder.headers_mut().ok_or_else(|| {
ResponseError::StreamError("Failed to get mutable headers".to_string())
BrightStaffError::StreamError("Failed to get mutable headers".to_string())
})?;
for (header_name, header_value) in response_headers.iter() {
@ -123,7 +81,7 @@ impl ResponseHandler {
response_builder
.body(stream_body)
.map_err(ResponseError::from)
.map_err(BrightStaffError::from)
}
/// Collect the full response body as a string
@ -136,7 +94,7 @@ impl ResponseHandler {
pub async fn collect_full_response(
&self,
llm_response: reqwest::Response,
) -> Result<String, ResponseError> {
) -> Result<String, BrightStaffError> {
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
let response_headers = llm_response.headers();
@ -144,10 +102,9 @@ impl ResponseHandler {
.get(hyper::header::CONTENT_TYPE)
.is_some_and(|v| v.to_str().unwrap_or("").contains("text/event-stream"));
let response_bytes = llm_response
.bytes()
.await
.map_err(|e| ResponseError::StreamError(format!("Failed to read response: {}", e)))?;
let response_bytes = llm_response.bytes().await.map_err(|e| {
BrightStaffError::StreamError(format!("Failed to read response: {}", e))
})?;
if is_sse_streaming {
let client_api =
@ -185,7 +142,7 @@ impl ResponseHandler {
} else {
// If not SSE, treat as regular text response
let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| {
ResponseError::StreamError(format!("Failed to decode response: {}", e))
BrightStaffError::StreamError(format!("Failed to decode response: {}", e))
})?;
Ok(response_text)
@ -204,42 +161,6 @@ mod tests {
use super::*;
use hyper::StatusCode;
#[test]
fn test_create_bad_request() {
let response = ResponseHandler::create_bad_request("Invalid request");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_create_internal_error() {
let response = ResponseHandler::create_internal_error("Server error");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_create_error_response() {
let response =
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[test]
fn test_create_json_error_response() {
let error_json = serde_json::json!({
"error": {
"type": "TestError",
"message": "Test error message"
}
});
let response = ResponseHandler::create_json_error_response(&error_json);
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[tokio::test]
async fn test_create_streaming_response_with_mock() {
use mockito::Server;

View file

@ -10,6 +10,7 @@ use crate::tracing::routing;
pub struct RoutingResult {
pub model_name: String,
pub route_name: Option<String>,
}
pub struct RoutingError {
@ -133,9 +134,12 @@ pub async fn router_chat_get_upstream_model(
match routing_result {
Ok(route) => match route {
Some((_, model_name)) => {
Some((route_name, model_name)) => {
current_span.record("route.selected_model", model_name.as_str());
Ok(RoutingResult { model_name })
Ok(RoutingResult {
model_name,
route_name: Some(route_name),
})
}
None => {
// No route determined, return sentinel value "none"
@ -145,6 +149,7 @@ pub async fn router_chat_get_upstream_model(
Ok(RoutingResult {
model_name: "none".to_string(),
route_name: None,
})
}
},

View file

@ -0,0 +1,163 @@
use bytes::Bytes;
use common::configuration::SpanAttributes;
use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
use common::errors::BrightStaffError;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use std::sync::Arc;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::router::llm_router::RouterService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
#[derive(serde::Serialize)]
struct RoutingDecisionResponse {
model: String,
route: Option<String>,
trace_id: String,
}
pub async fn routing_decision(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_path: String,
span_attributes: Arc<Option<SpanAttributes>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_headers = request.headers().clone();
let request_id: String = request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
let request_span = info_span!(
"routing_decision",
component = "routing",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
);
routing_decision_inner(
request,
router_service,
request_id,
request_path,
request_headers,
custom_attrs,
)
.instrument(request_span)
.await
}
async fn routing_decision_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
request_id: String,
request_path: String,
request_headers: hyper::HeaderMap,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
set_service_name(operation_component::ROUTING);
opentelemetry::trace::get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent
let traceparent: String = match request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(tp) => tp,
None => {
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %generated_tp,
"TRACE_PARENT header missing, generated new traceparent"
);
generated_tp
}
};
// Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags})
let trace_id = traceparent
.split('-')
.nth(1)
.unwrap_or("unknown")
.to_string();
// Parse request body
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
"routing decision request body received"
);
let client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(error = %err, "failed to parse request for routing decision");
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
// Call the existing routing logic
let routing_result = router_chat_get_upstream_model(
router_service,
client_request,
&traceparent,
&request_path,
&request_id,
)
.await;
match routing_result {
Ok(result) => {
let response = RoutingDecisionResponse {
model: result.model_name,
route: result.route_name,
trace_id,
};
info!(
model = %response.model,
route = ?response.route,
"routing decision completed"
);
let json = serde_json::to_string(&response).unwrap();
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap())
}
Err(err) => {
warn!(error = %err.message, "routing decision failed");
Ok(BrightStaffError::InternalServerError(err.message).into_response())
}
}
}

View file

@ -2,6 +2,7 @@ use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::routing_service::routing_decision;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
@ -114,6 +115,12 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
));
let model_aliases = Arc::new(plano_config.model_aliases.clone());
let span_attributes = Arc::new(
plano_config
.tracing
.as_ref()
.and_then(|tracing| tracing.span_attributes.clone()),
);
// Initialize trace collector and start background flusher
// Tracing is enabled if the tracing config is present in plano_config.yaml
@ -173,6 +180,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
@ -183,10 +191,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
async move {
let path = req.uri().path();
let path = req.uri().path().to_string();
// Check if path starts with /agents
if path.starts_with("/agents") {
// Check if it matches one of the agent API paths
@ -202,13 +211,30 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
fully_qualified_url,
agents_list,
listeners,
span_attributes,
llm_providers,
)
.with_context(parent_cx)
.await;
}
}
match (req.method(), path) {
if let Some(stripped_path) = path.strip_prefix("/routing") {
let stripped_path = stripped_path.to_string();
if matches!(
stripped_path.as_str(),
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return routing_decision(
req,
router_service,
stripped_path,
span_attributes,
)
.with_context(parent_cx)
.await;
}
}
match (req.method(), path.as_str()) {
(
&Method::POST,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
@ -220,6 +246,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
fully_qualified_url,
model_aliases,
llm_providers,
span_attributes,
state_storage,
listeners,
agents_list,
@ -262,7 +289,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(response)
}
_ => {
debug!(method = %req.method(), path = %req.uri().path(), "no route found");
debug!(method = %req.method(), path = %path, "no route found");
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)

View file

@ -0,0 +1,156 @@
use std::collections::HashMap;
use common::configuration::SpanAttributes;
use common::traces::SpanBuilder;
use hyper::header::HeaderMap;
pub fn collect_custom_trace_attributes(
headers: &HeaderMap,
span_attributes: Option<&SpanAttributes>,
) -> HashMap<String, String> {
let mut attributes = HashMap::new();
let Some(span_attributes) = span_attributes else {
return attributes;
};
if let Some(static_attributes) = span_attributes.static_attributes.as_ref() {
for (key, value) in static_attributes {
attributes.insert(key.clone(), value.clone());
}
}
let Some(header_prefixes) = span_attributes.header_prefixes.as_deref() else {
return attributes;
};
if header_prefixes.is_empty() {
return attributes;
}
for (name, value) in headers.iter() {
let header_name = name.as_str();
let matched_prefix = header_prefixes
.iter()
.find(|prefix| header_name.starts_with(prefix.as_str()))
.map(String::as_str);
let Some(prefix) = matched_prefix else {
continue;
};
let Some(raw_value) = value.to_str().ok().map(str::trim) else {
continue;
};
let suffix = header_name.strip_prefix(prefix).unwrap_or("");
let suffix_key = suffix.trim_start_matches('-').replace('-', ".");
if suffix_key.is_empty() {
continue;
}
attributes.insert(suffix_key, raw_value.to_string());
}
attributes
}
pub fn append_span_attributes(
mut span_builder: SpanBuilder,
attributes: &HashMap<String, String>,
) -> SpanBuilder {
for (key, value) in attributes {
span_builder = span_builder.with_attribute(key, value);
}
span_builder
}
#[cfg(test)]
mod tests {
use super::collect_custom_trace_attributes;
use common::configuration::SpanAttributes;
use hyper::header::{HeaderMap, HeaderValue};
use std::collections::HashMap;
#[test]
fn extracts_headers_by_prefix() {
let mut headers = HeaderMap::new();
headers.insert("x-katanemo-tenant-id", HeaderValue::from_static("ten_456"));
headers.insert("x-katanemo-user-id", HeaderValue::from_static("usr_789"));
headers.insert("x-katanemo-admin-level", HeaderValue::from_static("3"));
headers.insert("x-other-id", HeaderValue::from_static("ignored"));
let attrs = collect_custom_trace_attributes(
&headers,
Some(&SpanAttributes {
header_prefixes: Some(vec!["x-katanemo-".to_string()]),
static_attributes: None,
}),
);
assert_eq!(attrs.get("tenant.id"), Some(&"ten_456".to_string()));
assert_eq!(attrs.get("user.id"), Some(&"usr_789".to_string()));
assert_eq!(attrs.get("admin.level"), Some(&"3".to_string()));
assert!(!attrs.contains_key("other.id"));
}
#[test]
fn returns_empty_when_prefixes_missing_or_empty() {
let mut headers = HeaderMap::new();
headers.insert("x-katanemo-tenant-id", HeaderValue::from_static("ten_456"));
let attrs_none = collect_custom_trace_attributes(
&headers,
Some(&SpanAttributes {
header_prefixes: None,
static_attributes: None,
}),
);
assert!(attrs_none.is_empty());
let attrs_empty = collect_custom_trace_attributes(
&headers,
Some(&SpanAttributes {
header_prefixes: Some(Vec::new()),
static_attributes: None,
}),
);
assert!(attrs_empty.is_empty());
}
#[test]
fn supports_multiple_prefixes() {
let mut headers = HeaderMap::new();
headers.insert("x-katanemo-tenant-id", HeaderValue::from_static("ten_456"));
headers.insert("x-tenant-user-id", HeaderValue::from_static("usr_789"));
let attrs = collect_custom_trace_attributes(
&headers,
Some(&SpanAttributes {
header_prefixes: Some(vec!["x-katanemo-".to_string(), "x-tenant-".to_string()]),
static_attributes: None,
}),
);
assert_eq!(attrs.get("tenant.id"), Some(&"ten_456".to_string()));
assert_eq!(attrs.get("user.id"), Some(&"usr_789".to_string()));
}
#[test]
fn header_attributes_override_static_attributes() {
let mut headers = HeaderMap::new();
headers.insert("x-katanemo-tenant-id", HeaderValue::from_static("ten_456"));
let mut static_attributes = HashMap::new();
static_attributes.insert("tenant.id".to_string(), "ten_static".to_string());
static_attributes.insert("environment".to_string(), "prod".to_string());
let attrs = collect_custom_trace_attributes(
&headers,
Some(&SpanAttributes {
header_prefixes: Some(vec!["x-katanemo-".to_string()]),
static_attributes: Some(static_attributes),
}),
);
assert_eq!(attrs.get("tenant.id"), Some(&"ten_456".to_string()));
assert_eq!(attrs.get("environment"), Some(&"prod".to_string()));
}
}

View file

@ -1,9 +1,11 @@
mod constants;
mod custom_attributes;
mod service_name_exporter;
pub use constants::{
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
};
pub use custom_attributes::{append_span_attributes, collect_custom_trace_attributes};
pub use service_name_exporter::{ServiceNameOverrideExporter, SERVICE_NAME_OVERRIDE_KEY};
use opentelemetry::trace::get_active_span;

View file

@ -20,6 +20,9 @@ urlencoding = "2.1.3"
url = "2.5.4"
hermesllm = { version = "0.1.0", path = "../hermesllm" }
serde_with = "3.13.0"
hyper = "1.0"
bytes = "1.0"
http-body-util = "0.1"
[features]
default = []
@ -30,3 +33,6 @@ serde_json = "1.0.64"
serial_test = "3.2"
axum = "0.7"
tokio = { version = "1.44", features = ["sync", "time", "macros", "rt"] }
hyper = { version = "1.0", features = ["full"] }
bytes = "1.0"
http-body-util = "0.1"

View file

@ -93,6 +93,14 @@ pub struct Tracing {
pub trace_arch_internal: Option<bool>,
pub random_sampling: Option<u32>,
pub opentracing_grpc_endpoint: Option<String>,
pub span_attributes: Option<SpanAttributes>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct SpanAttributes {
pub header_prefixes: Option<Vec<String>>,
#[serde(rename = "static")]
pub static_attributes: Option<HashMap<String, String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, Default)]

View file

@ -1,9 +1,13 @@
use proxy_wasm::types::Status;
use crate::{api::open_ai::ChatCompletionChunkResponseError, ratelimit};
use bytes::Bytes;
use hermesllm::apis::openai::OpenAIError;
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Error as HyperError, Response, StatusCode};
use proxy_wasm::types::Status;
use serde_json::json;
use thiserror::Error;
#[derive(thiserror::Error, Debug)]
#[derive(Error, Debug)]
pub enum ClientError {
#[error("Error dispatching HTTP call to `{upstream_name}/{path}`, error: {internal_status:?}")]
DispatchError {
@ -13,7 +17,7 @@ pub enum ClientError {
},
}
#[derive(thiserror::Error, Debug)]
#[derive(Error, Debug)]
pub enum ServerError {
#[error(transparent)]
HttpDispatch(ClientError),
@ -43,3 +47,174 @@ pub enum ServerError {
#[error("error parsing openai message: {0}")]
OpenAIPError(#[from] OpenAIError),
}
// -----------------------------------------------------------------------------
// BrightStaff Errors (Standardized)
// -----------------------------------------------------------------------------
#[derive(Debug, Error)]
pub enum BrightStaffError {
#[error("The requested model '{0}' does not exist")]
ModelNotFound(String),
#[error("No model specified in request and no default provider configured")]
NoModelSpecified,
#[error("Conversation state not found for previous_response_id: {0}")]
ConversationStateNotFound(String),
#[error("Internal server error")]
InternalServerError(String),
#[error("Invalid request")]
InvalidRequest(String),
#[error("{message}")]
ForwardedError {
status_code: StatusCode,
message: String,
},
#[error("Stream error: {0}")]
StreamError(String),
#[error("Failed to create response: {0}")]
ResponseCreationFailed(#[from] hyper::http::Error),
}
impl BrightStaffError {
pub fn into_response(self) -> Response<BoxBody<Bytes, HyperError>> {
let (status, code, details) = match &self {
BrightStaffError::ModelNotFound(model_name) => (
StatusCode::NOT_FOUND,
"ModelNotFound",
json!({ "rejected_model_id": model_name }),
),
BrightStaffError::NoModelSpecified => {
(StatusCode::BAD_REQUEST, "NoModelSpecified", json!({}))
}
BrightStaffError::ConversationStateNotFound(prev_resp_id) => (
StatusCode::CONFLICT,
"ConversationStateNotFound",
json!({ "previous_response_id": prev_resp_id }),
),
BrightStaffError::InternalServerError(reason) => (
StatusCode::INTERNAL_SERVER_ERROR,
"InternalServerError",
// Passing the reason into details for easier debugging
json!({ "reason": reason }),
),
BrightStaffError::InvalidRequest(reason) => (
StatusCode::BAD_REQUEST,
"InvalidRequest",
json!({ "reason": reason }),
),
BrightStaffError::ForwardedError {
status_code,
message,
} => (*status_code, "ForwardedError", json!({ "reason": message })),
BrightStaffError::StreamError(reason) => (
StatusCode::BAD_REQUEST,
"StreamError",
json!({ "reason": reason }),
),
BrightStaffError::ResponseCreationFailed(reason) => (
StatusCode::BAD_REQUEST,
"ResponseCreationFailed",
json!({ "reason": reason.to_string() }),
),
};
let body_json = json!({
"error": {
"code": code,
"message": self.to_string(),
"details": details
}
});
// 1. Create the concrete body
let full_body = Full::new(Bytes::from(body_json.to_string()));
// 2. Convert it to BoxBody
// We map_err because Full never fails, but BoxBody expects a HyperError
let boxed_body = full_body
.map_err(|never| match never {}) // This handles the "Infallible" error type
.boxed();
Response::builder()
.status(status)
.header("content-type", "application/json")
.body(boxed_body)
.unwrap_or_else(|_| {
Response::new(
Full::new(Bytes::from("Internal Error"))
.map_err(|never| match never {})
.boxed(),
)
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::BodyExt; // For .collect().await
#[tokio::test]
async fn test_model_not_found_format() {
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
// Helper to extract body as JSON
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"]["code"], "ModelNotFound");
assert_eq!(
body["error"]["details"]["rejected_model_id"],
"gpt-5-secret"
);
assert!(body["error"]["message"]
.as_str()
.unwrap()
.contains("gpt-5-secret"));
}
#[tokio::test]
async fn test_forwarded_error_preserves_status() {
let err = BrightStaffError::ForwardedError {
status_code: StatusCode::TOO_MANY_REQUESTS,
message: "Rate limit exceeded on agent side".to_string(),
};
let response = err.into_response();
assert_eq!(response.status(), StatusCode::TOO_MANY_REQUESTS);
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"]["code"], "ForwardedError");
}
#[tokio::test]
async fn test_hyper_error_wrapping() {
// Manually trigger a hyper error by creating an invalid URI/Header
let hyper_err = hyper::http::Response::builder()
.status(1000) // Invalid status
.body(())
.unwrap_err();
let err = BrightStaffError::ResponseCreationFailed(hyper_err);
let response = err.into_response();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
}