refactor brightstaff

This commit is contained in:
Adil Hafeez 2026-02-10 00:45:34 -08:00
parent b9f01c8471
commit 3fdd8a3a35
24 changed files with 1102 additions and 1086 deletions

View file

@ -0,0 +1,25 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{Agent, Listener, ModelAlias};
use common::llm_providers::LlmProviders;
use tokio::sync::RwLock;
use crate::router::llm::RouterService;
use crate::router::orchestrator::OrchestratorService;
use crate::state::StateStorage;
/// Shared application state bundled into a single Arc-wrapped struct.
///
/// Instead of cloning 8+ individual `Arc`s per connection, a single
/// `Arc<AppState>` is cloned once and passed to the request handler.
pub struct AppState {
pub router_service: Arc<RouterService>,
pub orchestrator_service: Arc<OrchestratorService>,
pub model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
pub llm_providers: Arc<RwLock<LlmProviders>>,
pub agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
pub listeners: Arc<RwLock<Vec<Listener>>>,
pub state_storage: Option<Arc<dyn StateStorage>>,
pub llm_provider_url: String,
}

View file

@ -0,0 +1,4 @@
pub mod jsonrpc;
pub mod orchestrator;
pub mod pipeline;
pub mod selector;

View file

@ -2,36 +2,34 @@ use std::sync::Arc;
use std::time::Instant;
use bytes::Bytes;
use common::configuration::SpanAttributes;
use common::errors::BrightStaffError;
use common::llm_providers::LlmProviders;
use hermesllm::apis::OpenAIMessage;
use hermesllm::clients::SupportedAPIsFromClient;
use hermesllm::providers::request::ProviderRequest;
use hermesllm::ProviderRequestType;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response, StatusCode};
use hyper::{Request, Response};
use opentelemetry::trace::get_active_span;
use serde::ser::Error as SerError;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::plano_orchestrator::OrchestratorService;
use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name};
use super::pipeline::{PipelineError, PipelineProcessor};
use super::selector::{AgentSelectionError, AgentSelector};
use crate::handlers::errors::build_error_chain_response;
use crate::handlers::request::extract_request_id;
use crate::handlers::response::ResponseHandler;
use crate::router::orchestrator::OrchestratorService;
use crate::tracing::{operation_component, set_service_name};
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentFilterChainError {
#[error("Forwarded error: {0}")]
Brightstaff(#[from] BrightStaffError),
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] crate::handlers::response::ResponseError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
@ -41,24 +39,10 @@ pub enum AgentFilterChainError {
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
orchestrator_service: Arc<OrchestratorService>,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
span_attributes: Arc<Option<SpanAttributes>>,
llm_providers: Arc<RwLock<LlmProviders>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let custom_attrs =
collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref());
// Extract request_id from headers or generate a new one
let request_id: String = match request
.headers()
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(id) => id,
None => uuid::Uuid::new_v4().to_string(),
};
let request_id = extract_request_id(&request);
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
@ -79,9 +63,7 @@ pub async fn agent_chat(
orchestrator_service,
agents_list,
listeners,
llm_providers,
request_id,
custom_attrs,
)
.await
{
@ -101,7 +83,6 @@ pub async fn agent_chat(
"client error from agent"
);
// Create error response with the original status code and body
let error_json = serde_json::json!({
"error": "ClientError",
"agent": agent,
@ -109,52 +90,19 @@ pub async fn agent_chat(
"agent_response": body
});
let status_code = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
let json_string = error_json.to_string();
return Ok(BrightStaffError::ForwardedError {
status_code,
message: json_string,
}
.into_response());
let mut response =
Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
return Ok(response);
}
// Print detailed error information with full error chain for other errors
let mut error_chain = Vec::new();
let mut current_error: &dyn std::error::Error = &err;
// Collect the full error chain
loop {
error_chain.push(current_error.to_string());
match current_error.source() {
Some(source) => current_error = source,
None => break,
}
}
// Log the complete error chain
warn!(error_chain = ?error_chain, "agent chat error chain");
warn!(root_error = ?err, "root error");
// Create structured error response as JSON
let error_json = serde_json::json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
// Log the error for debugging
info!(error = %error_json, "structured error info");
Ok(BrightStaffError::ForwardedError {
status_code: StatusCode::BAD_REQUEST,
message: error_json.to_string(),
}
.into_response())
build_error_chain_response(&err)
}
}
}
@ -167,9 +115,7 @@ async fn handle_agent_chat_inner(
orchestrator_service: Arc<OrchestratorService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(orchestrator_service);
@ -192,9 +138,6 @@ async fn handle_agent_chat_inner(
get_active_span(|span| {
span.update_name(listener.name.to_string());
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
info!(listener = %listener.name, "handling request");
@ -238,33 +181,16 @@ async fn handle_agent_chat_inner(
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
})?;
let mut client_request =
match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
Ok(request) => request,
Err(err) => {
warn!("failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom(err_msg),
));
}
};
// If model is not specified in the request, resolve from default provider
if client_request.model().is_empty() {
match llm_providers.read().await.default() {
Some(default_provider) => {
let default_model = default_provider.name.clone();
info!(default_model = %default_model, "no model specified in request, using default provider");
client_request.set_model(default_model);
}
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
Ok(request) => request,
Err(err) => {
warn!("failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom(err_msg),
));
}
}
};
let message: Vec<OpenAIMessage> = client_request.get_messages();
@ -357,9 +283,6 @@ async fn handle_agent_chat_inner(
set_service_name(operation_component::AGENT);
get_active_span(|span| {
span.update_name(format!("{} /v1/chat/completions", agent_name));
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
pipeline_processor

View file

@ -11,7 +11,7 @@ use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use tracing::{debug, info, instrument, warn};
use crate::handlers::jsonrpc::{
use super::jsonrpc::{
JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION,
MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD,
};

View file

@ -7,7 +7,7 @@ use common::configuration::{
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::plano_orchestrator::OrchestratorService;
use crate::router::orchestrator::OrchestratorService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]

View file

@ -0,0 +1,41 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use hyper::Response;
use serde_json::json;
use tracing::{info, warn};
use super::response::ResponseHandler;
/// Build a JSON error response from an `AgentFilterChainError`, logging the
/// full error chain along the way.
///
/// Returns `Ok(Response)` so it can be used directly as a handler return value.
pub fn build_error_chain_response<E: std::error::Error>(
err: &E,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut error_chain = Vec::new();
let mut current: &dyn std::error::Error = err;
loop {
error_chain.push(current.to_string());
match current.source() {
Some(source) => current = source,
None => break,
}
}
warn!(error_chain = ?error_chain, "agent chat error chain");
warn!(root_error = ?err, "root error");
let error_json = json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
info!(error = %error_json, "structured error info");
Ok(ResponseHandler::create_json_error_response(&error_json))
}

View file

@ -3,12 +3,11 @@ use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::router::plano_orchestrator::OrchestratorService;
use common::errors::BrightStaffError;
use http_body_util::BodyExt;
use hyper::StatusCode;
use crate::handlers::agents::pipeline::PipelineProcessor;
use crate::handlers::agents::selector::{AgentSelectionError, AgentSelector};
use crate::handlers::response::ResponseHandler;
use crate::router::orchestrator::OrchestratorService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agents based on orchestration
@ -129,24 +128,8 @@ mod tests {
}
// Test 4: Error Response Creation
let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string());
let response = err.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
// Helper to extract body as JSON
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
assert_eq!(body["error"]["code"], "ModelNotFound");
assert_eq!(
body["error"]["details"]["rejected_model_id"],
"gpt-5-secret"
);
assert!(body["error"]["message"]
.as_str()
.unwrap()
.contains("gpt-5-secret"));
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
println!("✅ All modular components working correctly!");
}
@ -165,21 +148,12 @@ mod tests {
AgentSelectionError::ListenerNotFound(_)
));
let technical_reason = "Database connection timed out";
let err = BrightStaffError::InternalServerError(technical_reason.to_string());
let response = err.into_response();
// --- 1. EXTRACT BYTES ---
let body_bytes = response.into_body().collect().await.unwrap().to_bytes();
// --- 2. DECLARE body_json HERE ---
let body_json: serde_json::Value =
serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body");
// --- 3. USE body_json ---
assert_eq!(body_json["error"]["code"], "InternalServerError");
assert_eq!(body_json["error"]["details"]["reason"], technical_reason);
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
println!("✅ Error handling working correctly!");
}

View file

@ -1,553 +0,0 @@
use bytes::Bytes;
use common::configuration::{ModelAlias, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
use common::llm_providers::LlmProviders;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::header::{self};
use hyper::{Request, Response};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
use crate::handlers::router_chat::router_chat_get_upstream_model;
use crate::handlers::utils::{
create_streaming_response, truncate_message, ObservableStreamProcessor,
};
use crate::router::llm_router::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::{
collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name,
};
use common::errors::BrightStaffError;
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id: String = match request_headers
.get(REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(id) => id,
None => uuid::Uuid::new_v4().to_string(),
};
let custom_attrs =
collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref());
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
"llm",
component = "llm",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
llm.model = tracing::field::Empty,
llm.tools = tracing::field::Empty,
llm.user_message_preview = tracing::field::Empty,
llm.temperature = tracing::field::Empty,
);
// Execute the rest of the handler inside the span
llm_chat_inner(
request,
router_service,
full_qualified_llm_provider_url,
model_aliases,
llm_providers,
custom_attrs,
state_storage,
request_id,
request_path,
request_headers,
)
.instrument(request_span)
.await
}
#[allow(clippy::too_many_arguments)]
async fn llm_chat_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
custom_attrs: HashMap<String, String>,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
get_active_span(|span| {
for (key, value) in &custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
// Extract or generate traceparent - this establishes the trace context for all spans
let traceparent: String = match request_headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
{
Some(tp) => tp,
None => {
use uuid::Uuid;
let trace_id = Uuid::new_v4().to_string().replace("-", "");
let generated_tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %generated_tp,
"TRACE_PARENT header missing, generated new traceparent"
);
generated_tp
}
};
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
"request body received"
);
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!(
error = %err,
"failed to parse request as ProviderRequestType"
);
return Ok(BrightStaffError::InvalidRequest(format!(
"Failed to parse request: {}",
err
))
.into_response());
}
};
// === v1/responses state management: Extract input items early ===
let mut original_input_items = Vec::new();
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
// If model is not specified in the request, resolve from default provider
let model_from_request = client_request.model().to_string();
let model_from_request = if model_from_request.is_empty() {
match llm_providers.read().await.default() {
Some(default_provider) => {
let default_model = default_provider.name.clone();
info!(default_model = %default_model, "no model specified in request, using default provider");
client_request.set_model(default_model.clone());
default_model
}
None => {
let err_msg = "No model specified in request and no default provider configured";
warn!("{}", err_msg);
return Ok(BrightStaffError::NoModelSpecified.into_response());
}
}
} else {
model_from_request
};
// Model alias resolution: update model field in client_request immediately
// This ensures all downstream objects use the resolved model
let temperature = client_request.get_temperature();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases);
// Validate that the requested model exists in configuration
// This matches the validation in llm_gateway routing.rs
if llm_providers
.read()
.await
.get(&alias_resolved_model)
.is_none()
{
warn!(model = %alias_resolved_model, "model not found in configured providers");
return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response());
}
// Handle provider/model slug format (e.g., "openai/gpt-4")
// Extract just the model name for upstream (providers don't understand the slug)
let model_name_only = if let Some((_, model)) = alias_resolved_model.split_once('/') {
model.to_string()
} else {
alias_resolved_model.clone()
};
// Extract tool names and user message preview for span attributes
let tool_names = client_request.get_tool_names();
let user_message_preview = client_request
.get_recent_user_message()
.map(|msg| truncate_message(&msg, 50));
let span = tracing::Span::current();
if let Some(temp) = temperature {
span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp));
}
if let Some(tools) = &tool_names {
let formatted_tools = tools
.iter()
.map(|name| format!("{}(...)", name))
.collect::<Vec<_>>()
.join("\n");
span.record(tracing_llm::TOOLS, formatted_tools.as_str());
}
if let Some(preview) = &user_message_preview {
span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str());
}
// Extract messages for signal analysis (clone before moving client_request)
let messages_for_signals = Some(client_request.get_messages());
// Set the model to just the model name (without provider prefix)
// This ensures upstream receives "gpt-4" not "openai/gpt-4"
client_request.set_model(model_name_only.clone());
if client_request.remove_metadata_key("plano_preference_config") {
debug!("removed plano_preference_config from metadata");
}
// === v1/responses state management: Determine upstream API and combine input if needed ===
// Do this BEFORE routing since routing consumes the request
// Only process state if state_storage is configured
let mut should_manage_state = false;
if is_responses_api_client {
if let (
ProviderRequestType::ResponsesAPIRequest(ref mut responses_req),
Some(ref state_store),
) = (&mut client_request, &state_storage)
{
// Extract original input once
original_input_items = extract_input_items(&responses_req.input);
// Get the upstream path and check if it's ResponsesAPI
let upstream_path = get_upstream_path(
&llm_providers,
&alias_resolved_model,
&request_path,
&alias_resolved_model,
is_streaming_request,
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if should_manage_state {
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(
state_store.clone(),
prev_resp_id,
original_input_items, // Pass ownership instead of cloning
)
.await
{
Ok(combined_input) => {
// Update both the request and original_input_items
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
info!(
items = original_input_items.len(),
"updated request with conversation history"
);
}
Err(StateStorageError::NotFound(_)) => {
// Return 409 Conflict when previous_response_id not found
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
return Ok(BrightStaffError::ConversationStateNotFound(
prev_resp_id.to_string(),
)
.into_response());
}
Err(e) => {
// Log warning but continue on other storage errors
warn!(
previous_response_id = %prev_resp_id,
error = %e,
"failed to retrieve conversation state"
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
} else {
debug!("upstream supports ResponsesAPI natively");
}
}
}
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Determine routing using the dedicated router_chat module
// This gets its own span for latency and error tracking
let routing_span = info_span!(
"routing",
component = "routing",
http.method = "POST",
http.target = %request_path,
model.requested = %model_from_request,
model.alias_resolved = %alias_resolved_model,
route.selected_model = tracing::field::Empty,
routing.determination_ms = tracing::field::Empty,
);
let routing_result = match async {
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
router_service,
client_request, // Pass the original request - router_chat will convert it
&traceparent,
&request_path,
&request_id,
)
.await
}
.instrument(routing_span)
.await
{
Ok(result) => result,
Err(err) => {
return Ok(BrightStaffError::ForwardedError {
status_code: err.status_code,
message: err.message,
}
.into_response());
}
};
// Determine final model to use
// Router returns "none" as a sentinel value when it doesn't select a specific model
let router_selected_model = routing_result.model_name;
let resolved_model = if router_selected_model != "none" {
// Router selected a specific model via routing preferences
router_selected_model
} else {
// Router returned "none" sentinel, use validated resolved_model from request
alias_resolved_model.clone()
};
tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str());
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
)
};
get_active_span(|span| {
span.update_name(span_name.clone());
});
debug!(
url = %full_qualified_llm_provider_url,
provider_hint = %resolved_model,
upstream_model = %model_name_only,
"Routing to upstream"
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&resolved_model).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
// Inject current LLM span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
});
// Capture start time right before sending request to upstream
let request_start_time = std::time::Instant::now();
let _request_start_system_time = std::time::SystemTime::now();
let llm_response = match reqwest::Client::new()
.post(&full_qualified_llm_provider_url)
.headers(request_headers)
.body(client_request_bytes_for_upstream)
.send()
.await
{
Ok(res) => res,
Err(err) => {
return Ok(BrightStaffError::InternalServerError(format!(
"Failed to send request: {}",
err
))
.into_response());
}
};
// copy over the headers and status code from the original response
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Build LLM span with actual status code using constants
let byte_stream = llm_response.bytes_stream();
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
span_name,
request_start_time,
messages_for_signals,
);
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
let streaming_response = if let (true, false, Some(state_store)) = (
should_manage_state,
original_input_items.is_empty(),
state_storage,
) {
// Extract Content-Encoding header to handle decompression for state parsing
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
// Wrap with state management processor to store state after response completes
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_store,
original_input_items,
alias_resolved_model.clone(),
resolved_model.clone(),
is_streaming_request,
false, // Not OpenAI upstream since should_manage_state is true
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
// Use base processor without state management
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => Ok(BrightStaffError::InternalServerError(format!(
"Failed to create response: {}",
err
))
.into_response()),
}
}
/// Resolves model aliases by looking up the requested model in the model_aliases map.
/// Returns the target model if an alias is found, otherwise returns the original model.
fn resolve_model_alias(
model_from_request: &str,
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
) -> String {
if let Some(aliases) = model_aliases.as_ref() {
if let Some(model_alias) = aliases.get(model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To {}'",
model_from_request, model_alias.target
);
return model_alias.target.clone();
}
}
model_from_request.to_string()
}
/// Calculates the upstream path for the provider based on the model name.
/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix,
/// then uses target_endpoint_for_provider to calculate the correct upstream path.
async fn get_upstream_path(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
request_path: &str,
resolved_model: &str,
is_streaming: bool,
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
// Calculate the upstream path using the proper API
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
.expect("Should have valid API endpoint");
client_api.target_endpoint_for_provider(
&provider_id,
request_path,
resolved_model,
is_streaming,
base_url_path_prefix.as_deref(),
)
}
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
async fn get_provider_info(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
// Try to find by model name or provider name using LlmProviders::get
// This handles both "gpt-4" and "openai/gpt-4" formats
if let Some(provider) = providers_lock.get(model_name) {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
// Fall back to default provider
if let Some(provider) = providers_lock.default() {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
// Last resort: use OpenAI as hardcoded fallback
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

View file

@ -0,0 +1,618 @@
use bytes::Bytes;
use common::configuration::ModelAlias;
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, TRACE_PARENT_HEADER};
use common::llm_providers::LlmProviders;
use hermesllm::apis::openai::Message;
use hermesllm::apis::openai_responses::InputParam;
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full};
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use opentelemetry::global;
use opentelemetry::trace::get_active_span;
use opentelemetry_http::HeaderInjector;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, info_span, warn, Instrument};
mod router;
use crate::handlers::request::extract_request_id;
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor};
use crate::router::llm::RouterService;
use crate::state::response_state_processor::ResponsesStateProcessor;
use crate::state::{
extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError,
};
use crate::tracing::{operation_component, set_service_name};
use router::router_chat_get_upstream_model;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
pub async fn llm_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
state_storage: Option<Arc<dyn StateStorage>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
let request_id = extract_request_id(&request);
// Create a span with request_id that will be included in all log lines
let request_span = info_span!(
"llm",
component = "llm",
request_id = %request_id,
http.method = %request.method(),
http.path = %request_path,
);
// Execute the rest of the handler inside the span
llm_chat_inner(
request,
router_service,
full_qualified_llm_provider_url,
model_aliases,
llm_providers,
state_storage,
request_id,
request_path,
request_headers,
)
.instrument(request_span)
.await
}
#[allow(clippy::too_many_arguments)]
async fn llm_chat_inner(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: Arc<RwLock<LlmProviders>>,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
let traceparent = extract_or_generate_traceparent(&request_headers);
// --- Phase 1: Parse and validate the incoming request ---
let parsed =
match parse_and_validate_request(request, &request_path, &model_aliases, &llm_providers)
.await
{
Ok(p) => p,
Err(response) => return Ok(response),
};
let PreparedRequest {
mut client_request,
model_from_request,
alias_resolved_model,
model_name_only,
is_streaming_request,
is_responses_api_client,
messages_for_signals,
} = parsed;
// --- Phase 2: Resolve conversation state (v1/responses API) ---
let state_ctx = match resolve_conversation_state(
&mut client_request,
is_responses_api_client,
&state_storage,
&llm_providers,
&alias_resolved_model,
&request_path,
is_streaming_request,
)
.await
{
Ok(ctx) => ctx,
Err(response) => return Ok(response),
};
// Serialize request for upstream BEFORE router consumes it
let client_request_bytes_for_upstream: Bytes = ProviderRequestType::to_bytes(&client_request)
.unwrap()
.into();
// --- Phase 3: Route the request ---
let routing_span = info_span!(
"routing",
component = "routing",
http.method = "POST",
http.target = %request_path,
model.requested = %model_from_request,
model.alias_resolved = %alias_resolved_model,
route.selected_model = tracing::field::Empty,
routing.determination_ms = tracing::field::Empty,
);
let routing_result = match async {
set_service_name(operation_component::ROUTING);
router_chat_get_upstream_model(
router_service,
client_request,
&traceparent,
&request_path,
&request_id,
)
.await
}
.instrument(routing_span)
.await
{
Ok(result) => result,
Err(err) => {
let mut internal_error = Response::new(full(err.message));
*internal_error.status_mut() = err.status_code;
return Ok(internal_error);
}
};
// Determine final model (router returns "none" when it doesn't select a specific model)
let router_selected_model = routing_result.model_name;
let resolved_model = if router_selected_model != "none" {
router_selected_model
} else {
alias_resolved_model.clone()
};
// --- Phase 4: Forward to upstream and stream back ---
send_upstream(
&full_qualified_llm_provider_url,
&mut request_headers,
client_request_bytes_for_upstream,
&model_from_request,
&alias_resolved_model,
&resolved_model,
&model_name_only,
&request_path,
is_streaming_request,
messages_for_signals,
state_ctx,
state_storage,
request_id,
)
.await
}
// ---------------------------------------------------------------------------
// Phase 1 — Parse & validate the incoming request
// ---------------------------------------------------------------------------
/// All pre-validated request data extracted from the raw HTTP request.
struct PreparedRequest {
client_request: ProviderRequestType,
model_from_request: String,
alias_resolved_model: String,
model_name_only: String,
is_streaming_request: bool,
is_responses_api_client: bool,
messages_for_signals: Option<Vec<Message>>,
}
/// Parse the body, resolve the model alias, and validate the model exists.
///
/// Returns `Err(Response)` for early-exit error responses (400 etc.).
async fn parse_and_validate_request(
request: Request<hyper::body::Incoming>,
request_path: &str,
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
llm_providers: &Arc<RwLock<LlmProviders>>,
) -> Result<PreparedRequest, Response<BoxBody<Bytes, hyper::Error>>> {
let chat_request_bytes = request
.collect()
.await
.map_err(|_| {
let mut r = Response::new(full("Failed to read request body"));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
})?
.to_bytes();
debug!(
body = %String::from_utf8_lossy(&chat_request_bytes),
"request body received"
);
let mut client_request = ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIsFromClient::from_endpoint(request_path).unwrap(),
))
.map_err(|err| {
warn!(error = %err, "failed to parse request as ProviderRequestType");
let mut r = Response::new(full(format!("Failed to parse request: {}", err)));
*r.status_mut() = StatusCode::BAD_REQUEST;
r
})?;
let client_api = SupportedAPIsFromClient::from_endpoint(request_path);
let is_responses_api_client = matches!(
client_api,
Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_))
);
let model_from_request = client_request.model().to_string();
let is_streaming_request = client_request.is_streaming();
let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases);
// Validate model exists in configuration
if llm_providers
.read()
.await
.get(&alias_resolved_model)
.is_none()
{
let err_msg = format!(
"Model '{}' not found in configured providers",
alias_resolved_model
);
warn!(model = %alias_resolved_model, "model not found in configured providers");
let mut r = Response::new(full(err_msg));
*r.status_mut() = StatusCode::BAD_REQUEST;
return Err(r);
}
// Strip provider prefix for upstream (e.g. "openai/gpt-4" → "gpt-4")
let model_name_only = alias_resolved_model
.split_once('/')
.map(|(_, model)| model.to_string())
.unwrap_or_else(|| alias_resolved_model.clone());
// Extract messages for signal analysis before mutating client_request
let messages_for_signals = Some(client_request.get_messages());
// Set the upstream model name and strip routing metadata
client_request.set_model(model_name_only.clone());
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("removed archgw_preference_config from metadata");
}
Ok(PreparedRequest {
client_request,
model_from_request,
alias_resolved_model,
model_name_only,
is_streaming_request,
is_responses_api_client,
messages_for_signals,
})
}
// ---------------------------------------------------------------------------
// Phase 2 — Resolve conversation state (v1/responses API)
// ---------------------------------------------------------------------------
/// Holds the state management context resolved from a v1/responses request.
struct ConversationStateContext {
should_manage_state: bool,
original_input_items: Vec<hermesllm::apis::openai_responses::InputItem>,
}
/// If the client uses the v1/responses API and the upstream provider doesn't
/// support it natively, we manage conversation state ourselves.
///
/// This resolves `previous_response_id`, merges conversation history, and
/// updates the request in place.
///
/// Returns `Err(Response)` for early-exit (e.g. 409 Conflict).
async fn resolve_conversation_state(
client_request: &mut ProviderRequestType,
is_responses_api_client: bool,
state_storage: &Option<Arc<dyn StateStorage>>,
llm_providers: &Arc<RwLock<LlmProviders>>,
alias_resolved_model: &str,
request_path: &str,
is_streaming_request: bool,
) -> Result<ConversationStateContext, Response<BoxBody<Bytes, hyper::Error>>> {
if !is_responses_api_client {
return Ok(ConversationStateContext {
should_manage_state: false,
original_input_items: Vec::new(),
});
}
let (responses_req, state_store) = match (client_request, state_storage) {
(ProviderRequestType::ResponsesAPIRequest(ref mut req), Some(store)) => (req, store),
_ => {
return Ok(ConversationStateContext {
should_manage_state: false,
original_input_items: Vec::new(),
});
}
};
let mut original_input_items = extract_input_items(&responses_req.input);
// Check whether the upstream supports v1/responses natively
let upstream_path = get_upstream_path(
llm_providers,
alias_resolved_model,
request_path,
alias_resolved_model,
is_streaming_request,
)
.await;
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
let should_manage_state = !matches!(
upstream_api,
Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_))
);
if !should_manage_state {
debug!("upstream supports ResponsesAPI natively");
return Ok(ConversationStateContext {
should_manage_state: false,
original_input_items,
});
}
// Retrieve and combine conversation history if previous_response_id exists
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
match retrieve_and_combine_input(state_store.clone(), prev_resp_id, original_input_items)
.await
{
Ok(combined_input) => {
responses_req.input = InputParam::Items(combined_input.clone());
original_input_items = combined_input;
info!(
items = original_input_items.len(),
"updated request with conversation history"
);
}
Err(StateStorageError::NotFound(_)) => {
warn!(previous_response_id = %prev_resp_id, "previous response_id not found");
let err_msg = format!(
"Conversation state not found for previous_response_id: {}",
prev_resp_id
);
let mut r = Response::new(full(err_msg));
*r.status_mut() = StatusCode::CONFLICT;
return Err(r);
}
Err(e) => {
warn!(
previous_response_id = %prev_resp_id,
error = %e,
"failed to retrieve conversation state"
);
// Restore original_input_items since we passed ownership
original_input_items = extract_input_items(&responses_req.input);
}
}
}
Ok(ConversationStateContext {
should_manage_state,
original_input_items,
})
}
// ---------------------------------------------------------------------------
// Phase 4 — Forward to upstream and stream the response back
// ---------------------------------------------------------------------------
#[allow(clippy::too_many_arguments)]
async fn send_upstream(
upstream_url: &str,
request_headers: &mut hyper::HeaderMap,
body: bytes::Bytes,
model_from_request: &str,
alias_resolved_model: &str,
resolved_model: &str,
model_name_only: &str,
request_path: &str,
is_streaming_request: bool,
messages_for_signals: Option<Vec<Message>>,
state_ctx: ConversationStateContext,
state_storage: Option<Arc<dyn StateStorage>>,
request_id: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let span_name = if model_from_request == resolved_model {
format!("POST {} {}", request_path, resolved_model)
} else {
format!(
"POST {} {} -> {}",
request_path, model_from_request, resolved_model
)
};
get_active_span(|span| {
span.update_name(span_name.clone());
});
debug!(
url = %upstream_url,
provider_hint = %resolved_model,
upstream_model = %model_name_only,
"Routing to upstream"
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(resolved_model).unwrap(),
);
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
);
request_headers.remove(header::CONTENT_LENGTH);
// Inject current span's trace context so upstream spans are children of plano(llm)
global::get_text_map_propagator(|propagator| {
let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(request_headers));
});
let request_start_time = std::time::Instant::now();
let llm_response = match reqwest::Client::new()
.post(upstream_url)
.headers(request_headers.clone())
.body(body)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// Propagate upstream headers and status
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (name, value) in response_headers.iter() {
headers.insert(name, value.clone());
}
let byte_stream = llm_response.bytes_stream();
// Create base processor for metrics and tracing
let base_processor = ObservableStreamProcessor::new(
operation_component::LLM,
span_name,
request_start_time,
messages_for_signals,
);
// Wrap with state management processor when needed
let streaming_response = if let (true, false, Some(state_store)) = (
state_ctx.should_manage_state,
state_ctx.original_input_items.is_empty(),
state_storage,
) {
let content_encoding = response_headers
.get("content-encoding")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let state_processor = ResponsesStateProcessor::new(
base_processor,
state_store,
state_ctx.original_input_items,
alias_resolved_model.to_string(),
resolved_model.to_string(),
is_streaming_request,
false,
content_encoding,
request_id,
);
create_streaming_response(byte_stream, state_processor, 16)
} else {
create_streaming_response(byte_stream, base_processor, 16)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Extract or generate a W3C `traceparent` header value.
fn extract_or_generate_traceparent(headers: &hyper::HeaderMap) -> String {
headers
.get(TRACE_PARENT_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
let trace_id = uuid::Uuid::new_v4().to_string().replace("-", "");
let tp = format!("00-{}-0000000000000000-01", trace_id);
warn!(
generated_traceparent = %tp,
"TRACE_PARENT header missing, generated new traceparent"
);
tp
})
}
/// Resolves model aliases by looking up the requested model in the model_aliases map.
/// Returns the target model if an alias is found, otherwise returns the original model.
fn resolve_model_alias(
model_from_request: &str,
model_aliases: &Arc<Option<HashMap<String, ModelAlias>>>,
) -> String {
if let Some(aliases) = model_aliases.as_ref() {
if let Some(model_alias) = aliases.get(model_from_request) {
debug!(
"Model Alias: 'From {}' -> 'To {}'",
model_from_request, model_alias.target
);
return model_alias.target.clone();
}
}
model_from_request.to_string()
}
/// Calculates the upstream path for the provider based on the model name.
async fn get_upstream_path(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
request_path: &str,
resolved_model: &str,
is_streaming: bool,
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
.expect("Should have valid API endpoint");
client_api.target_endpoint_for_provider(
&provider_id,
request_path,
resolved_model,
is_streaming,
base_url_path_prefix.as_deref(),
)
}
/// Helper to get provider info (ProviderId and base_url_path_prefix).
async fn get_provider_info(
llm_providers: &Arc<RwLock<LlmProviders>>,
model_name: &str,
) -> (hermesllm::ProviderId, Option<String>) {
let providers_lock = llm_providers.read().await;
if let Some(provider) = providers_lock.get(model_name) {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
return (provider_id, prefix);
}
if let Some(provider) = providers_lock.default() {
let provider_id = provider.provider_interface.to_provider_id();
let prefix = provider.base_url_path_prefix.clone();
(provider_id, prefix)
} else {
warn!("No default provider found, falling back to OpenAI");
(hermesllm::ProviderId::OpenAI, None)
}
}

View file

@ -5,7 +5,7 @@ use hyper::StatusCode;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use crate::router::llm::RouterService;
use crate::tracing::routing;
pub struct RoutingResult {

View file

@ -1,12 +1,10 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod agents;
pub mod errors;
pub mod function_calling;
pub mod jsonrpc;
pub mod llm;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
pub mod router_chat;
pub mod request;
pub mod response;
pub mod utils;
#[cfg(test)]

View file

@ -0,0 +1,11 @@
use hyper::Request;
/// Extract request ID from incoming request headers, or generate a new UUID v4.
pub fn extract_request_id<T>(request: &Request<T>) -> String {
request
.headers()
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string())
}

View file

@ -1,6 +1,6 @@
pub mod app_state;
pub mod handlers;
pub mod router;
pub mod signals;
pub mod state;
pub mod tracing;
pub mod utils;

View file

@ -1,15 +1,16 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::app_state::AppState;
use brightstaff::handlers::agents::orchestrator::agent_chat;
use brightstaff::handlers::function_calling::function_calling_chat_handler;
use brightstaff::handlers::llm::llm_chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::router::plano_orchestrator::OrchestratorService;
use brightstaff::router::llm::RouterService;
use brightstaff::router::orchestrator::OrchestratorService;
use brightstaff::state::memory::MemoryConversationalStorage;
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use brightstaff::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::configuration::Configuration;
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
@ -29,265 +30,293 @@ use tokio::net::TcpListener;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
pub mod router;
const BIND_ADDRESS: &str = "0.0.0.0:9091";
const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router";
const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router";
// Utility function to extract the context from the incoming request headers
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
/// Extract the OpenTelemetry context propagated via HTTP headers.
fn extract_context_from_request(req: &Request<Incoming>) -> Context {
global::get_text_map_propagator(|propagator| {
propagator.extract(&HeaderExtractor(req.headers()))
})
}
/// An empty HTTP body (used for 404 / OPTIONS responses).
fn empty() -> BoxBody<Bytes, hyper::Error> {
Empty::<Bytes>::new()
.map_err(|never| match never {})
.boxed()
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
/// CORS pre-flight response for the models endpoint.
fn cors_preflight() -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
let headers = response.headers_mut();
headers.insert("Allow", "GET, OPTIONS".parse().unwrap());
headers.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
headers.insert(
"Access-Control-Allow-Headers",
"Authorization, Content-Type".parse().unwrap(),
);
headers.insert(
"Access-Control-Allow-Methods",
"GET, POST, OPTIONS".parse().unwrap(),
);
headers.insert("Content-Type", "application/json".parse().unwrap());
Ok(response)
}
// loading plano_config.yaml file (before tracing init so we can read tracing config)
let plano_config_path = env::var("PLANO_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string());
eprintln!("loading plano_config.yaml from {}", plano_config_path);
// ---------------------------------------------------------------------------
// Configuration loading
// ---------------------------------------------------------------------------
let config_contents =
fs::read_to_string(&plano_config_path).expect("Failed to read plano_config.yaml");
/// Load and parse the YAML configuration file.
///
/// The path is read from `ARCH_CONFIG_PATH_RENDERED` (env) or falls back to
/// `./arch_config_rendered.yaml`.
fn load_config() -> Result<Configuration, Box<dyn std::error::Error + Send + Sync>> {
let path = env::var("ARCH_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
eprintln!("loading arch_config.yaml from {}", path);
let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?;
let config: Configuration =
serde_yaml::from_str(&config_contents).expect("Failed to parse plano_config.yaml");
serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?;
// Initialize tracing using config.yaml tracing section
let _tracer_provider = init_tracer(config.tracing.as_ref());
info!(path = %plano_config_path, "loaded plano_config.yaml");
Ok(config)
}
let plano_config = Arc::new(config);
// ---------------------------------------------------------------------------
// Application state initialization
// ---------------------------------------------------------------------------
// combine agents and filters into a single list of agents
let all_agents: Vec<Agent> = plano_config
/// Build the shared [`AppState`] from a parsed [`Configuration`].
async fn init_app_state(
config: &Configuration,
) -> Result<AppState, Box<dyn std::error::Error + Send + Sync>> {
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
// Combine agents and filters into a single list
let all_agents = config
.agents
.as_deref()
.unwrap_or_default()
.iter()
.chain(plano_config.filters.as_deref().unwrap_or_default())
.chain(config.filters.as_deref().unwrap_or_default())
.cloned()
.collect();
// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
let llm_providers = LlmProviders::try_from(config.model_providers.clone())
.map_err(|e| format!("failed to create LlmProviders: {e}"))?;
let listener = TcpListener::bind(bind_address).await?;
let routing_model_name: String = plano_config
let routing_model_name = config
.routing
.as_ref()
.and_then(|r| r.model.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_MODEL_NAME.to_string());
let routing_llm_provider = plano_config
let routing_llm_provider = config
.routing
.as_ref()
.and_then(|r| r.model_provider.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
plano_config.model_providers.clone(),
let router_service = Arc::new(RouterService::new(
config.model_providers.clone(),
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
routing_model_name,
routing_llm_provider,
));
let orchestrator_service: Arc<OrchestratorService> = Arc::new(OrchestratorService::new(
let orchestrator_service = Arc::new(OrchestratorService::new(
format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"),
PLANO_ORCHESTRATOR_MODEL_NAME.to_string(),
));
let model_aliases = Arc::new(plano_config.model_aliases.clone());
let span_attributes = Arc::new(
plano_config
.tracing
.as_ref()
.and_then(|tracing| tracing.span_attributes.clone()),
);
let state_storage = init_state_storage(config).await?;
// Initialize trace collector and start background flusher
// Tracing is enabled if the tracing config is present in plano_config.yaml
// Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED
// OpenTelemetry automatic instrumentation is configured in utils/tracing.rs
Ok(AppState {
router_service,
orchestrator_service,
model_aliases: Arc::new(config.model_aliases.clone()),
llm_providers: Arc::new(RwLock::new(llm_providers)),
agents_list: Arc::new(RwLock::new(Some(all_agents))),
listeners: Arc::new(RwLock::new(config.listeners.clone())),
state_storage,
llm_provider_url,
})
}
// Initialize conversation state storage for v1/responses
// Configurable via plano_config.yaml state_storage section
// If not configured, state management is disabled
// Environment variables are substituted by envsubst before config is read
let state_storage: Option<Arc<dyn StateStorage>> =
if let Some(storage_config) = &plano_config.state_storage {
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!(
storage_type = "memory",
"initialized conversation state storage"
);
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.expect("connection_string is required for postgres state_storage");
/// Initialize the conversation state storage backend (if configured).
async fn init_state_storage(
config: &Configuration,
) -> Result<Option<Arc<dyn StateStorage>>, Box<dyn std::error::Error + Send + Sync>> {
let Some(storage_config) = &config.state_storage else {
info!("no state_storage configured, conversation state management disabled");
return Ok(None);
};
debug!(connection_string = %connection_string, "postgres connection");
info!(
storage_type = "postgres",
"initializing conversation state storage"
);
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.expect("Failed to initialize Postgres state storage"),
)
}
};
Some(storage)
} else {
info!("no state_storage configured, conversation state management disabled");
None
};
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
common::configuration::StateStorageType::Memory => {
info!(
storage_type = "memory",
"initialized conversation state storage"
);
Arc::new(MemoryConversationalStorage::new())
}
common::configuration::StateStorageType::Postgres => {
let connection_string = storage_config
.connection_string
.as_ref()
.ok_or("connection_string is required for postgres state_storage")?;
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
debug!(connection_string = %connection_string, "postgres connection");
info!(
storage_type = "postgres",
"initializing conversation state storage"
);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let orchestrator_service: Arc<OrchestratorService> = Arc::clone(&orchestrator_service);
let model_aliases: Arc<
Option<std::collections::HashMap<String, common::configuration::ModelAlias>>,
> = Arc::clone(&model_aliases);
let llm_provider_url = llm_provider_url.clone();
Arc::new(
PostgreSQLConversationStorage::new(connection_string.clone())
.await
.map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?,
)
}
};
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let orchestrator_service = Arc::clone(&orchestrator_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
Ok(Some(storage))
}
async move {
let path = req.uri().path();
// Check if path starts with /agents
if path.starts_with("/agents") {
// Check if it matches one of the agent API paths
let stripped_path = path.strip_prefix("/agents").unwrap();
if matches!(
stripped_path,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path);
return agent_chat(
req,
orchestrator_service,
fully_qualified_url,
agents_list,
listeners,
span_attributes,
llm_providers,
)
.with_context(parent_cx)
.await;
}
}
match (req.method(), path) {
(
&Method::POST,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, path);
llm_chat(
req,
router_service,
fully_qualified_url,
model_aliases,
llm_providers,
span_attributes,
state_storage,
)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, "/v1/chat/completions");
function_calling_chat_handler(req, fully_qualified_url)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)
}
// hack for now to get openw-web-ui to work
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
response
.headers_mut()
.insert("Allow", "GET, OPTIONS".parse().unwrap());
response
.headers_mut()
.insert("Access-Control-Allow-Origin", "*".parse().unwrap());
response.headers_mut().insert(
"Access-Control-Allow-Headers",
"Authorization, Content-Type".parse().unwrap(),
);
response.headers_mut().insert(
"Access-Control-Allow-Methods",
"GET, POST, OPTIONS".parse().unwrap(),
);
response
.headers_mut()
.insert("Content-Type", "application/json".parse().unwrap());
// ---------------------------------------------------------------------------
// Request routing
// ---------------------------------------------------------------------------
Ok(response)
}
_ => {
debug!(method = %req.method(), path = %req.uri().path(), "no route found");
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
});
/// Route an incoming HTTP request to the appropriate handler.
async fn route(
req: Request<Incoming>,
state: Arc<AppState>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let parent_cx = extract_context_from_request(&req);
let path = req.uri().path().to_string();
tokio::task::spawn(async move {
debug!(peer = ?peer_addr, "accepted connection");
if let Err(err) = http1::Builder::new()
// .serve_connection(io, service_fn(chat_completion))
.serve_connection(io, service)
// --- Agent routes (/agents/...) ---
if let Some(stripped) = path.strip_prefix("/agents") {
if matches!(
stripped,
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH
) {
return agent_chat(
req,
Arc::clone(&state.orchestrator_service),
Arc::clone(&state.agents_list),
Arc::clone(&state.listeners),
)
.with_context(parent_cx)
.await;
}
}
// --- Standard routes ---
match (req.method(), path.as_str()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
let url = format!("{}{}", state.llm_provider_url, path);
llm_chat(
req,
Arc::clone(&state.router_service),
url,
Arc::clone(&state.model_aliases),
Arc::clone(&state.llm_providers),
state.state_storage.clone(),
)
.with_context(parent_cx)
.await
}
(&Method::POST, "/function_calling") => {
let url = format!("{}/v1/chat/completions", state.llm_provider_url);
function_calling_chat_handler(req, url)
.with_context(parent_cx)
.await
{
warn!(error = ?err, "error serving connection");
}
});
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(Arc::clone(&state.llm_providers)).await)
}
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(),
_ => {
debug!(method = %req.method(), path = %path, "no route found");
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)
}
}
}
// ---------------------------------------------------------------------------
// Server loop
// ---------------------------------------------------------------------------
/// Accept connections and spawn a task per connection.
///
/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing
/// in-flight connections to finish.
async fn run_server(state: Arc<AppState>) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
let listener = TcpListener::bind(&bind_address).await?;
info!(address = %bind_address, "server listening");
let shutdown = tokio::signal::ctrl_c();
tokio::pin!(shutdown);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, _) = result?;
let peer_addr = stream.peer_addr()?;
let io = TokioIo::new(stream);
let state = Arc::clone(&state);
tokio::task::spawn(async move {
debug!(peer = ?peer_addr, "accepted connection");
let service = service_fn(move |req| {
let state = Arc::clone(&state);
async move { route(req, state).await }
});
if let Err(err) = http1::Builder::new().serve_connection(io, service).await {
warn!(error = ?err, "error serving connection");
}
});
}
_ = &mut shutdown => {
info!("received shutdown signal, stopping server");
break;
}
}
}
Ok(())
}
// ---------------------------------------------------------------------------
// Entry point
// ---------------------------------------------------------------------------
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let config = load_config()?;
let _tracer_provider = init_tracer(config.tracing.as_ref());
let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
info!(path = %arch_config_path, "loaded arch_config.yaml");
let state = Arc::new(init_app_state(&config).await?);
run_server(state).await
}

View file

@ -0,0 +1,48 @@
use hermesllm::apis::openai::ChatCompletionsResponse;
use hyper::header;
use thiserror::Error;
use tracing::warn;
#[derive(Debug, Error)]
pub enum HttpError {
#[error("Failed to send request: {0}")]
Request(#[from] reqwest::Error),
#[error("Failed to parse JSON response: {0}")]
Json(serde_json::Error, String),
}
/// Sends a POST request to the given URL and extracts the text content
/// from the first choice of the `ChatCompletionsResponse`.
///
/// Returns `Some((content, elapsed))` on success, or `None` if the response
/// had no choices or the first choice had no content.
pub async fn post_and_extract_content(
client: &reqwest::Client,
url: &str,
headers: header::HeaderMap,
body: String,
) -> Result<Option<(String, std::time::Duration)>, HttpError> {
let start_time = std::time::Instant::now();
let res = client.post(url).headers(headers).body(body).send().await?;
let body = res.text().await?;
let elapsed = start_time.elapsed();
let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| {
warn!(error = %err, body = %body, "failed to parse json response");
HttpError::Json(err, format!("Failed to parse JSON: {}", body))
})?;
if response.choices.is_empty() {
warn!(body = %body, "no choices in response");
return Ok(None);
}
Ok(response.choices[0]
.message
.content
.as_ref()
.map(|c| (c.clone(), elapsed)))
}

View file

@ -4,15 +4,16 @@ use common::{
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER},
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hermesllm::apis::openai::Message;
use hyper::header;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::router::router_model_v1::{self};
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::router_model::RouterModel;
use crate::router::router_model_v1;
pub struct RouterService {
router_url: String,
client: reqwest::Client,
@ -24,11 +25,8 @@ pub struct RouterService {
#[derive(Debug, Error)]
pub enum RoutingError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error(transparent)]
Http(#[from] http::HttpError),
#[error("Router model error: {0}")]
RouterModelError(#[from] super::router_model::RoutingModelError),
@ -101,87 +99,48 @@ impl RouterService {
"sending request to arch-router"
);
debug!(
body = %serde_json::to_string(&router_request).unwrap(),
"arch router request"
);
let body = serde_json::to_string(&router_request).unwrap();
debug!(body = %body, "arch router request");
let mut llm_route_request_headers = header::HeaderMap::new();
llm_route_request_headers.insert(
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
llm_route_request_headers.insert(
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
);
llm_route_request_headers.insert(
headers.insert(
header::HeaderName::from_static(TRACE_PARENT_HEADER),
header::HeaderValue::from_str(traceparent).unwrap(),
);
llm_route_request_headers.insert(
headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(request_id).unwrap(),
);
llm_route_request_headers.insert(
headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static("arch-router"),
);
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.router_url)
.headers(llm_route_request_headers)
.body(serde_json::to_string(&router_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let router_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
error = %err,
body = %serde_json::to_string(&body).unwrap(),
"failed to parse json response"
);
return Err(RoutingError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
let Some((content, elapsed)) =
post_and_extract_content(&self.client, &self.router_url, headers, body).await?
else {
return Ok(None);
};
if chat_completion_response.choices.is_empty() {
warn!(body = %body, "no choices in router response");
return Ok(None);
}
let parsed = self
.router_model
.parse_response(&content, &usage_preferences)?;
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.router_model
.parse_response(content, &usage_preferences)?;
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed_response,
response_time_ms = router_response_time.as_millis(),
"arch-router determined route"
);
info!(
content = %content.replace("\n", "\\n"),
selected_model = ?parsed,
response_time_ms = elapsed.as_millis(),
"arch-router determined route"
);
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(None)
} else {
Ok(None)
}
Ok(parsed)
}
}

View file

@ -1,6 +1,7 @@
pub mod llm_router;
pub(crate) mod http;
pub mod llm;
pub mod orchestrator;
pub mod orchestrator_model;
pub mod orchestrator_model_v1;
pub mod plano_orchestrator;
pub mod router_model;
pub mod router_model_v1;

View file

@ -4,17 +4,18 @@ use common::{
configuration::{AgentUsagePreference, OrchestrationPreference},
consts::{ARCH_PROVIDER_HINT_HEADER, PLANO_ORCHESTRATOR_MODEL_NAME, REQUEST_ID_HEADER},
};
use hermesllm::apis::openai::{ChatCompletionsResponse, Message};
use hermesllm::apis::openai::Message;
use hyper::header;
use opentelemetry::global;
use opentelemetry_http::HeaderInjector;
use thiserror::Error;
use tracing::{debug, info, warn};
use crate::router::orchestrator_model_v1::{self};
use tracing::{debug, info};
use super::http::{self, post_and_extract_content};
use super::orchestrator_model::OrchestratorModel;
use crate::router::orchestrator_model_v1;
pub struct OrchestratorService {
orchestrator_url: String,
client: reqwest::Client,
@ -23,11 +24,8 @@ pub struct OrchestratorService {
#[derive(Debug, Error)]
pub enum OrchestrationError {
#[error("Failed to send request: {0}")]
RequestError(#[from] reqwest::Error),
#[error("Failed to parse JSON: {0}, JSON: {1}")]
JsonError(serde_json::Error, String),
#[error(transparent)]
Http(#[from] http::HttpError),
#[error("Orchestrator model error: {0}")]
OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError),
@ -37,7 +35,6 @@ pub type Result<T> = std::result::Result<T, OrchestrationError>;
impl OrchestratorService {
pub fn new(orchestrator_url: String, orchestration_model_name: String) -> Self {
// Empty agent orchestrations - will be provided via usage_preferences in requests
let agent_orchestrations: HashMap<String, Vec<OrchestrationPreference>> = HashMap::new();
let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new(
@ -63,7 +60,6 @@ impl OrchestratorService {
return Ok(None);
}
// Require usage_preferences to be provided
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
return Ok(None);
}
@ -78,18 +74,15 @@ impl OrchestratorService {
"sending request to arch-orchestrator"
);
debug!(
body = %serde_json::to_string(&orchestrator_request).unwrap(),
"arch orchestrator request"
);
let body = serde_json::to_string(&orchestrator_request).unwrap();
debug!(body = %body, "arch orchestrator request");
let mut orchestration_request_headers = header::HeaderMap::new();
orchestration_request_headers.insert(
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
orchestration_request_headers.insert(
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
);
@ -98,71 +91,38 @@ impl OrchestratorService {
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut orchestration_request_headers));
propagator.inject_context(&cx, &mut HeaderInjector(&mut headers));
});
if let Some(request_id) = request_id {
orchestration_request_headers.insert(
if let Some(ref request_id) = request_id {
headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(&request_id).unwrap(),
header::HeaderValue::from_str(request_id).unwrap(),
);
}
orchestration_request_headers.insert(
headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
);
let start_time = std::time::Instant::now();
let res = self
.client
.post(&self.orchestrator_url)
.headers(orchestration_request_headers)
.body(serde_json::to_string(&orchestrator_request).unwrap())
.send()
.await?;
let body = res.text().await?;
let orchestrator_response_time = start_time.elapsed();
let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) {
Ok(response) => response,
Err(err) => {
warn!(
error = %err,
body = %serde_json::to_string(&body).unwrap(),
"failed to parse json response"
);
return Err(OrchestrationError::JsonError(
err,
format!("Failed to parse JSON: {}", body),
));
}
let Some((content, elapsed)) =
post_and_extract_content(&self.client, &self.orchestrator_url, headers, body).await?
else {
return Ok(None);
};
if chat_completion_response.choices.is_empty() {
warn!(body = %body, "no choices in orchestrator response");
return Ok(None);
}
let parsed = self
.orchestrator_model
.parse_response(&content, &usage_preferences)?;
if let Some(content) = &chat_completion_response.choices[0].message.content {
let parsed_response = self
.orchestrator_model
.parse_response(content, &usage_preferences)?;
info!(
content = %content.replace("\n", "\\n"),
selected_routes = ?parsed_response,
response_time_ms = orchestrator_response_time.as_millis(),
"arch-orchestrator determined routes"
);
info!(
content = %content.replace("\n", "\\n"),
selected_routes = ?parsed,
response_time_ms = elapsed.as_millis(),
"arch-orchestrator determined routes"
);
if let Some(ref parsed_response) = parsed_response {
return Ok(Some(parsed_response.clone()));
}
Ok(None)
} else {
Ok(None)
}
Ok(parsed)
}
}

View file

@ -100,23 +100,6 @@ pub trait StateStorage: Send + Sync {
}
}
/// Storage backend type enum
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StorageBackend {
Memory,
Supabase,
}
impl StorageBackend {
pub fn parse_backend(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"memory" => Some(StorageBackend::Memory),
"supabase" => Some(StorageBackend::Supabase),
_ => None,
}
}
}
// === Utility functions for state management ===
/// Extract input items from InputParam, converting text to structured format

View file

@ -11,7 +11,7 @@ use tracing_subscriber::registry::LookupSpan;
use tracing_subscriber::util::SubscriberInitExt;
use tracing_subscriber::EnvFilter;
use crate::tracing::ServiceNameOverrideExporter;
use super::ServiceNameOverrideExporter;
use common::configuration::Tracing;
struct BracketedTime;
@ -96,17 +96,13 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid
tracing_enabled, otel_endpoint, random_sampling
);
// Create OTLP exporter to send spans to collector
if tracing_enabled {
// Set service name via environment if not already set
if std::env::var("OTEL_SERVICE_NAME").is_err() {
std::env::set_var("OTEL_SERVICE_NAME", "plano");
}
// Create OTLP exporter to send spans to collector.
// Use `if let` to destructure the endpoint, avoiding an unwrap.
if let Some(endpoint) = otel_endpoint.as_deref().filter(|_| tracing_enabled) {
// Create ServiceNameOverrideExporter to support per-span service names
// This allows spans to have different service names (e.g., plano(orchestrator),
// plano(filter), plano(llm)) by setting the "service.name.override" attribute
let exporter = ServiceNameOverrideExporter::new(otel_endpoint.as_ref().unwrap());
let exporter = ServiceNameOverrideExporter::new(endpoint);
let provider = SdkTracerProvider::builder()
.with_batch_exporter(exporter)

View file

@ -1,11 +1,11 @@
mod constants;
mod custom_attributes;
mod init;
mod service_name_exporter;
pub use constants::{
error, http, llm, operation_component, routing, signals, OperationNameBuilder,
};
pub use custom_attributes::{append_span_attributes, collect_custom_trace_attributes};
pub use init::init_tracer;
pub use service_name_exporter::{ServiceNameOverrideExporter, SERVICE_NAME_OVERRIDE_KEY};
use opentelemetry::trace::get_active_span;

View file

@ -1 +0,0 @@
pub mod tracing;