diff --git a/crates/brightstaff/src/app_state.rs b/crates/brightstaff/src/app_state.rs new file mode 100644 index 00000000..57707f6e --- /dev/null +++ b/crates/brightstaff/src/app_state.rs @@ -0,0 +1,29 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use common::configuration::{Agent, FilterPipeline, Listener, ModelAlias, SpanAttributes}; +use common::llm_providers::LlmProviders; +use tokio::sync::RwLock; + +use crate::router::llm::RouterService; +use crate::router::orchestrator::OrchestratorService; +use crate::state::StateStorage; + +/// Shared application state bundled into a single Arc-wrapped struct. +/// +/// Instead of cloning 8+ individual `Arc`s per connection, a single +/// `Arc` is cloned once and passed to the request handler. +pub struct AppState { + pub router_service: Arc, + pub orchestrator_service: Arc, + pub model_aliases: Option>, + pub llm_providers: Arc>, + pub agents_list: Option>, + pub listeners: Vec, + pub state_storage: Option>, + pub llm_provider_url: String, + pub span_attributes: Option, + /// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive). + pub http_client: reqwest::Client, + pub filter_pipeline: Arc, +} diff --git a/crates/brightstaff/src/handlers/agents/errors.rs b/crates/brightstaff/src/handlers/agents/errors.rs new file mode 100644 index 00000000..478b4380 --- /dev/null +++ b/crates/brightstaff/src/handlers/agents/errors.rs @@ -0,0 +1,41 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use hyper::Response; +use serde_json::json; +use tracing::{info, warn}; + +use crate::handlers::response::ResponseHandler; + +/// Build a JSON error response from an `AgentFilterChainError`, logging the +/// full error chain along the way. +/// +/// Returns `Ok(Response)` so it can be used directly as a handler return value. +pub fn build_error_chain_response( + err: &E, +) -> Result>, hyper::Error> { + let mut error_chain = Vec::new(); + let mut current: &dyn std::error::Error = err; + loop { + error_chain.push(current.to_string()); + match current.source() { + Some(source) => current = source, + None => break, + } + } + + warn!(error_chain = ?error_chain, "agent chat error chain"); + warn!(root_error = ?err, "root error"); + + let error_json = json!({ + "error": { + "type": "AgentFilterChainError", + "message": err.to_string(), + "error_chain": error_chain, + "debug_info": format!("{:?}", err) + } + }); + + info!(error = %error_json, "structured error info"); + + Ok(ResponseHandler::create_json_error_response(&error_json)) +} diff --git a/crates/brightstaff/src/handlers/jsonrpc.rs b/crates/brightstaff/src/handlers/agents/jsonrpc.rs similarity index 100% rename from crates/brightstaff/src/handlers/jsonrpc.rs rename to crates/brightstaff/src/handlers/agents/jsonrpc.rs diff --git a/crates/brightstaff/src/handlers/agents/mod.rs b/crates/brightstaff/src/handlers/agents/mod.rs new file mode 100644 index 00000000..5b507907 --- /dev/null +++ b/crates/brightstaff/src/handlers/agents/mod.rs @@ -0,0 +1,5 @@ +pub mod errors; +pub mod jsonrpc; +pub mod orchestrator; +pub mod pipeline; +pub mod selector; diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agents/orchestrator.rs similarity index 53% rename from crates/brightstaff/src/handlers/agent_chat_completions.rs rename to crates/brightstaff/src/handlers/agents/orchestrator.rs index 44bb3235..8ece914e 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agents/orchestrator.rs @@ -2,63 +2,56 @@ use std::sync::Arc; use std::time::Instant; use bytes::Bytes; -use common::configuration::SpanAttributes; -use common::errors::BrightStaffError; -use common::llm_providers::LlmProviders; use hermesllm::apis::OpenAIMessage; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::providers::request::ProviderRequest; use hermesllm::ProviderRequestType; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; -use hyper::{Request, Response, StatusCode}; +use hyper::{Request, Response}; use opentelemetry::trace::get_active_span; -use serde::ser::Error as SerError; -use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; -use super::agent_selector::{AgentSelectionError, AgentSelector}; -use super::pipeline_processor::{PipelineError, PipelineProcessor}; -use super::response_handler::ResponseHandler; -use crate::router::plano_orchestrator::OrchestratorService; +use super::errors::build_error_chain_response; +use super::pipeline::{PipelineError, PipelineProcessor}; +use super::selector::{AgentSelectionError, AgentSelector}; +use crate::app_state::AppState; +use crate::handlers::extract_request_id; +use crate::handlers::response::ResponseHandler; use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name}; /// Main errors for agent chat completions #[derive(Debug, thiserror::Error)] pub enum AgentFilterChainError { - #[error("Forwarded error: {0}")] - Brightstaff(#[from] BrightStaffError), #[error("Agent selection error: {0}")] Selection(#[from] AgentSelectionError), #[error("Pipeline processing error: {0}")] Pipeline(#[from] PipelineError), + #[error("Response handling error: {0}")] + Response(#[from] common::errors::BrightStaffError), #[error("Request parsing error: {0}")] - RequestParsing(#[from] serde_json::Error), + RequestParsing(String), #[error("HTTP error: {0}")] Http(#[from] hyper::Error), + #[error("Unsupported endpoint: {0}")] + UnsupportedEndpoint(String), + #[error("No agents configured")] + NoAgentsConfigured, + #[error("Agent '{0}' not found in configuration")] + AgentNotFound(String), + #[error("No messages in conversation history")] + EmptyHistory, + #[error("Agent chain completed without producing a response")] + IncompleteChain, } pub async fn agent_chat( request: Request, - orchestrator_service: Arc, - _: String, - agents_list: Arc>>>, - listeners: Arc>>, - span_attributes: Arc>, - llm_providers: Arc>, + state: Arc, ) -> Result>, hyper::Error> { + let request_id = extract_request_id(&request); let custom_attrs = - collect_custom_trace_attributes(request.headers(), span_attributes.as_ref().as_ref()); - // Extract request_id from headers or generate a new one - let request_id: String = match request - .headers() - .get(common::consts::REQUEST_ID_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()) - { - Some(id) => id, - None => uuid::Uuid::new_v4().to_string(), - }; + collect_custom_trace_attributes(request.headers(), state.span_attributes.as_ref()); // Create a span with request_id that will be included in all log lines let request_span = info_span!( @@ -74,17 +67,7 @@ pub async fn agent_chat( // Set service name for orchestrator operations set_service_name(operation_component::ORCHESTRATOR); - match handle_agent_chat_inner( - request, - orchestrator_service, - agents_list, - listeners, - llm_providers, - request_id, - custom_attrs, - ) - .await - { + match handle_agent_chat_inner(request, state, request_id, custom_attrs).await { Ok(response) => Ok(response), Err(err) => { // Check if this is a client error from the pipeline that should be cascaded @@ -101,7 +84,6 @@ pub async fn agent_chat( "client error from agent" ); - // Create error response with the original status code and body let error_json = serde_json::json!({ "error": "ClientError", "agent": agent, @@ -109,52 +91,19 @@ pub async fn agent_chat( "agent_response": body }); - let status_code = hyper::StatusCode::from_u16(*status) - .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); - let json_string = error_json.to_string(); - return Ok(BrightStaffError::ForwardedError { - status_code, - message: json_string, - } - .into_response()); + let mut response = + Response::new(ResponseHandler::create_full_body(json_string)); + *response.status_mut() = hyper::StatusCode::from_u16(*status) + .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); + response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + hyper::header::HeaderValue::from_static("application/json"), + ); + return Ok(response); } - // Print detailed error information with full error chain for other errors - let mut error_chain = Vec::new(); - let mut current_error: &dyn std::error::Error = &err; - - // Collect the full error chain - loop { - error_chain.push(current_error.to_string()); - match current_error.source() { - Some(source) => current_error = source, - None => break, - } - } - - // Log the complete error chain - warn!(error_chain = ?error_chain, "agent chat error chain"); - warn!(root_error = ?err, "root error"); - - // Create structured error response as JSON - let error_json = serde_json::json!({ - "error": { - "type": "AgentFilterChainError", - "message": err.to_string(), - "error_chain": error_chain, - "debug_info": format!("{:?}", err) - } - }); - - // Log the error for debugging - info!(error = %error_json, "structured error info"); - - Ok(BrightStaffError::ForwardedError { - status_code: StatusCode::BAD_REQUEST, - message: error_json.to_string(), - } - .into_response()) + build_error_chain_response(&err) } } } @@ -162,19 +111,22 @@ pub async fn agent_chat( .await } -async fn handle_agent_chat_inner( +/// Parsed and validated agent request data. +struct AgentRequest { + client_request: ProviderRequestType, + messages: Vec, + request_headers: hyper::HeaderMap, + request_id: Option, +} + +/// Parse the incoming HTTP request, resolve the listener, and extract messages. +async fn parse_agent_request( request: Request, - orchestrator_service: Arc, - agents_list: Arc>>>, - listeners: Arc>>, - llm_providers: Arc>, - request_id: String, - custom_attrs: std::collections::HashMap, -) -> Result>, AgentFilterChainError> { - // Initialize services - let agent_selector = AgentSelector::new(orchestrator_service); - let mut pipeline_processor = PipelineProcessor::default(); - let response_handler = ResponseHandler::new(); + state: &AppState, + request_id: &str, + custom_attrs: &std::collections::HashMap, +) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> { + let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service)); // Extract listener name from headers let listener_name = request @@ -183,16 +135,11 @@ async fn handle_agent_chat_inner( .and_then(|name| name.to_str().ok()); // Find the appropriate listener - let listener: common::configuration::Listener = { - let listeners = listeners.read().await; - agent_selector - .find_listener(listener_name, &listeners) - .await? - }; + let listener = agent_selector.find_listener(listener_name, &state.listeners)?; get_active_span(|span| { span.update_name(listener.name.to_string()); - for (key, value) in &custom_attrs { + for (key, value) in custom_attrs { span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); } }); @@ -200,24 +147,20 @@ async fn handle_agent_chat_inner( info!(listener = %listener.name, "handling request"); // Parse request body - let request_path = request - .uri() - .path() - .to_string() + let full_path = request.uri().path().to_string(); + let request_path = full_path .strip_prefix("/agents") - .unwrap() + .unwrap_or(&full_path) .to_string(); let request_headers = { let mut headers = request.headers().clone(); headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER); - // Set the request_id in headers if not already present if !headers.contains_key(common::consts::REQUEST_ID_HEADER) { - headers.insert( - common::consts::REQUEST_ID_HEADER, - hyper::header::HeaderValue::from_str(&request_id).unwrap(), - ); + if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) { + headers.insert(common::consts::REQUEST_ID_HEADER, val); + } } headers @@ -230,63 +173,62 @@ async fn handle_agent_chat_inner( "received request body" ); - // Determine the API type from the endpoint let api_type = SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| { - let err_msg = format!("Unsupported endpoint: {}", request_path); - warn!("{}", err_msg); - AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) + warn!(path = %request_path, "unsupported endpoint"); + AgentFilterChainError::UnsupportedEndpoint(request_path.clone()) })?; - let mut client_request = - match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { - Ok(request) => request, - Err(err) => { - warn!("failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); - return Err(AgentFilterChainError::RequestParsing( - serde_json::Error::custom(err_msg), - )); - } - }; + let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) + .map_err(|err| { + warn!(error = %err, "failed to parse request as ProviderRequestType"); + AgentFilterChainError::RequestParsing(format!("Failed to parse request: {}", err)) + })?; - // If model is not specified in the request, resolve from default provider - if client_request.model().is_empty() { - match llm_providers.read().await.default() { - Some(default_provider) => { - let default_model = default_provider.name.clone(); - info!(default_model = %default_model, "no model specified in request, using default provider"); - client_request.set_model(default_model); - } - None => { - let err_msg = "No model specified in request and no default provider configured"; - warn!("{}", err_msg); - return Ok(BrightStaffError::NoModelSpecified.into_response()); - } - } - } - - let message: Vec = client_request.get_messages(); + let messages: Vec = client_request.get_messages(); let request_id = request_headers .get(common::consts::REQUEST_ID_HEADER) .and_then(|val| val.to_str().ok()) .map(|s| s.to_string()); - // Create agent map for pipeline processing and agent selection - let agent_map = { - let agents = agents_list.read().await; - let agents = agents.as_ref().unwrap(); - agent_selector.create_agent_map(agents) - }; + Ok(( + AgentRequest { + client_request, + messages, + request_headers, + request_id, + }, + listener, + agent_selector, + )) +} + +/// Select agents via the orchestrator model and record selection metrics. +async fn select_and_build_agent_map( + agent_selector: &AgentSelector, + state: &AppState, + messages: &[OpenAIMessage], + listener: &common::configuration::Listener, + request_id: Option, +) -> Result< + ( + Vec, + std::collections::HashMap, + ), + AgentFilterChainError, +> { + let agents = state + .agents_list + .as_ref() + .ok_or(AgentFilterChainError::NoAgentsConfigured)?; + let agent_map = agent_selector.create_agent_map(agents); - // Select appropriate agents using arch orchestrator llm model let selection_start = Instant::now(); let selected_agents = agent_selector - .select_agents(&message, &listener, request_id.clone()) + .select_agents(messages, listener, request_id) .await?; - // Record selection attributes on the current orchestrator span let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0; get_active_span(|span| { span.set_attribute(opentelemetry::KeyValue::new( @@ -316,12 +258,25 @@ async fn handle_agent_chat_inner( "selected agents for execution" ); - // Execute agents sequentially, passing output from one to the next - let mut current_messages = message.clone(); + Ok((selected_agents, agent_map)) +} + +/// Execute the agent chain: run each selected agent sequentially, streaming +/// the final agent's response back to the client. +async fn execute_agent_chain( + selected_agents: &[common::configuration::AgentFilterChain], + agent_map: &std::collections::HashMap, + client_request: ProviderRequestType, + messages: Vec, + request_headers: &hyper::HeaderMap, + custom_attrs: &std::collections::HashMap, +) -> Result>, AgentFilterChainError> { + let mut pipeline_processor = PipelineProcessor::default(); + let response_handler = ResponseHandler::new(); + let mut current_messages = messages; let agent_count = selected_agents.len(); for (agent_index, selected_agent) in selected_agents.iter().enumerate() { - // Get agent name let agent_name = selected_agent.id.clone(); let is_last_agent = agent_index == agent_count - 1; @@ -332,8 +287,6 @@ async fn handle_agent_chat_inner( "processing agent" ); - // Process input filters — serialize current request as OpenAI chat completions body, - // pass raw bytes through each filter, then extract updated messages from the result. let chat_history = if selected_agent .input_filters .as_ref() @@ -351,8 +304,8 @@ async fn handle_agent_chat_inner( .process_raw_filter_chain( &filter_bytes, selected_agent, - &agent_map, - &request_headers, + agent_map, + request_headers, "/v1/chat/completions", ) .await?; @@ -365,8 +318,9 @@ async fn handle_agent_chat_inner( current_messages.clone() }; - // Get agent details and invoke - let agent = agent_map.get(&agent_name).unwrap(); + let agent = agent_map + .get(&agent_name) + .ok_or_else(|| AgentFilterChainError::AgentNotFound(agent_name.clone()))?; debug!(agent = %agent_name, "invoking agent"); @@ -380,7 +334,7 @@ async fn handle_agent_chat_inner( set_service_name(operation_component::AGENT); get_active_span(|span| { span.update_name(format!("{} /v1/chat/completions", agent_name)); - for (key, value) in &custom_attrs { + for (key, value) in custom_attrs { span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); } }); @@ -390,28 +344,25 @@ async fn handle_agent_chat_inner( &chat_history, client_request.clone(), agent, - &request_headers, + request_headers, ) .await } .instrument(agent_span.clone()) .await?; - // If this is the last agent, return the streaming response if is_last_agent { info!( agent = %agent_name, "completed agent chain, returning response" ); - // Capture the orchestrator span (parent of the agent span) so it - // stays open for the full streaming duration alongside the agent span. let orchestrator_span = tracing::Span::current(); return async { response_handler .create_streaming_response( llm_response, - tracing::Span::current(), // agent span (inner) - orchestrator_span, // orchestrator span (outer) + tracing::Span::current(), + orchestrator_span, ) .await .map_err(AgentFilterChainError::from) @@ -420,7 +371,6 @@ async fn handle_agent_chat_inner( .await; } - // For intermediate agents, collect the full response and pass to next agent debug!(agent = %agent_name, "collecting response from intermediate agent"); let response_text = async { response_handler.collect_full_response(llm_response).await } .instrument(agent_span) @@ -432,11 +382,11 @@ async fn handle_agent_chat_inner( "agent completed, passing response to next agent" ); - // remove last message and add new one at the end - let last_message = current_messages.pop().unwrap(); + let Some(last_message) = current_messages.pop() else { + warn!(agent = %agent_name, "no messages in conversation history"); + return Err(AgentFilterChainError::EmptyHistory); + }; - // Create a new message with the agent's response as assistant message - // and add it to the conversation history current_messages.push(OpenAIMessage { role: hermesllm::apis::openai::Role::Assistant, content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)), @@ -448,6 +398,34 @@ async fn handle_agent_chat_inner( current_messages.push(last_message); } - // This should never be reached since we return in the last agent iteration - unreachable!("Agent execution loop should have returned a response") + Err(AgentFilterChainError::IncompleteChain) +} + +async fn handle_agent_chat_inner( + request: Request, + state: Arc, + request_id: String, + custom_attrs: std::collections::HashMap, +) -> Result>, AgentFilterChainError> { + let (agent_req, listener, agent_selector) = + parse_agent_request(request, &state, &request_id, &custom_attrs).await?; + + let (selected_agents, agent_map) = select_and_build_agent_map( + &agent_selector, + &state, + &agent_req.messages, + &listener, + agent_req.request_id, + ) + .await?; + + execute_agent_chain( + &selected_agents, + &agent_map, + agent_req.client_request, + agent_req.messages, + &agent_req.request_headers, + &custom_attrs, + ) + .await } diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/agents/pipeline.rs similarity index 90% rename from crates/brightstaff/src/handlers/pipeline_processor.rs rename to crates/brightstaff/src/handlers/agents/pipeline.rs index 4cb8531f..50058441 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/agents/pipeline.rs @@ -12,7 +12,7 @@ use opentelemetry::global; use opentelemetry_http::HeaderInjector; use tracing::{debug, info, instrument, warn}; -use crate::handlers::jsonrpc::{ +use super::jsonrpc::{ JsonRpcId, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, JSON_RPC_VERSION, MCP_INITIALIZE, MCP_INITIALIZE_NOTIFICATION, TOOL_CALL_METHOD, }; @@ -78,12 +78,11 @@ impl PipelineProcessor { } } - /// Build common MCP headers for requests - fn build_mcp_headers( - &self, + /// Prepare headers shared by all agent/filter requests: removes + /// content-length, injects trace context, sets upstream host and retry. + fn build_agent_headers( request_headers: &HeaderMap, agent_id: &str, - session_id: Option<&str>, ) -> Result { let mut headers = request_headers.clone(); headers.remove(hyper::header::CONTENT_LENGTH); @@ -104,24 +103,34 @@ impl PipelineProcessor { headers.insert( ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), + hyper::header::HeaderValue::from_static("3"), ); + Ok(headers) + } + + /// Build headers for MCP requests (adds Accept, Content-Type, optional session id). + fn build_mcp_headers( + &self, + request_headers: &HeaderMap, + agent_id: &str, + session_id: Option<&str>, + ) -> Result { + let mut headers = Self::build_agent_headers(request_headers, agent_id)?; + headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json, text/event-stream"), ); - headers.insert( "Content-Type", hyper::header::HeaderValue::from_static("application/json"), ); if let Some(sid) = session_id { - headers.insert( - "mcp-session-id", - hyper::header::HeaderValue::from_str(sid).unwrap(), - ); + if let Ok(val) = hyper::header::HeaderValue::from_str(sid) { + headers.insert("mcp-session-id", val); + } } Ok(headers) @@ -243,7 +252,7 @@ impl PipelineProcessor { let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { session_id.clone() } else { - let session_id = self.get_new_session_id(&agent.id, request_headers).await; + let session_id = self.get_new_session_id(&agent.id, request_headers).await?; self.agent_id_session_map .insert(agent.id.clone(), session_id.clone()); session_id @@ -394,18 +403,19 @@ impl PipelineProcessor { Ok(()) } - async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String { + async fn get_new_session_id( + &self, + agent_id: &str, + request_headers: &HeaderMap, + ) -> Result { info!("initializing MCP session for agent {}", agent_id); let initialize_request = self.build_initialize_request(); - let headers = self - .build_mcp_headers(request_headers, agent_id, None) - .expect("Failed to build headers for initialization"); + let headers = self.build_mcp_headers(request_headers, agent_id, None)?; let response = self .send_mcp_request(&initialize_request, &headers, agent_id) - .await - .expect("Failed to initialize MCP session"); + .await?; info!("initialize response status: {}", response.status()); @@ -413,8 +423,13 @@ impl PipelineProcessor { .headers() .get("mcp-session-id") .and_then(|v| v.to_str().ok()) - .expect("No mcp-session-id in response") - .to_string(); + .map(|s| s.to_string()) + .ok_or_else(|| { + PipelineError::NoContentInResponse(format!( + "No mcp-session-id header in initialize response from agent {}", + agent_id + )) + })?; info!( "created new MCP session for agent {}: {}", @@ -423,10 +438,9 @@ impl PipelineProcessor { // Send initialized notification self.send_initialized_notification(agent_id, &session_id, &headers) - .await - .expect("Failed to send initialized notification"); + .await?; - session_id + Ok(session_id) } /// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back. @@ -454,25 +468,7 @@ impl PipelineProcessor { span.update_name(format!("execute_raw_filter ({})", agent.id)); }); - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&agent.id) - .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, - ); - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); + let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?; agent_headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json"), @@ -578,36 +574,15 @@ impl PipelineProcessor { terminal_agent: &Agent, request_headers: &HeaderMap, ) -> Result { - // let mut request = original_request.clone(); original_request.set_messages(messages); let request_url = "/v1/chat/completions"; - let request_body = ProviderRequestType::to_bytes(&original_request).unwrap(); - // let request_body = serde_json::to_string(&request)?; + let request_body = ProviderRequestType::to_bytes(&original_request) + .map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?; debug!("sending request to terminal agent {}", terminal_agent.id); - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - // Inject OpenTelemetry trace context automatically - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&terminal_agent.id) - .map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?, - ); - - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); + let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?; let response = self .client diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agents/selector.rs similarity index 94% rename from crates/brightstaff/src/handlers/agent_selector.rs rename to crates/brightstaff/src/handlers/agents/selector.rs index 2341e156..8225a003 100644 --- a/crates/brightstaff/src/handlers/agent_selector.rs +++ b/crates/brightstaff/src/handlers/agents/selector.rs @@ -7,7 +7,7 @@ use common::configuration::{ use hermesllm::apis::openai::Message; use tracing::{debug, warn}; -use crate::router::plano_orchestrator::OrchestratorService; +use crate::router::orchestrator::OrchestratorService; /// Errors that can occur during agent selection #[derive(Debug, thiserror::Error)] @@ -37,7 +37,7 @@ impl AgentSelector { } /// Find listener by name from the request headers - pub async fn find_listener( + pub fn find_listener( &self, listener_name: Option<&str>, listeners: &[common::configuration::Listener], @@ -84,7 +84,7 @@ impl AgentSelector { } /// Convert agent descriptions to orchestration preferences - async fn convert_agent_description_to_orchestration_preferences( + fn convert_agent_description_to_orchestration_preferences( &self, agents: &[AgentFilterChain], ) -> Vec { @@ -121,9 +121,7 @@ impl AgentSelector { return Ok(vec![agents[0].clone()]); } - let usage_preferences = self - .convert_agent_description_to_orchestration_preferences(agents) - .await; + let usage_preferences = self.convert_agent_description_to_orchestration_preferences(agents); debug!( "Agents usage preferences for orchestration: {}", serde_json::to_string(&usage_preferences).unwrap_or_default() @@ -222,9 +220,7 @@ mod tests { let listener2 = create_test_listener("other-listener", vec![]); let listeners = vec![listener1.clone(), listener2]; - let result = selector - .find_listener(Some("test-listener"), &listeners) - .await; + let result = selector.find_listener(Some("test-listener"), &listeners); assert!(result.is_ok()); assert_eq!(result.unwrap().name, "test-listener"); @@ -237,9 +233,7 @@ mod tests { let listeners = vec![create_test_listener("other-listener", vec![])]; - let result = selector - .find_listener(Some("nonexistent"), &listeners) - .await; + let result = selector.find_listener(Some("nonexistent"), &listeners); assert!(result.is_err()); matches!( diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index c5bfb1b2..499fbfca 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -3,12 +3,11 @@ use std::sync::Arc; use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; use hyper::header::HeaderMap; -use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector}; -use crate::handlers::pipeline_processor::PipelineProcessor; -use crate::router::plano_orchestrator::OrchestratorService; -use common::errors::BrightStaffError; -use http_body_util::BodyExt; -use hyper::StatusCode; +use crate::handlers::agents::pipeline::PipelineProcessor; +use crate::handlers::agents::selector::{AgentSelectionError, AgentSelector}; +use crate::handlers::response::ResponseHandler; +use crate::router::orchestrator::OrchestratorService; + /// Integration test that demonstrates the modular agent chat flow /// This test shows how the three main components work together: /// 1. AgentSelector - selects the appropriate agents based on orchestration @@ -86,9 +85,7 @@ mod tests { let messages = vec![create_test_message(Role::User, "Hello world!")]; // Test 1: Agent Selection - let selected_listener = agent_selector - .find_listener(Some("test-listener"), &listeners) - .await; + let selected_listener = agent_selector.find_listener(Some("test-listener"), &listeners); assert!(selected_listener.is_ok()); let listener = selected_listener.unwrap(); @@ -142,24 +139,8 @@ mod tests { } // Test 4: Error Response Creation - let err = BrightStaffError::ModelNotFound("gpt-5-secret".to_string()); - let response = err.into_response(); - - assert_eq!(response.status(), StatusCode::NOT_FOUND); - - // Helper to extract body as JSON - let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); - let body: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); - - assert_eq!(body["error"]["code"], "ModelNotFound"); - assert_eq!( - body["error"]["details"]["rejected_model_id"], - "gpt-5-secret" - ); - assert!(body["error"]["message"] - .as_str() - .unwrap() - .contains("gpt-5-secret")); + let error_response = ResponseHandler::create_bad_request("Test error"); + assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST); println!("✅ All modular components working correctly!"); } @@ -170,7 +151,7 @@ mod tests { let agent_selector = AgentSelector::new(router_service); // Test listener not found - let result = agent_selector.find_listener(Some("nonexistent"), &[]).await; + let result = agent_selector.find_listener(Some("nonexistent"), &[]); assert!(result.is_err()); assert!(matches!( @@ -178,21 +159,12 @@ mod tests { AgentSelectionError::ListenerNotFound(_) )); - let technical_reason = "Database connection timed out"; - let err = BrightStaffError::InternalServerError(technical_reason.to_string()); - - let response = err.into_response(); - - // --- 1. EXTRACT BYTES --- - let body_bytes = response.into_body().collect().await.unwrap().to_bytes(); - - // --- 2. DECLARE body_json HERE --- - let body_json: serde_json::Value = - serde_json::from_slice(&body_bytes).expect("Failed to parse JSON body"); - - // --- 3. USE body_json --- - assert_eq!(body_json["error"]["code"], "InternalServerError"); - assert_eq!(body_json["error"]["details"]["reason"], technical_reason); + // Test error response creation + let error_response = ResponseHandler::create_internal_error("Pipeline failed"); + assert_eq!( + error_response.status(), + hyper::StatusCode::INTERNAL_SERVER_ERROR + ); println!("✅ Error handling working correctly!"); } diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs deleted file mode 100644 index 02e4fcf3..00000000 --- a/crates/brightstaff/src/handlers/llm.rs +++ /dev/null @@ -1,686 +0,0 @@ -use bytes::Bytes; -use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes}; -use common::consts::{ - ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, -}; -use common::llm_providers::LlmProviders; -use hermesllm::apis::openai_responses::InputParam; -use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; -use hermesllm::{ProviderRequest, ProviderRequestType}; -use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; -use hyper::header::{self}; -use hyper::{Request, Response, StatusCode}; -use opentelemetry::global; -use opentelemetry::trace::get_active_span; -use opentelemetry_http::HeaderInjector; -use std::collections::HashMap; -use std::sync::Arc; -use tokio::sync::RwLock; -use tracing::{debug, info, info_span, warn, Instrument}; - -use super::pipeline_processor::PipelineProcessor; - -use crate::handlers::router_chat::router_chat_get_upstream_model; -use crate::handlers::streaming::{ - create_streaming_response, create_streaming_response_with_output_filter, truncate_message, - ObservableStreamProcessor, StreamProcessor, -}; -use crate::router::llm_router::RouterService; -use crate::state::response_state_processor::ResponsesStateProcessor; -use crate::state::{ - extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError, -}; -use crate::tracing::{ - collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name, -}; - -use common::errors::BrightStaffError; - -#[allow(clippy::too_many_arguments)] -pub async fn llm_chat( - request: Request, - router_service: Arc, - full_qualified_llm_provider_url: String, - model_aliases: Arc>>, - llm_providers: Arc>, - span_attributes: Arc>, - state_storage: Option>, - filter_pipeline: Arc, -) -> Result>, hyper::Error> { - let request_path = request.uri().path().to_string(); - let request_headers = request.headers().clone(); - let request_id: String = match request_headers - .get(REQUEST_ID_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()) - { - Some(id) => id, - None => uuid::Uuid::new_v4().to_string(), - }; - let custom_attrs = - collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref()); - - // Create a span with request_id that will be included in all log lines - let request_span = info_span!( - "llm", - component = "llm", - request_id = %request_id, - http.method = %request.method(), - http.path = %request_path, - llm.model = tracing::field::Empty, - llm.tools = tracing::field::Empty, - llm.user_message_preview = tracing::field::Empty, - llm.temperature = tracing::field::Empty, - ); - - // Execute the rest of the handler inside the span - llm_chat_inner( - request, - router_service, - full_qualified_llm_provider_url, - model_aliases, - llm_providers, - custom_attrs, - state_storage, - request_id, - request_path, - request_headers, - filter_pipeline, - ) - .instrument(request_span) - .await -} - -#[allow(clippy::too_many_arguments)] -async fn llm_chat_inner( - request: Request, - router_service: Arc, - full_qualified_llm_provider_url: String, - model_aliases: Arc>>, - llm_providers: Arc>, - custom_attrs: HashMap, - state_storage: Option>, - request_id: String, - request_path: String, - mut request_headers: hyper::HeaderMap, - filter_pipeline: Arc, -) -> Result>, hyper::Error> { - // Set service name for LLM operations - set_service_name(operation_component::LLM); - get_active_span(|span| { - for (key, value) in &custom_attrs { - span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); - } - }); - - // Extract or generate traceparent - this establishes the trace context for all spans - let traceparent: String = match request_headers - .get(TRACE_PARENT_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()) - { - Some(tp) => tp, - None => { - use uuid::Uuid; - let trace_id = Uuid::new_v4().to_string().replace("-", ""); - let generated_tp = format!("00-{}-0000000000000000-01", trace_id); - warn!( - generated_traceparent = %generated_tp, - "TRACE_PARENT header missing, generated new traceparent" - ); - generated_tp - } - }; - - let raw_bytes = request.collect().await?.to_bytes(); - - debug!( - body = %String::from_utf8_lossy(&raw_bytes), - "request body received" - ); - - // Extract routing_policy from request body if present - let (chat_request_bytes, inline_routing_policy) = - match crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false) { - Ok(result) => result, - Err(err) => { - warn!(error = %err, "failed to parse request JSON"); - return Ok(BrightStaffError::InvalidRequest(format!( - "Failed to parse request: {}", - err - )) - .into_response()); - } - }; - - let mut client_request = match ProviderRequestType::try_from(( - &chat_request_bytes[..], - &SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(), - )) { - Ok(request) => request, - Err(err) => { - warn!( - error = %err, - "failed to parse request as ProviderRequestType" - ); - return Ok(BrightStaffError::InvalidRequest(format!( - "Failed to parse request: {}", - err - )) - .into_response()); - } - }; - - // === v1/responses state management: Extract input items early === - let mut original_input_items = Vec::new(); - let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str()); - let is_responses_api_client = matches!( - client_api, - Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) - ); - - // If model is not specified in the request, resolve from default provider - let model_from_request = client_request.model().to_string(); - let model_from_request = if model_from_request.is_empty() { - match llm_providers.read().await.default() { - Some(default_provider) => { - let default_model = default_provider.name.clone(); - info!(default_model = %default_model, "no model specified in request, using default provider"); - client_request.set_model(default_model.clone()); - default_model - } - None => { - let err_msg = "No model specified in request and no default provider configured"; - warn!("{}", err_msg); - return Ok(BrightStaffError::NoModelSpecified.into_response()); - } - } - } else { - model_from_request - }; - - // Model alias resolution: update model field in client_request immediately - // This ensures all downstream objects use the resolved model - let temperature = client_request.get_temperature(); - let is_streaming_request = client_request.is_streaming(); - let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases); - let (provider_id, _) = get_provider_info(&llm_providers, &alias_resolved_model).await; - - // Validate that the requested model exists in configuration - // This matches the validation in llm_gateway routing.rs - if llm_providers - .read() - .await - .get(&alias_resolved_model) - .is_none() - { - warn!(model = %alias_resolved_model, "model not found in configured providers"); - return Ok(BrightStaffError::ModelNotFound(alias_resolved_model).into_response()); - } - - // Handle provider/model slug format (e.g., "openai/gpt-4") - // Extract just the model name for upstream (providers don't understand the slug) - let model_name_only = if let Some((_, model)) = alias_resolved_model.split_once('/') { - model.to_string() - } else { - alias_resolved_model.clone() - }; - - // Extract tool names and user message preview for span attributes - let tool_names = client_request.get_tool_names(); - let user_message_preview = client_request - .get_recent_user_message() - .map(|msg| truncate_message(&msg, 50)); - let span = tracing::Span::current(); - if let Some(temp) = temperature { - span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp)); - } - if let Some(tools) = &tool_names { - let formatted_tools = tools - .iter() - .map(|name| format!("{}(...)", name)) - .collect::>() - .join("\n"); - span.record(tracing_llm::TOOLS, formatted_tools.as_str()); - } - if let Some(preview) = &user_message_preview { - span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str()); - } - - // Extract messages for signal analysis (clone before moving client_request) - let messages_for_signals = Some(client_request.get_messages()); - - // Set the model to just the model name (without provider prefix) - // This ensures upstream receives "gpt-4" not "openai/gpt-4" - client_request.set_model(model_name_only.clone()); - if client_request.remove_metadata_key("plano_preference_config") { - debug!("removed plano_preference_config from metadata"); - } - - // === Input filters processing for model listener === - // Filters receive the raw request bytes and return (possibly modified) raw bytes. - // The returned bytes are re-parsed into a ProviderRequestType to continue the request. - { - if let Some(ref input_chain) = filter_pipeline.input { - if !input_chain.is_empty() { - debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters"); - - let chain = input_chain.to_agent_filter_chain("model_listener"); - - let mut pipeline_processor = PipelineProcessor::default(); - match pipeline_processor - .process_raw_filter_chain( - &chat_request_bytes, - &chain, - &input_chain.agents, - &request_headers, - &request_path, - ) - .await - { - Ok(filtered_bytes) => { - match ProviderRequestType::try_from(( - &filtered_bytes[..], - &SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(), - )) { - Ok(updated_request) => { - client_request = updated_request; - info!("input filter chain processed successfully"); - } - Err(parse_err) => { - warn!(error = %parse_err, "input filter returned invalid request JSON"); - return Ok(BrightStaffError::InvalidRequest(format!( - "Input filter returned invalid request: {}", - parse_err - )) - .into_response()); - } - } - } - Err(super::pipeline_processor::PipelineError::ClientError { - agent, - status, - body, - }) => { - warn!( - agent = %agent, - status = %status, - body = %body, - "client error from filter chain" - ); - let error_json = serde_json::json!({ - "error": "FilterChainError", - "agent": agent, - "status": status, - "agent_response": body - }); - let mut error_response = Response::new(full(error_json.to_string())); - *error_response.status_mut() = - StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST); - error_response.headers_mut().insert( - hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), - ); - return Ok(error_response); - } - Err(err) => { - warn!(error = %err, "filter chain processing failed"); - let err_msg = format!("Filter chain processing failed: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - } - } - } - } - - if let Some(ref client_api_kind) = client_api { - let upstream_api = - provider_id.compatible_api_for_client(client_api_kind, is_streaming_request); - client_request.normalize_for_upstream(provider_id, &upstream_api); - } - // === v1/responses state management: Determine upstream API and combine input if needed === - // Do this BEFORE routing since routing consumes the request - // Only process state if state_storage is configured - let mut should_manage_state = false; - if is_responses_api_client { - if let ( - ProviderRequestType::ResponsesAPIRequest(ref mut responses_req), - Some(ref state_store), - ) = (&mut client_request, &state_storage) - { - // Extract original input once - original_input_items = extract_input_items(&responses_req.input); - - // Get the upstream path and check if it's ResponsesAPI - let upstream_path = get_upstream_path( - &llm_providers, - &alias_resolved_model, - &request_path, - &alias_resolved_model, - is_streaming_request, - ) - .await; - - let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); - - // Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation) - should_manage_state = !matches!( - upstream_api, - Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) - ); - - if should_manage_state { - // Retrieve and combine conversation history if previous_response_id exists - if let Some(ref prev_resp_id) = responses_req.previous_response_id { - match retrieve_and_combine_input( - state_store.clone(), - prev_resp_id, - original_input_items, // Pass ownership instead of cloning - ) - .await - { - Ok(combined_input) => { - // Update both the request and original_input_items - responses_req.input = InputParam::Items(combined_input.clone()); - original_input_items = combined_input; - info!( - items = original_input_items.len(), - "updated request with conversation history" - ); - } - Err(StateStorageError::NotFound(_)) => { - // Return 409 Conflict when previous_response_id not found - warn!(previous_response_id = %prev_resp_id, "previous response_id not found"); - return Ok(BrightStaffError::ConversationStateNotFound( - prev_resp_id.to_string(), - ) - .into_response()); - } - Err(e) => { - // Log warning but continue on other storage errors - warn!( - previous_response_id = %prev_resp_id, - error = %e, - "failed to retrieve conversation state" - ); - // Restore original_input_items since we passed ownership - original_input_items = extract_input_items(&responses_req.input); - } - } - } - } else { - debug!("upstream supports ResponsesAPI natively"); - } - } - } - - // Serialize request for upstream BEFORE router consumes it - let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap(); - - // Determine routing using the dedicated router_chat module - // This gets its own span for latency and error tracking - let routing_span = info_span!( - "routing", - component = "routing", - http.method = "POST", - http.target = %request_path, - model.requested = %model_from_request, - model.alias_resolved = %alias_resolved_model, - route.selected_model = tracing::field::Empty, - routing.determination_ms = tracing::field::Empty, - ); - let routing_result = match async { - set_service_name(operation_component::ROUTING); - router_chat_get_upstream_model( - router_service, - client_request, // Pass the original request - router_chat will convert it - &traceparent, - &request_path, - &request_id, - inline_routing_policy, - ) - .await - } - .instrument(routing_span) - .await - { - Ok(result) => result, - Err(err) => { - return Ok(BrightStaffError::ForwardedError { - status_code: err.status_code, - message: err.message, - } - .into_response()); - } - }; - - // Determine final model to use - // Router returns "none" as a sentinel value when it doesn't select a specific model - let router_selected_model = routing_result.model_name; - let resolved_model = if router_selected_model != "none" { - // Router selected a specific model via routing preferences - router_selected_model - } else { - // Router returned "none" sentinel, use validated resolved_model from request - alias_resolved_model.clone() - }; - tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str()); - - let span_name = if model_from_request == resolved_model { - format!("POST {} {}", request_path, resolved_model) - } else { - format!( - "POST {} {} -> {}", - request_path, model_from_request, resolved_model - ) - }; - get_active_span(|span| { - span.update_name(span_name.clone()); - }); - - debug!( - url = %full_qualified_llm_provider_url, - provider_hint = %resolved_model, - upstream_model = %model_name_only, - "Routing to upstream" - ); - - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(&resolved_model).unwrap(), - ); - - request_headers.insert( - header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), - header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), - ); - // remove content-length header if it exists - request_headers.remove(header::CONTENT_LENGTH); - - // Inject current LLM span's trace context so upstream spans are children of plano(llm) - global::get_text_map_propagator(|propagator| { - let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers)); - }); - - // Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions, - // /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does - // outbound translation; filters receive raw response bytes and request path. - let has_output_filter = filter_pipeline.has_output_filters(); - - // Save request headers for output filters (before they're consumed by upstream request) - let output_filter_request_headers = if has_output_filter { - Some(request_headers.clone()) - } else { - None - }; - - // Capture start time right before sending request to upstream - let request_start_time = std::time::Instant::now(); - let _request_start_system_time = std::time::SystemTime::now(); - - let llm_response = match reqwest::Client::new() - .post(&full_qualified_llm_provider_url) - .headers(request_headers) - .body(client_request_bytes_for_upstream) - .send() - .await - { - Ok(res) => res, - Err(err) => { - return Ok(BrightStaffError::InternalServerError(format!( - "Failed to send request: {}", - err - )) - .into_response()); - } - }; - - // copy over the headers and status code from the original response - let response_headers = llm_response.headers().clone(); - let upstream_status = llm_response.status(); - let mut response = Response::builder().status(upstream_status); - let headers = response.headers_mut().unwrap(); - for (header_name, header_value) in response_headers.iter() { - headers.insert(header_name, header_value.clone()); - } - - // Build LLM span with actual status code using constants - let byte_stream = llm_response.bytes_stream(); - - // Create base processor for metrics and tracing - let base_processor = ObservableStreamProcessor::new( - operation_component::LLM, - span_name, - request_start_time, - messages_for_signals, - ); - - // === v1/responses state management: Wrap with ResponsesStateProcessor === - // Pick the right processor: state-aware if needed, otherwise base metrics-only. - let processor: Box = if let (true, false, Some(state_store)) = ( - should_manage_state, - original_input_items.is_empty(), - state_storage, - ) { - let content_encoding = response_headers - .get("content-encoding") - .and_then(|v| v.to_str().ok()) - .map(|s| s.to_string()); - - Box::new(ResponsesStateProcessor::new( - base_processor, - state_store, - original_input_items, - alias_resolved_model.clone(), - resolved_model.clone(), - is_streaming_request, - false, // Not OpenAI upstream since should_manage_state is true - content_encoding, - request_id, - )) - } else { - Box::new(base_processor) - }; - - // Apply output filters if configured, then build the streaming response. - let streaming_response = if has_output_filter { - let output_chain = filter_pipeline.output.as_ref().unwrap().clone(); - create_streaming_response_with_output_filter( - byte_stream, - processor, - output_chain, - output_filter_request_headers.unwrap(), - request_path.clone(), - ) - } else { - create_streaming_response(byte_stream, processor) - }; - - match response.body(streaming_response.body) { - Ok(response) => Ok(response), - Err(err) => Ok(BrightStaffError::InternalServerError(format!( - "Failed to create response: {}", - err - )) - .into_response()), - } -} -/// Resolves model aliases by looking up the requested model in the model_aliases map. -/// Returns the target model if an alias is found, otherwise returns the original model. -fn resolve_model_alias( - model_from_request: &str, - model_aliases: &Arc>>, -) -> String { - if let Some(aliases) = model_aliases.as_ref() { - if let Some(model_alias) = aliases.get(model_from_request) { - debug!( - "Model Alias: 'From {}' -> 'To {}'", - model_from_request, model_alias.target - ); - return model_alias.target.clone(); - } - } - model_from_request.to_string() -} - -/// Calculates the upstream path for the provider based on the model name. -/// Looks up provider configuration, gets the ProviderId and base_url_path_prefix, -/// then uses target_endpoint_for_provider to calculate the correct upstream path. -async fn get_upstream_path( - llm_providers: &Arc>, - model_name: &str, - request_path: &str, - resolved_model: &str, - is_streaming: bool, -) -> String { - let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; - - // Calculate the upstream path using the proper API - let client_api = SupportedAPIsFromClient::from_endpoint(request_path) - .expect("Should have valid API endpoint"); - - client_api.target_endpoint_for_provider( - &provider_id, - request_path, - resolved_model, - is_streaming, - base_url_path_prefix.as_deref(), - ) -} - -/// Helper function to get provider info (ProviderId and base_url_path_prefix) -async fn get_provider_info( - llm_providers: &Arc>, - model_name: &str, -) -> (hermesllm::ProviderId, Option) { - let providers_lock = llm_providers.read().await; - - // Try to find by model name or provider name using LlmProviders::get - // This handles both "gpt-4" and "openai/gpt-4" formats - if let Some(provider) = providers_lock.get(model_name) { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - return (provider_id, prefix); - } - - // Fall back to default provider - if let Some(provider) = providers_lock.default() { - let provider_id = provider.provider_interface.to_provider_id(); - let prefix = provider.base_url_path_prefix.clone(); - (provider_id, prefix) - } else { - // Last resort: use OpenAI as hardcoded fallback - warn!("No default provider found, falling back to OpenAI"); - (hermesllm::ProviderId::OpenAI, None) - } -} - -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() -} diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs new file mode 100644 index 00000000..9d4a2dfb --- /dev/null +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -0,0 +1,780 @@ +use bytes::Bytes; +use common::configuration::{FilterPipeline, ModelAlias}; +use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER}; +use common::llm_providers::LlmProviders; +use hermesllm::apis::openai::Message; +use hermesllm::apis::openai_responses::InputParam; +use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; +use hermesllm::{ProviderRequest, ProviderRequestType}; +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use hyper::header::{self}; +use hyper::{Request, Response, StatusCode}; +use opentelemetry::global; +use opentelemetry::trace::get_active_span; +use opentelemetry_http::HeaderInjector; +use std::collections::HashMap; +use std::sync::Arc; +use tokio::sync::RwLock; +use tracing::{debug, info, info_span, warn, Instrument}; + +pub(crate) mod model_selection; + +use crate::app_state::AppState; +use crate::handlers::agents::pipeline::PipelineProcessor; +use crate::handlers::extract_or_generate_traceparent; +use crate::handlers::extract_request_id; +use crate::handlers::full; +use crate::state::response_state_processor::ResponsesStateProcessor; +use crate::state::{ + extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError, +}; +use crate::streaming::{ + create_streaming_response, create_streaming_response_with_output_filter, truncate_message, + ObservableStreamProcessor, StreamProcessor, +}; +use crate::tracing::{ + collect_custom_trace_attributes, llm as tracing_llm, operation_component, set_service_name, +}; +use model_selection::router_chat_get_upstream_model; + +pub async fn llm_chat( + request: Request, + state: Arc, +) -> Result>, hyper::Error> { + let request_path = request.uri().path().to_string(); + let request_headers = request.headers().clone(); + let request_id = extract_request_id(&request); + let custom_attrs = + collect_custom_trace_attributes(&request_headers, state.span_attributes.as_ref()); + + // Create a span with request_id that will be included in all log lines + let request_span = info_span!( + "llm", + component = "llm", + request_id = %request_id, + http.method = %request.method(), + http.path = %request_path, + llm.model = tracing::field::Empty, + llm.tools = tracing::field::Empty, + llm.user_message_preview = tracing::field::Empty, + llm.temperature = tracing::field::Empty, + ); + + // Execute the rest of the handler inside the span + llm_chat_inner( + request, + state, + custom_attrs, + request_id, + request_path, + request_headers, + ) + .instrument(request_span) + .await +} + +async fn llm_chat_inner( + request: Request, + state: Arc, + custom_attrs: HashMap, + request_id: String, + request_path: String, + mut request_headers: hyper::HeaderMap, +) -> Result>, hyper::Error> { + // Set service name for LLM operations + set_service_name(operation_component::LLM); + get_active_span(|span| { + for (key, value) in &custom_attrs { + span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); + } + }); + + let traceparent = extract_or_generate_traceparent(&request_headers); + + let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path); + + // --- Phase 1: Parse and validate the incoming request --- + let parsed = match parse_and_validate_request( + request, + &request_path, + &state.model_aliases, + &state.llm_providers, + ) + .await + { + Ok(p) => p, + Err(response) => return Ok(response), + }; + + let PreparedRequest { + mut client_request, + chat_request_bytes, + model_from_request, + alias_resolved_model, + model_name_only, + is_streaming_request, + is_responses_api_client, + messages_for_signals, + temperature, + tool_names, + user_message_preview, + inline_routing_policy, + client_api, + provider_id, + } = parsed; + + // Record LLM-specific span attributes + let span = tracing::Span::current(); + if let Some(temp) = temperature { + span.record(tracing_llm::TEMPERATURE, tracing::field::display(temp)); + } + if let Some(tools) = &tool_names { + let formatted_tools = tools + .iter() + .map(|name| format!("{}(...)", name)) + .collect::>() + .join("\n"); + span.record(tracing_llm::TOOLS, formatted_tools.as_str()); + } + if let Some(preview) = &user_message_preview { + span.record(tracing_llm::USER_MESSAGE_PREVIEW, preview.as_str()); + } + + // --- Phase 1b: Input filter processing for model listener --- + if let Some(ref input_chain) = state.filter_pipeline.input { + if !input_chain.is_empty() { + debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters"); + let chain = input_chain.to_agent_filter_chain("model_listener"); + let mut pipeline_processor = PipelineProcessor::default(); + match pipeline_processor + .process_raw_filter_chain( + &chat_request_bytes, + &chain, + &input_chain.agents, + &request_headers, + &request_path, + ) + .await + { + Ok(filtered_bytes) => { + let api_type = SupportedAPIsFromClient::from_endpoint(request_path.as_str()) + .expect("endpoint validated in parse_and_validate_request"); + match ProviderRequestType::try_from((&filtered_bytes[..], &api_type)) { + Ok(updated_request) => { + client_request = updated_request; + info!("input filter chain processed successfully"); + } + Err(parse_err) => { + warn!(error = %parse_err, "input filter returned invalid request JSON"); + return Ok(common::errors::BrightStaffError::InvalidRequest(format!( + "Input filter returned invalid request: {}", + parse_err + )) + .into_response()); + } + } + } + Err(super::agents::pipeline::PipelineError::ClientError { + agent, + status, + body, + }) => { + warn!(agent = %agent, status = %status, body = %body, "client error from filter chain"); + let error_json = serde_json::json!({ + "error": "FilterChainError", + "agent": agent, + "status": status, + "agent_response": body + }); + let mut error_response = Response::new(full(error_json.to_string())); + *error_response.status_mut() = + StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST); + error_response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + hyper::header::HeaderValue::from_static("application/json"), + ); + return Ok(error_response); + } + Err(err) => { + warn!(error = %err, "filter chain processing failed"); + let mut internal_error = + Response::new(full(format!("Filter chain processing failed: {}", err))); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + } + } + } + + // Normalize for upstream after input filters + if let Some(ref client_api_kind) = client_api { + let upstream_api = + provider_id.compatible_api_for_client(client_api_kind, is_streaming_request); + client_request.normalize_for_upstream(provider_id, &upstream_api); + } + + // --- Phase 2: Resolve conversation state (v1/responses API) --- + let state_ctx = match resolve_conversation_state( + &mut client_request, + is_responses_api_client, + &state.state_storage, + &state.llm_providers, + &alias_resolved_model, + &request_path, + is_streaming_request, + ) + .await + { + Ok(ctx) => ctx, + Err(response) => return Ok(response), + }; + + // Serialize request for upstream BEFORE router consumes it + let client_request_bytes_for_upstream: Bytes = + match ProviderRequestType::to_bytes(&client_request) { + Ok(bytes) => bytes.into(), + Err(err) => { + warn!(error = %err, "failed to serialize request for upstream"); + let mut r = Response::new(full(format!("Failed to serialize request: {}", err))); + *r.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(r); + } + }; + + // --- Phase 3: Route the request --- + let routing_span = info_span!( + "routing", + component = "routing", + http.method = "POST", + http.target = %request_path, + model.requested = %model_from_request, + model.alias_resolved = %alias_resolved_model, + route.selected_model = tracing::field::Empty, + routing.determination_ms = tracing::field::Empty, + ); + let routing_result = match async { + set_service_name(operation_component::ROUTING); + router_chat_get_upstream_model( + Arc::clone(&state.router_service), + client_request, + &traceparent, + &request_path, + &request_id, + inline_routing_policy, + ) + .await + } + .instrument(routing_span) + .await + { + Ok(result) => result, + Err(err) => { + let mut internal_error = Response::new(full(err.message)); + *internal_error.status_mut() = err.status_code; + return Ok(internal_error); + } + }; + + // Determine final model (router returns "none" when it doesn't select a specific model) + let router_selected_model = routing_result.model_name; + let resolved_model = if router_selected_model != "none" { + router_selected_model + } else { + alias_resolved_model.clone() + }; + tracing::Span::current().record(tracing_llm::MODEL_NAME, resolved_model.as_str()); + + // --- Phase 4: Forward to upstream and stream back --- + send_upstream( + &state.http_client, + &full_qualified_llm_provider_url, + &mut request_headers, + client_request_bytes_for_upstream, + &model_from_request, + &alias_resolved_model, + &resolved_model, + &model_name_only, + &request_path, + is_streaming_request, + messages_for_signals, + state_ctx, + state.state_storage.clone(), + request_id, + &state.filter_pipeline, + ) + .await +} + +// --------------------------------------------------------------------------- +// Phase 1 — Parse & validate the incoming request +// --------------------------------------------------------------------------- + +/// All pre-validated request data extracted from the raw HTTP request. +struct PreparedRequest { + client_request: ProviderRequestType, + chat_request_bytes: Bytes, + model_from_request: String, + alias_resolved_model: String, + model_name_only: String, + is_streaming_request: bool, + is_responses_api_client: bool, + messages_for_signals: Option>, + temperature: Option, + tool_names: Option>, + user_message_preview: Option, + inline_routing_policy: Option>, + client_api: Option, + provider_id: hermesllm::ProviderId, +} + +/// Parse the body, resolve the model alias, and validate the model exists. +/// +/// Returns `Err(Response)` for early-exit error responses (400 etc.). +async fn parse_and_validate_request( + request: Request, + request_path: &str, + model_aliases: &Option>, + llm_providers: &Arc>, +) -> Result>> { + let raw_bytes = request + .collect() + .await + .map_err(|_| { + let mut r = Response::new(full("Failed to read request body")); + *r.status_mut() = StatusCode::BAD_REQUEST; + r + })? + .to_bytes(); + + debug!( + body = %String::from_utf8_lossy(&raw_bytes), + "request body received" + ); + + // Extract routing_policy from request body if present + let (chat_request_bytes, inline_routing_policy) = + crate::handlers::routing_service::extract_routing_policy(&raw_bytes, false).map_err( + |err| { + warn!(error = %err, "failed to parse request JSON"); + let mut r = Response::new(full(format!("Failed to parse request: {}", err))); + *r.status_mut() = StatusCode::BAD_REQUEST; + r + }, + )?; + + let api_type = SupportedAPIsFromClient::from_endpoint(request_path).ok_or_else(|| { + warn!(path = %request_path, "unsupported endpoint"); + let mut r = Response::new(full(format!("Unsupported endpoint: {}", request_path))); + *r.status_mut() = StatusCode::BAD_REQUEST; + r + })?; + + let mut client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) + .map_err(|err| { + warn!(error = %err, "failed to parse request as ProviderRequestType"); + let mut r = Response::new(full(format!("Failed to parse request: {}", err))); + *r.status_mut() = StatusCode::BAD_REQUEST; + r + })?; + + let is_responses_api_client = + matches!(api_type, SupportedAPIsFromClient::OpenAIResponsesAPI(_)); + let client_api = Some(api_type); + + let model_from_request = client_request.model().to_string(); + let temperature = client_request.get_temperature(); + let is_streaming_request = client_request.is_streaming(); + let alias_resolved_model = resolve_model_alias(&model_from_request, model_aliases); + let (provider_id, _) = get_provider_info(llm_providers, &alias_resolved_model).await; + + // Validate model exists in configuration + if llm_providers + .read() + .await + .get(&alias_resolved_model) + .is_none() + { + let err_msg = format!( + "Model '{}' not found in configured providers", + alias_resolved_model + ); + warn!(model = %alias_resolved_model, "model not found in configured providers"); + let mut r = Response::new(full(err_msg)); + *r.status_mut() = StatusCode::BAD_REQUEST; + return Err(r); + } + + // Strip provider prefix for upstream (e.g. "openai/gpt-4" → "gpt-4") + let model_name_only = alias_resolved_model + .split_once('/') + .map(|(_, model)| model.to_string()) + .unwrap_or_else(|| alias_resolved_model.clone()); + + // Extract span attributes and messages before mutating client_request + let tool_names = client_request.get_tool_names(); + let user_message_preview = client_request + .get_recent_user_message() + .map(|msg| truncate_message(&msg, 50)); + let messages_for_signals = Some(client_request.get_messages()); + + // Set the upstream model name and strip routing metadata + client_request.set_model(model_name_only.clone()); + if client_request.remove_metadata_key("archgw_preference_config") { + debug!("removed archgw_preference_config from metadata"); + } + if client_request.remove_metadata_key("plano_preference_config") { + debug!("removed plano_preference_config from metadata"); + } + + Ok(PreparedRequest { + client_request, + chat_request_bytes, + model_from_request, + alias_resolved_model, + model_name_only, + is_streaming_request, + is_responses_api_client, + messages_for_signals, + temperature, + tool_names, + user_message_preview, + inline_routing_policy, + client_api, + provider_id, + }) +} + +// --------------------------------------------------------------------------- +// Phase 2 — Resolve conversation state (v1/responses API) +// --------------------------------------------------------------------------- + +/// Holds the state management context resolved from a v1/responses request. +struct ConversationStateContext { + should_manage_state: bool, + original_input_items: Vec, +} + +/// If the client uses the v1/responses API and the upstream provider doesn't +/// support it natively, we manage conversation state ourselves. +/// +/// This resolves `previous_response_id`, merges conversation history, and +/// updates the request in place. +/// +/// Returns `Err(Response)` for early-exit (e.g. 409 Conflict). +async fn resolve_conversation_state( + client_request: &mut ProviderRequestType, + is_responses_api_client: bool, + state_storage: &Option>, + llm_providers: &Arc>, + alias_resolved_model: &str, + request_path: &str, + is_streaming_request: bool, +) -> Result>> { + if !is_responses_api_client { + return Ok(ConversationStateContext { + should_manage_state: false, + original_input_items: Vec::new(), + }); + } + + let (responses_req, state_store) = match (client_request, state_storage) { + (ProviderRequestType::ResponsesAPIRequest(ref mut req), Some(store)) => (req, store), + _ => { + return Ok(ConversationStateContext { + should_manage_state: false, + original_input_items: Vec::new(), + }); + } + }; + + let mut original_input_items = extract_input_items(&responses_req.input); + + // Check whether the upstream supports v1/responses natively + let upstream_path = get_upstream_path( + llm_providers, + alias_resolved_model, + request_path, + alias_resolved_model, + is_streaming_request, + ) + .await; + + let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path); + let should_manage_state = !matches!( + upstream_api, + Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)) + ); + + if !should_manage_state { + debug!("upstream supports ResponsesAPI natively"); + return Ok(ConversationStateContext { + should_manage_state: false, + original_input_items, + }); + } + + // Retrieve and combine conversation history if previous_response_id exists + if let Some(ref prev_resp_id) = responses_req.previous_response_id { + match retrieve_and_combine_input(state_store.clone(), prev_resp_id, original_input_items) + .await + { + Ok(combined_input) => { + responses_req.input = InputParam::Items(combined_input.clone()); + original_input_items = combined_input; + info!( + items = original_input_items.len(), + "updated request with conversation history" + ); + } + Err(StateStorageError::NotFound(_)) => { + warn!(previous_response_id = %prev_resp_id, "previous response_id not found"); + let err_msg = format!( + "Conversation state not found for previous_response_id: {}", + prev_resp_id + ); + let mut r = Response::new(full(err_msg)); + *r.status_mut() = StatusCode::CONFLICT; + return Err(r); + } + Err(e) => { + warn!( + previous_response_id = %prev_resp_id, + error = %e, + "failed to retrieve conversation state" + ); + // Restore original_input_items since we passed ownership + original_input_items = extract_input_items(&responses_req.input); + } + } + } + + Ok(ConversationStateContext { + should_manage_state, + original_input_items, + }) +} + +// --------------------------------------------------------------------------- +// Phase 4 — Forward to upstream and stream the response back +// --------------------------------------------------------------------------- + +#[allow(clippy::too_many_arguments)] +async fn send_upstream( + http_client: &reqwest::Client, + upstream_url: &str, + request_headers: &mut hyper::HeaderMap, + body: bytes::Bytes, + model_from_request: &str, + alias_resolved_model: &str, + resolved_model: &str, + model_name_only: &str, + request_path: &str, + is_streaming_request: bool, + messages_for_signals: Option>, + state_ctx: ConversationStateContext, + state_storage: Option>, + request_id: String, + filter_pipeline: &Arc, +) -> Result>, hyper::Error> { + let span_name = if model_from_request == resolved_model { + format!("POST {} {}", request_path, resolved_model) + } else { + format!( + "POST {} {} -> {}", + request_path, model_from_request, resolved_model + ) + }; + get_active_span(|span| { + span.update_name(span_name.clone()); + }); + + debug!( + url = %upstream_url, + provider_hint = %resolved_model, + upstream_model = %model_name_only, + "Routing to upstream" + ); + + if let Ok(val) = header::HeaderValue::from_str(resolved_model) { + request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val); + } + request_headers.insert( + header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), + header::HeaderValue::from_static(if is_streaming_request { + "true" + } else { + "false" + }), + ); + request_headers.remove(header::CONTENT_LENGTH); + + // Inject current span's trace context so upstream spans are children of plano(llm) + global::get_text_map_propagator(|propagator| { + let cx = tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); + propagator.inject_context(&cx, &mut HeaderInjector(request_headers)); + }); + + let request_start_time = std::time::Instant::now(); + + let llm_response = match http_client + .post(upstream_url) + .headers(request_headers.clone()) + .body(body) + .send() + .await + { + Ok(res) => res, + Err(err) => { + let err_msg = format!("Failed to send request: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + + // Propagate upstream headers and status + let response_headers = llm_response.headers().clone(); + let upstream_status = llm_response.status(); + let mut response = Response::builder().status(upstream_status); + if let Some(headers) = response.headers_mut() { + for (name, value) in response_headers.iter() { + headers.insert(name, value.clone()); + } + } + + let byte_stream = llm_response.bytes_stream(); + + // Create base processor for metrics and tracing + let base_processor = ObservableStreamProcessor::new( + operation_component::LLM, + span_name, + request_start_time, + messages_for_signals, + ); + + let output_filter_request_headers = if filter_pipeline.has_output_filters() { + Some(request_headers.clone()) + } else { + None + }; + + // Pick the right processor: state-aware if needed, otherwise base metrics-only. + let processor: Box = if let (true, false, Some(state_store)) = ( + state_ctx.should_manage_state, + state_ctx.original_input_items.is_empty(), + state_storage, + ) { + let content_encoding = response_headers + .get("content-encoding") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); + + Box::new(ResponsesStateProcessor::new( + base_processor, + state_store, + state_ctx.original_input_items, + alias_resolved_model.to_string(), + resolved_model.to_string(), + is_streaming_request, + false, + content_encoding, + request_id, + )) + } else { + Box::new(base_processor) + }; + + let streaming_response = if let (Some(output_chain), Some(filter_headers)) = ( + filter_pipeline.output.as_ref().filter(|c| !c.is_empty()), + output_filter_request_headers, + ) { + create_streaming_response_with_output_filter( + byte_stream, + processor, + output_chain.clone(), + filter_headers, + request_path.to_string(), + ) + } else { + create_streaming_response(byte_stream, processor) + }; + + match response.body(streaming_response.body) { + Ok(response) => Ok(response), + Err(err) => { + let err_msg = format!("Failed to create response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(internal_error) + } + } +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Resolves model aliases by looking up the requested model in the model_aliases map. +/// Returns the target model if an alias is found, otherwise returns the original model. +fn resolve_model_alias( + model_from_request: &str, + model_aliases: &Option>, +) -> String { + if let Some(aliases) = model_aliases.as_ref() { + if let Some(model_alias) = aliases.get(model_from_request) { + debug!( + "Model Alias: 'From {}' -> 'To {}'", + model_from_request, model_alias.target + ); + return model_alias.target.clone(); + } + } + model_from_request.to_string() +} + +/// Calculates the upstream path for the provider based on the model name. +async fn get_upstream_path( + llm_providers: &Arc>, + model_name: &str, + request_path: &str, + resolved_model: &str, + is_streaming: bool, +) -> String { + let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; + + let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else { + return request_path.to_string(); + }; + + client_api.target_endpoint_for_provider( + &provider_id, + request_path, + resolved_model, + is_streaming, + base_url_path_prefix.as_deref(), + ) +} + +/// Helper to get provider info (ProviderId and base_url_path_prefix). +async fn get_provider_info( + llm_providers: &Arc>, + model_name: &str, +) -> (hermesllm::ProviderId, Option) { + let providers_lock = llm_providers.read().await; + + if let Some(provider) = providers_lock.get(model_name) { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + return (provider_id, prefix); + } + + if let Some(provider) = providers_lock.default() { + let provider_id = provider.provider_interface.to_provider_id(); + let prefix = provider.base_url_path_prefix.clone(); + (provider_id, prefix) + } else { + warn!("No default provider found, falling back to OpenAI"); + (hermesllm::ProviderId::OpenAI, None) + } +} diff --git a/crates/brightstaff/src/handlers/router_chat.rs b/crates/brightstaff/src/handlers/llm/model_selection.rs similarity index 92% rename from crates/brightstaff/src/handlers/router_chat.rs rename to crates/brightstaff/src/handlers/llm/model_selection.rs index 910e5408..455b7c0e 100644 --- a/crates/brightstaff/src/handlers/router_chat.rs +++ b/crates/brightstaff/src/handlers/llm/model_selection.rs @@ -5,7 +5,8 @@ use hyper::StatusCode; use std::sync::Arc; use tracing::{debug, info, warn}; -use crate::router::llm_router::RouterService; +use crate::router::llm::RouterService; +use crate::streaming::truncate_message; use crate::tracing::routing; pub struct RoutingResult { @@ -103,16 +104,7 @@ pub async fn router_chat_get_upstream_model( .map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n")) }); - const MAX_MESSAGE_LENGTH: usize = 50; - let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH { - let truncated: String = latest_message_for_log - .chars() - .take(MAX_MESSAGE_LENGTH) - .collect(); - format!("{}...", truncated) - } else { - latest_message_for_log - }; + let latest_message_for_log = truncate_message(&latest_message_for_log, 50); info!( has_usage_preferences = usage_preferences.is_some(), diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index b2161e43..485a0438 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,14 +1,57 @@ -pub mod agent_chat_completions; -pub mod agent_selector; +pub mod agents; pub mod function_calling; -pub mod jsonrpc; pub mod llm; pub mod models; -pub mod pipeline_processor; -pub mod response_handler; -pub mod router_chat; +pub mod response; pub mod routing_service; -pub mod streaming; #[cfg(test)] mod integration_tests; + +use bytes::Bytes; +use common::consts::TRACE_PARENT_HEADER; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Empty, Full}; +use hyper::Request; +use tracing::warn; + +/// Wrap a chunk into a `BoxBody` for hyper responses. +pub fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() +} + +/// An empty HTTP body (used for 404 / OPTIONS responses). +pub fn empty() -> BoxBody { + Empty::::new() + .map_err(|never| match never {}) + .boxed() +} + +/// Extract request ID from incoming request headers, or generate a new UUID v4. +pub fn extract_request_id(request: &Request) -> String { + request + .headers() + .get(common::consts::REQUEST_ID_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()) +} + +/// Extract or generate a W3C `traceparent` header value. +pub fn extract_or_generate_traceparent(headers: &hyper::HeaderMap) -> String { + headers + .get(TRACE_PARENT_HEADER) + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + let trace_id = uuid::Uuid::new_v4().to_string().replace("-", ""); + let tp = format!("00-{}-0000000000000000-01", trace_id); + warn!( + generated_traceparent = %tp, + "TRACE_PARENT header missing, generated new traceparent" + ); + tp + }) +} diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs index a29d5e90..9fd5fe07 100644 --- a/crates/brightstaff/src/handlers/models.rs +++ b/crates/brightstaff/src/handlers/models.rs @@ -1,10 +1,11 @@ use bytes::Bytes; use common::llm_providers::LlmProviders; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use http_body_util::combinators::BoxBody; use hyper::{Response, StatusCode}; -use serde_json; use std::sync::Arc; +use super::full; + pub async fn list_models( llm_providers: Arc>, ) -> Response> { @@ -12,27 +13,15 @@ pub async fn list_models( let models = prov.to_models(); match serde_json::to_string(&models) { - Ok(json) => { - let body = Full::new(Bytes::from(json)) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - Err(_) => { - let body = Full::new(Bytes::from_static( - b"{\"error\":\"Failed to serialize models\"}", - )) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } + Ok(json) => Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(full(json)) + .unwrap(), + Err(_) => Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header("Content-Type", "application/json") + .body(full("{\"error\":\"Failed to serialize models\"}")) + .unwrap(), } } diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response.rs similarity index 77% rename from crates/brightstaff/src/handlers/response_handler.rs rename to crates/brightstaff/src/handlers/response.rs index 7331ab4c..4db2f8a8 100644 --- a/crates/brightstaff/src/handlers/response_handler.rs +++ b/crates/brightstaff/src/handlers/response.rs @@ -4,7 +4,7 @@ use hermesllm::apis::OpenAIApi; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; use hermesllm::SseEvent; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full, StreamBody}; +use http_body_util::StreamBody; use hyper::body::Frame; use hyper::Response; use tokio::sync::mpsc; @@ -12,6 +12,8 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; use tracing::{info, warn, Instrument}; +use super::full; + /// Service for handling HTTP responses and streaming pub struct ResponseHandler; @@ -22,9 +24,40 @@ impl ResponseHandler { /// Create a full response body from bytes pub fn create_full_body>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() + full(chunk) + } + + /// Create a JSON error response with BAD_REQUEST status + pub fn create_json_error_response( + json: &serde_json::Value, + ) -> Response> { + let body = Self::create_full_body(json.to_string()); + let mut response = Response::new(body); + *response.status_mut() = hyper::StatusCode::BAD_REQUEST; + response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + hyper::header::HeaderValue::from_static("application/json"), + ); + response + } + + /// Create a BAD_REQUEST error response with a message + pub fn create_bad_request(message: &str) -> Response> { + let json = serde_json::json!({"error": message}); + Self::create_json_error_response(&json) + } + + /// Create an INTERNAL_SERVER_ERROR response with a message + pub fn create_internal_error(message: &str) -> Response> { + let json = serde_json::json!({"error": message}); + let body = Self::create_full_body(json.to_string()); + let mut response = Response::new(body); + *response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR; + response.headers_mut().insert( + hyper::header::CONTENT_TYPE, + hyper::header::HeaderValue::from_static("application/json"), + ); + response } /// Create a streaming response from a reqwest response. @@ -112,7 +145,9 @@ impl ResponseHandler { let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap(); + let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| { + BrightStaffError::StreamError(format!("Failed to parse SSE stream: {}", e)) + })?; let mut accumulated_text = String::new(); for sse_event in sse_iter { @@ -122,7 +157,13 @@ impl ResponseHandler { } let transformed_event = - SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap(); + match SseEvent::try_from((sse_event, &client_api, &upstream_api)) { + Ok(event) => event, + Err(e) => { + warn!(error = ?e, "failed to transform SSE event, skipping"); + continue; + } + }; // Try to get provider response and extract content delta match transformed_event.provider_response() { diff --git a/crates/brightstaff/src/handlers/routing_service.rs b/crates/brightstaff/src/handlers/routing_service.rs index 4eae4685..ec09f06f 100644 --- a/crates/brightstaff/src/handlers/routing_service.rs +++ b/crates/brightstaff/src/handlers/routing_service.rs @@ -1,6 +1,6 @@ use bytes::Bytes; use common::configuration::{ModelUsagePreference, SpanAttributes}; -use common::consts::{REQUEST_ID_HEADER, TRACE_PARENT_HEADER}; +use common::consts::REQUEST_ID_HEADER; use common::errors::BrightStaffError; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::ProviderRequestType; @@ -10,8 +10,9 @@ use hyper::{Request, Response, StatusCode}; use std::sync::Arc; use tracing::{debug, info, info_span, warn, Instrument}; -use crate::handlers::router_chat::router_chat_get_upstream_model; -use crate::router::llm_router::RouterService; +use super::extract_or_generate_traceparent; +use crate::handlers::llm::model_selection::router_chat_get_upstream_model; +use crate::router::llm::RouterService; use crate::tracing::{collect_custom_trace_attributes, operation_component, set_service_name}; const ROUTING_POLICY_SIZE_WARNING_BYTES: usize = 5120; @@ -72,7 +73,7 @@ pub async fn routing_decision( request: Request, router_service: Arc, request_path: String, - span_attributes: Arc>, + span_attributes: &Option, ) -> Result>, hyper::Error> { let request_headers = request.headers().clone(); let request_id: String = request_headers @@ -81,8 +82,7 @@ pub async fn routing_decision( .map(|s| s.to_string()) .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); - let custom_attrs = - collect_custom_trace_attributes(&request_headers, span_attributes.as_ref().as_ref()); + let custom_attrs = collect_custom_trace_attributes(&request_headers, span_attributes.as_ref()); let request_span = info_span!( "routing_decision", @@ -119,23 +119,7 @@ async fn routing_decision_inner( } }); - // Extract or generate traceparent - let traceparent: String = match request_headers - .get(TRACE_PARENT_HEADER) - .and_then(|h| h.to_str().ok()) - .map(|s| s.to_string()) - { - Some(tp) => tp, - None => { - let trace_id = uuid::Uuid::new_v4().to_string().replace("-", ""); - let generated_tp = format!("00-{}-0000000000000000-01", trace_id); - warn!( - generated_traceparent = %generated_tp, - "TRACE_PARENT header missing, generated new traceparent" - ); - generated_tp - } - }; + let traceparent = extract_or_generate_traceparent(&request_headers); // Extract trace_id from traceparent (format: 00-{trace_id}-{span_id}-{flags}) let trace_id = traceparent diff --git a/crates/brightstaff/src/lib.rs b/crates/brightstaff/src/lib.rs index 47da64aa..b4ab82a9 100644 --- a/crates/brightstaff/src/lib.rs +++ b/crates/brightstaff/src/lib.rs @@ -1,6 +1,7 @@ +pub mod app_state; pub mod handlers; pub mod router; pub mod signals; pub mod state; +pub mod streaming; pub mod tracing; -pub mod utils; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 391aed03..60a69bca 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,28 +1,31 @@ -use brightstaff::handlers::agent_chat_completions::agent_chat; +use brightstaff::app_state::AppState; +use brightstaff::handlers::agents::orchestrator::agent_chat; +use brightstaff::handlers::empty; use brightstaff::handlers::function_calling::function_calling_chat_handler; use brightstaff::handlers::llm::llm_chat; use brightstaff::handlers::models::list_models; use brightstaff::handlers::routing_service::routing_decision; -use brightstaff::router::llm_router::RouterService; -use brightstaff::router::plano_orchestrator::OrchestratorService; +use brightstaff::router::llm::RouterService; +use brightstaff::router::orchestrator::OrchestratorService; use brightstaff::state::memory::MemoryConversationalStorage; use brightstaff::state::postgresql::PostgreSQLConversationStorage; use brightstaff::state::StateStorage; -use brightstaff::utils::tracing::init_tracer; +use brightstaff::tracing::init_tracer; use bytes::Bytes; use common::configuration::{ Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain, }; use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH}; use common::llm_providers::LlmProviders; -use http_body_util::{combinators::BoxBody, BodyExt, Empty}; +use http_body_util::combinators::BoxBody; use hyper::body::Incoming; +use hyper::header::HeaderValue; use hyper::server::conn::http1; use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode}; use hyper_util::rt::TokioIo; +use opentelemetry::global; use opentelemetry::trace::FutureExt; -use opentelemetry::{global, Context}; use opentelemetry_http::HeaderExtractor; use std::collections::HashMap; use std::sync::Arc; @@ -31,82 +34,94 @@ use tokio::net::TcpListener; use tokio::sync::RwLock; use tracing::{debug, info, warn}; -pub mod router; - const BIND_ADDRESS: &str = "0.0.0.0:9091"; const DEFAULT_ROUTING_LLM_PROVIDER: &str = "arch-router"; const DEFAULT_ROUTING_MODEL_NAME: &str = "Arch-Router"; const DEFAULT_ORCHESTRATOR_LLM_PROVIDER: &str = "plano-orchestrator"; const DEFAULT_ORCHESTRATOR_MODEL_NAME: &str = "Plano-Orchestrator"; -// Utility function to extract the context from the incoming request headers -fn extract_context_from_request(req: &Request) -> Context { - global::get_text_map_propagator(|propagator| { - propagator.extract(&HeaderExtractor(req.headers())) - }) +/// CORS pre-flight response for the models endpoint. +fn cors_preflight() -> Result>, hyper::Error> { + let mut response = Response::new(empty()); + *response.status_mut() = StatusCode::NO_CONTENT; + let h = response.headers_mut(); + h.insert("Allow", HeaderValue::from_static("GET, OPTIONS")); + h.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*")); + h.insert( + "Access-Control-Allow-Headers", + HeaderValue::from_static("Authorization, Content-Type"), + ); + h.insert( + "Access-Control-Allow-Methods", + HeaderValue::from_static("GET, POST, OPTIONS"), + ); + h.insert("Content-Type", HeaderValue::from_static("application/json")); + Ok(response) } -fn empty() -> BoxBody { - Empty::::new() - .map_err(|never| match never {}) - .boxed() -} +// --------------------------------------------------------------------------- +// Configuration loading +// --------------------------------------------------------------------------- -#[tokio::main] -async fn main() -> Result<(), Box> { - let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string()); - - // loading plano_config.yaml file (before tracing init so we can read tracing config) - let plano_config_path = env::var("PLANO_CONFIG_PATH_RENDERED") +/// Load and parse the YAML configuration file. +/// +/// The path is read from `PLANO_CONFIG_PATH_RENDERED` (env) or falls back to +/// `./plano_config_rendered.yaml`. +fn load_config() -> Result> { + let path = env::var("PLANO_CONFIG_PATH_RENDERED") .unwrap_or_else(|_| "./plano_config_rendered.yaml".to_string()); - eprintln!("loading plano_config.yaml from {}", plano_config_path); + eprintln!("loading plano_config.yaml from {}", path); - let config_contents = - fs::read_to_string(&plano_config_path).expect("Failed to read plano_config.yaml"); + let contents = fs::read_to_string(&path).map_err(|e| format!("failed to read {path}: {e}"))?; let config: Configuration = - serde_yaml::from_str(&config_contents).expect("Failed to parse plano_config.yaml"); + serde_yaml::from_str(&contents).map_err(|e| format!("failed to parse {path}: {e}"))?; - // Initialize tracing using config.yaml tracing section - let _tracer_provider = init_tracer(config.tracing.as_ref()); - info!(path = %plano_config_path, "loaded plano_config.yaml"); + Ok(config) +} - let plano_config = Arc::new(config); +// --------------------------------------------------------------------------- +// Application state initialization +// --------------------------------------------------------------------------- - // combine agents and filters into a single list of agents - let all_agents: Vec = plano_config +/// Build the shared [`AppState`] from a parsed [`Configuration`]. +async fn init_app_state( + config: &Configuration, +) -> Result> { + let llm_provider_url = + env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); + + // Combine agents and filters into a single list + let all_agents: Vec = config .agents .as_deref() .unwrap_or_default() .iter() - .chain(plano_config.filters.as_deref().unwrap_or_default()) + .chain(config.filters.as_deref().unwrap_or_default()) .cloned() .collect(); - // Build global agent map for resolving filter chain references let global_agent_map: HashMap = all_agents .iter() .map(|a| (a.id.clone(), a.clone())) .collect(); - // Create expanded provider list for /v1/models endpoint - let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone()) - .expect("Failed to create LlmProviders"); - let llm_providers = Arc::new(RwLock::new(llm_providers)); - let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); + let llm_providers = LlmProviders::try_from(config.model_providers.clone()) + .map_err(|e| format!("failed to create LlmProviders: {e}"))?; - // Resolve model listener filter chain and agents at startup - let model_listener_count = plano_config + let model_listener_count = config .listeners .iter() .filter(|l| l.listener_type == ListenerType::Model) .count(); - assert!( - model_listener_count <= 1, - "only one model listener is allowed, found {}", - model_listener_count - ); - let model_listener = plano_config + if model_listener_count > 1 { + return Err(format!( + "only one model listener is allowed, found {}", + model_listener_count + ) + .into()); + } + let model_listener = config .listeners .iter() .find(|l| l.listener_type == ListenerType::Model); @@ -114,7 +129,11 @@ async fn main() -> Result<(), Box> { filter_ids.map(|ids| { let agents = ids .iter() - .filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone()))) + .filter_map(|id| { + global_agent_map + .get(id) + .map(|a: &Agent| (id.clone(), a.clone())) + }) .collect(); ResolvedFilterChain { filter_ids: ids, @@ -126,14 +145,9 @@ async fn main() -> Result<(), Box> { input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())), output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())), }); - let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); - let llm_provider_url = - env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); - let listener = TcpListener::bind(bind_address).await?; - let overrides = plano_config.overrides.clone().unwrap_or_default(); + let overrides = config.overrides.clone().unwrap_or_default(); - // Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests let routing_model_name: String = overrides .llm_routing_model .as_deref() @@ -141,21 +155,20 @@ async fn main() -> Result<(), Box> { .unwrap_or(DEFAULT_ROUTING_MODEL_NAME) .to_string(); - let routing_llm_provider = plano_config + let routing_llm_provider = config .model_providers .iter() .find(|p| p.model.as_deref() == Some(routing_model_name.as_str())) .map(|p| p.name.clone()) .unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string()); - let router_service: Arc = Arc::new(RouterService::new( - plano_config.model_providers.clone(), + let router_service = Arc::new(RouterService::new( + config.model_providers.clone(), format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), routing_model_name, routing_llm_provider, )); - // Strip provider prefix (e.g. "arch/") to get the model ID used in upstream requests let orchestrator_model_name: String = overrides .agent_orchestration_model .as_deref() @@ -163,213 +176,205 @@ async fn main() -> Result<(), Box> { .unwrap_or(DEFAULT_ORCHESTRATOR_MODEL_NAME) .to_string(); - let orchestrator_llm_provider: String = plano_config + let orchestrator_llm_provider: String = config .model_providers .iter() .find(|p| p.model.as_deref() == Some(orchestrator_model_name.as_str())) .map(|p| p.name.clone()) .unwrap_or_else(|| DEFAULT_ORCHESTRATOR_LLM_PROVIDER.to_string()); - let orchestrator_service: Arc = Arc::new(OrchestratorService::new( + let orchestrator_service = Arc::new(OrchestratorService::new( format!("{llm_provider_url}{CHAT_COMPLETIONS_PATH}"), orchestrator_model_name, orchestrator_llm_provider, )); - let model_aliases = Arc::new(plano_config.model_aliases.clone()); - let span_attributes = Arc::new( - plano_config - .tracing - .as_ref() - .and_then(|tracing| tracing.span_attributes.clone()), - ); + let state_storage = init_state_storage(config).await?; - // Initialize trace collector and start background flusher - // Tracing is enabled if the tracing config is present in plano_config.yaml - // Pass Some(true/false) to override, or None to use env var OTEL_TRACING_ENABLED - // OpenTelemetry automatic instrumentation is configured in utils/tracing.rs + let span_attributes = config + .tracing + .as_ref() + .and_then(|tracing| tracing.span_attributes.clone()); - // Initialize conversation state storage for v1/responses - // Configurable via plano_config.yaml state_storage section - // If not configured, state management is disabled - // Environment variables are substituted by envsubst before config is read - let state_storage: Option> = - if let Some(storage_config) = &plano_config.state_storage { - let storage: Arc = match storage_config.storage_type { - common::configuration::StateStorageType::Memory => { - info!( - storage_type = "memory", - "initialized conversation state storage" - ); - Arc::new(MemoryConversationalStorage::new()) - } - common::configuration::StateStorageType::Postgres => { - let connection_string = storage_config - .connection_string - .as_ref() - .expect("connection_string is required for postgres state_storage"); + Ok(AppState { + router_service, + orchestrator_service, + model_aliases: config.model_aliases.clone(), + llm_providers: Arc::new(RwLock::new(llm_providers)), + agents_list: Some(all_agents), + listeners: config.listeners.clone(), + state_storage, + llm_provider_url, + span_attributes, + http_client: reqwest::Client::new(), + filter_pipeline, + }) +} - debug!(connection_string = %connection_string, "postgres connection"); - info!( - storage_type = "postgres", - "initializing conversation state storage" - ); - Arc::new( - PostgreSQLConversationStorage::new(connection_string.clone()) - .await - .expect("Failed to initialize Postgres state storage"), - ) - } - }; - Some(storage) - } else { - info!("no state_storage configured, conversation state management disabled"); - None - }; +/// Initialize the conversation state storage backend (if configured). +async fn init_state_storage( + config: &Configuration, +) -> Result>, Box> { + let Some(storage_config) = &config.state_storage else { + info!("no state_storage configured, conversation state management disabled"); + return Ok(None); + }; - loop { - let (stream, _) = listener.accept().await?; - let peer_addr = stream.peer_addr()?; - let io = TokioIo::new(stream); + let storage: Arc = match storage_config.storage_type { + common::configuration::StateStorageType::Memory => { + info!( + storage_type = "memory", + "initialized conversation state storage" + ); + Arc::new(MemoryConversationalStorage::new()) + } + common::configuration::StateStorageType::Postgres => { + let connection_string = storage_config + .connection_string + .as_ref() + .ok_or("connection_string is required for postgres state_storage")?; - let router_service: Arc = Arc::clone(&router_service); - let orchestrator_service: Arc = Arc::clone(&orchestrator_service); - let model_aliases: Arc< - Option>, - > = Arc::clone(&model_aliases); - let llm_provider_url = llm_provider_url.clone(); + debug!(connection_string = %connection_string, "postgres connection"); + info!( + storage_type = "postgres", + "initializing conversation state storage" + ); - let llm_providers = llm_providers.clone(); - let agents_list = combined_agents_filters_list.clone(); - let filter_pipeline = filter_pipeline.clone(); - let listeners = listeners.clone(); - let span_attributes = span_attributes.clone(); - let state_storage = state_storage.clone(); - let service = service_fn(move |req| { - let router_service = Arc::clone(&router_service); - let orchestrator_service = Arc::clone(&orchestrator_service); - let parent_cx = extract_context_from_request(&req); - let llm_provider_url = llm_provider_url.clone(); - let llm_providers = llm_providers.clone(); - let model_aliases = Arc::clone(&model_aliases); - let agents_list = agents_list.clone(); - let filter_pipeline = filter_pipeline.clone(); - let listeners = listeners.clone(); - let span_attributes = span_attributes.clone(); - let state_storage = state_storage.clone(); + Arc::new( + PostgreSQLConversationStorage::new(connection_string.clone()) + .await + .map_err(|e| format!("failed to initialize Postgres state storage: {e}"))?, + ) + } + }; - async move { - let path = req.uri().path().to_string(); - // Check if path starts with /agents - if path.starts_with("/agents") { - // Check if it matches one of the agent API paths - let stripped_path = path.strip_prefix("/agents").unwrap(); - if matches!( - stripped_path, - CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH - ) { - let fully_qualified_url = format!("{}{}", llm_provider_url, stripped_path); - return agent_chat( - req, - orchestrator_service, - fully_qualified_url, - agents_list, - listeners, - span_attributes, - llm_providers, - ) - .with_context(parent_cx) - .await; - } - } - if let Some(stripped_path) = path.strip_prefix("/routing") { - let stripped_path = stripped_path.to_string(); - if matches!( - stripped_path.as_str(), - CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH - ) { - return routing_decision( - req, - router_service, - stripped_path, - span_attributes, - ) - .with_context(parent_cx) - .await; - } - } - match (req.method(), path.as_str()) { - ( - &Method::POST, - CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH, - ) => { - let fully_qualified_url = format!("{}{}", llm_provider_url, path); - llm_chat( - req, - router_service, - fully_qualified_url, - model_aliases, - llm_providers, - span_attributes, - state_storage, - filter_pipeline, - ) - .with_context(parent_cx) - .await - } - (&Method::POST, "/function_calling") => { - let fully_qualified_url = - format!("{}{}", llm_provider_url, "/v1/chat/completions"); - function_calling_chat_handler(req, fully_qualified_url) - .with_context(parent_cx) - .await - } - (&Method::GET, "/v1/models" | "/agents/v1/models") => { - Ok(list_models(llm_providers).await) - } - // hack for now to get openw-web-ui to work - (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => { - let mut response = Response::new(empty()); - *response.status_mut() = StatusCode::NO_CONTENT; - response - .headers_mut() - .insert("Allow", "GET, OPTIONS".parse().unwrap()); - response - .headers_mut() - .insert("Access-Control-Allow-Origin", "*".parse().unwrap()); - response.headers_mut().insert( - "Access-Control-Allow-Headers", - "Authorization, Content-Type".parse().unwrap(), - ); - response.headers_mut().insert( - "Access-Control-Allow-Methods", - "GET, POST, OPTIONS".parse().unwrap(), - ); - response - .headers_mut() - .insert("Content-Type", "application/json".parse().unwrap()); + Ok(Some(storage)) +} - Ok(response) - } - _ => { - debug!(method = %req.method(), path = %path, "no route found"); - let mut not_found = Response::new(empty()); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } - } - }); +// --------------------------------------------------------------------------- +// Request routing +// --------------------------------------------------------------------------- - tokio::task::spawn(async move { - debug!(peer = ?peer_addr, "accepted connection"); - if let Err(err) = http1::Builder::new() - // .serve_connection(io, service_fn(chat_completion)) - .serve_connection(io, service) +/// Route an incoming HTTP request to the appropriate handler. +async fn route( + req: Request, + state: Arc, +) -> Result>, hyper::Error> { + let parent_cx = global::get_text_map_propagator(|p| p.extract(&HeaderExtractor(req.headers()))); + let path = req.uri().path().to_string(); + + // --- Agent routes (/agents/...) --- + if let Some(stripped) = path.strip_prefix("/agents") { + if matches!( + stripped, + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH + ) { + return agent_chat(req, Arc::clone(&state)) + .with_context(parent_cx) + .await; + } + } + + // --- Routing decision routes (/routing/...) --- + if let Some(stripped) = path.strip_prefix("/routing") { + let stripped = stripped.to_string(); + if matches!( + stripped.as_str(), + CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH + ) { + return routing_decision( + req, + Arc::clone(&state.router_service), + stripped, + &state.span_attributes, + ) + .with_context(parent_cx) + .await; + } + } + + // --- Standard routes --- + match (req.method(), path.as_str()) { + (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { + llm_chat(req, Arc::clone(&state)) + .with_context(parent_cx) .await - { - warn!(error = ?err, "error serving connection"); - } - }); + } + (&Method::POST, "/function_calling") => { + let url = format!("{}/v1/chat/completions", state.llm_provider_url); + function_calling_chat_handler(req, url) + .with_context(parent_cx) + .await + } + (&Method::GET, "/v1/models" | "/agents/v1/models") => { + Ok(list_models(Arc::clone(&state.llm_providers)).await) + } + (&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => cors_preflight(), + _ => { + debug!(method = %req.method(), path = %path, "no route found"); + let mut not_found = Response::new(empty()); + *not_found.status_mut() = StatusCode::NOT_FOUND; + Ok(not_found) + } } } + +// --------------------------------------------------------------------------- +// Server loop +// --------------------------------------------------------------------------- + +/// Accept connections and spawn a task per connection. +/// +/// Listens for `SIGINT` / `ctrl-c` and shuts down gracefully, allowing +/// in-flight connections to finish. +async fn run_server(state: Arc) -> Result<(), Box> { + let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string()); + let listener = TcpListener::bind(&bind_address).await?; + info!(address = %bind_address, "server listening"); + + let shutdown = tokio::signal::ctrl_c(); + tokio::pin!(shutdown); + + loop { + tokio::select! { + result = listener.accept() => { + let (stream, _) = result?; + let peer_addr = stream.peer_addr()?; + let io = TokioIo::new(stream); + let state = Arc::clone(&state); + + tokio::task::spawn(async move { + debug!(peer = ?peer_addr, "accepted connection"); + + let service = service_fn(move |req| { + let state = Arc::clone(&state); + async move { route(req, state).await } + }); + + if let Err(err) = http1::Builder::new().serve_connection(io, service).await { + warn!(error = ?err, "error serving connection"); + } + }); + } + _ = &mut shutdown => { + info!("received shutdown signal, stopping server"); + break; + } + } + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// Entry point +// --------------------------------------------------------------------------- + +#[tokio::main] +async fn main() -> Result<(), Box> { + let config = load_config()?; + let _tracer_provider = init_tracer(config.tracing.as_ref()); + info!("loaded plano_config.yaml"); + let state = Arc::new(init_app_state(&config).await?); + run_server(state).await +} diff --git a/crates/brightstaff/src/router/http.rs b/crates/brightstaff/src/router/http.rs new file mode 100644 index 00000000..ad1b711c --- /dev/null +++ b/crates/brightstaff/src/router/http.rs @@ -0,0 +1,48 @@ +use hermesllm::apis::openai::ChatCompletionsResponse; +use hyper::header; +use thiserror::Error; +use tracing::warn; + +#[derive(Debug, Error)] +pub enum HttpError { + #[error("Failed to send request: {0}")] + Request(#[from] reqwest::Error), + + #[error("Failed to parse JSON response: {0}")] + Json(serde_json::Error, String), +} + +/// Sends a POST request to the given URL and extracts the text content +/// from the first choice of the `ChatCompletionsResponse`. +/// +/// Returns `Some((content, elapsed))` on success, or `None` if the response +/// had no choices or the first choice had no content. +pub async fn post_and_extract_content( + client: &reqwest::Client, + url: &str, + headers: header::HeaderMap, + body: String, +) -> Result, HttpError> { + let start_time = std::time::Instant::now(); + + let res = client.post(url).headers(headers).body(body).send().await?; + + let body = res.text().await?; + let elapsed = start_time.elapsed(); + + let response: ChatCompletionsResponse = serde_json::from_str(&body).map_err(|err| { + warn!(error = %err, body = %body, "failed to parse json response"); + HttpError::Json(err, format!("Failed to parse JSON: {}", body)) + })?; + + if response.choices.is_empty() { + warn!(body = %body, "no choices in response"); + return Ok(None); + } + + Ok(response.choices[0] + .message + .content + .as_ref() + .map(|c| (c.clone(), elapsed))) +} diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs new file mode 100644 index 00000000..7d27e80a --- /dev/null +++ b/crates/brightstaff/src/router/llm.rs @@ -0,0 +1,148 @@ +use std::{collections::HashMap, sync::Arc}; + +use common::{ + configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, + consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}, +}; +use hermesllm::apis::openai::Message; +use hyper::header; +use thiserror::Error; +use tracing::{debug, info}; + +use super::http::{self, post_and_extract_content}; +use super::router_model::RouterModel; + +use crate::router::router_model_v1; + +pub struct RouterService { + router_url: String, + client: reqwest::Client, + router_model: Arc, + routing_provider_name: String, + llm_usage_defined: bool, +} + +#[derive(Debug, Error)] +pub enum RoutingError { + #[error(transparent)] + Http(#[from] http::HttpError), + + #[error("Router model error: {0}")] + RouterModelError(#[from] super::router_model::RoutingModelError), +} + +pub type Result = std::result::Result; + +impl RouterService { + pub fn new( + providers: Vec, + router_url: String, + routing_model_name: String, + routing_provider_name: String, + ) -> Self { + let providers_with_usage = providers + .iter() + .filter(|provider| provider.routing_preferences.is_some()) + .cloned() + .collect::>(); + + let llm_routes: HashMap> = providers_with_usage + .iter() + .filter_map(|provider| { + provider + .routing_preferences + .as_ref() + .map(|prefs| (provider.name.clone(), prefs.clone())) + }) + .collect(); + + let router_model = Arc::new(router_model_v1::RouterModelV1::new( + llm_routes, + routing_model_name, + router_model_v1::MAX_TOKEN_LEN, + )); + + RouterService { + router_url, + client: reqwest::Client::new(), + router_model, + routing_provider_name, + llm_usage_defined: !providers_with_usage.is_empty(), + } + } + + pub async fn determine_route( + &self, + messages: &[Message], + traceparent: &str, + usage_preferences: Option>, + request_id: &str, + ) -> Result> { + if messages.is_empty() { + return Ok(None); + } + + if usage_preferences + .as_ref() + .is_none_or(|prefs| prefs.len() < 2) + && !self.llm_usage_defined + { + return Ok(None); + } + + let router_request = self + .router_model + .generate_request(messages, &usage_preferences); + + debug!( + model = %self.router_model.get_model_name(), + endpoint = %self.router_url, + "sending request to arch-router" + ); + + let body = serde_json::to_string(&router_request) + .map_err(super::router_model::RoutingModelError::from)?; + debug!(body = %body, "arch router request"); + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) { + headers.insert( + header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), + val, + ); + } + if let Ok(val) = header::HeaderValue::from_str(traceparent) { + headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val); + } + if let Ok(val) = header::HeaderValue::from_str(request_id) { + headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val); + } + headers.insert( + header::HeaderName::from_static("model"), + header::HeaderValue::from_static("arch-router"), + ); + + let Some((content, elapsed)) = + post_and_extract_content(&self.client, &self.router_url, headers, body).await? + else { + return Ok(None); + }; + + let parsed = self + .router_model + .parse_response(&content, &usage_preferences)?; + + info!( + content = %content.replace("\n", "\\n"), + selected_model = ?parsed, + response_time_ms = elapsed.as_millis(), + "arch-router determined route" + ); + + Ok(parsed) + } +} diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs deleted file mode 100644 index ec3fe3ab..00000000 --- a/crates/brightstaff/src/router/llm_router.rs +++ /dev/null @@ -1,187 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use common::{ - configuration::{LlmProvider, ModelUsagePreference, RoutingPreference}, - consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER}, -}; -use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; -use hyper::header; -use thiserror::Error; -use tracing::{debug, info, warn}; - -use crate::router::router_model_v1::{self}; - -use super::router_model::RouterModel; - -pub struct RouterService { - router_url: String, - client: reqwest::Client, - router_model: Arc, - #[allow(dead_code)] - routing_provider_name: String, - llm_usage_defined: bool, -} - -#[derive(Debug, Error)] -pub enum RoutingError { - #[error("Failed to send request: {0}")] - RequestError(#[from] reqwest::Error), - - #[error("Failed to parse JSON: {0}, JSON: {1}")] - JsonError(serde_json::Error, String), - - #[error("Router model error: {0}")] - RouterModelError(#[from] super::router_model::RoutingModelError), -} - -pub type Result = std::result::Result; - -impl RouterService { - pub fn new( - providers: Vec, - router_url: String, - routing_model_name: String, - routing_provider_name: String, - ) -> Self { - let providers_with_usage = providers - .iter() - .filter(|provider| provider.routing_preferences.is_some()) - .cloned() - .collect::>(); - - let llm_routes: HashMap> = providers_with_usage - .iter() - .filter_map(|provider| { - provider - .routing_preferences - .as_ref() - .map(|prefs| (provider.name.clone(), prefs.clone())) - }) - .collect(); - - let router_model = Arc::new(router_model_v1::RouterModelV1::new( - llm_routes, - routing_model_name, - router_model_v1::MAX_TOKEN_LEN, - )); - - RouterService { - router_url, - client: reqwest::Client::new(), - router_model, - routing_provider_name, - llm_usage_defined: !providers_with_usage.is_empty(), - } - } - - pub async fn determine_route( - &self, - messages: &[Message], - traceparent: &str, - usage_preferences: Option>, - request_id: &str, - ) -> Result> { - if messages.is_empty() { - return Ok(None); - } - - if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2) - && !self.llm_usage_defined - { - return Ok(None); - } - - let router_request = self - .router_model - .generate_request(messages, &usage_preferences); - - debug!( - model = %self.router_model.get_model_name(), - endpoint = %self.router_url, - "sending request to arch-router" - ); - - debug!( - body = %serde_json::to_string(&router_request).unwrap(), - "arch router request" - ); - - let mut llm_route_request_headers = header::HeaderMap::new(); - llm_route_request_headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - llm_route_request_headers.insert( - header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(&self.routing_provider_name).unwrap(), - ); - - llm_route_request_headers.insert( - header::HeaderName::from_static(TRACE_PARENT_HEADER), - header::HeaderValue::from_str(traceparent).unwrap(), - ); - - llm_route_request_headers.insert( - header::HeaderName::from_static(REQUEST_ID_HEADER), - header::HeaderValue::from_str(request_id).unwrap(), - ); - - llm_route_request_headers.insert( - header::HeaderName::from_static("model"), - header::HeaderValue::from_static("arch-router"), - ); - - let start_time = std::time::Instant::now(); - let res = self - .client - .post(&self.router_url) - .headers(llm_route_request_headers) - .body(serde_json::to_string(&router_request).unwrap()) - .send() - .await?; - - let body = res.text().await?; - let router_response_time = start_time.elapsed(); - - let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) { - Ok(response) => response, - Err(err) => { - warn!( - error = %err, - body = %serde_json::to_string(&body).unwrap(), - "failed to parse json response" - ); - return Err(RoutingError::JsonError( - err, - format!("Failed to parse JSON: {}", body), - )); - } - }; - - if chat_completion_response.choices.is_empty() { - warn!(body = %body, "no choices in router response"); - return Ok(None); - } - - if let Some(content) = &chat_completion_response.choices[0].message.content { - let parsed_response = self - .router_model - .parse_response(content, &usage_preferences)?; - info!( - content = %content.replace("\n", "\\n"), - selected_model = ?parsed_response, - response_time_ms = router_response_time.as_millis(), - "arch-router determined route" - ); - - if let Some(ref parsed_response) = parsed_response { - return Ok(Some(parsed_response.clone())); - } - - Ok(None) - } else { - Ok(None) - } - } -} diff --git a/crates/brightstaff/src/router/mod.rs b/crates/brightstaff/src/router/mod.rs index 9b1abbea..b010d80c 100644 --- a/crates/brightstaff/src/router/mod.rs +++ b/crates/brightstaff/src/router/mod.rs @@ -1,6 +1,7 @@ -pub mod llm_router; +pub(crate) mod http; +pub mod llm; +pub mod orchestrator; pub mod orchestrator_model; pub mod orchestrator_model_v1; -pub mod plano_orchestrator; pub mod router_model; pub mod router_model_v1; diff --git a/crates/brightstaff/src/router/orchestrator.rs b/crates/brightstaff/src/router/orchestrator.rs new file mode 100644 index 00000000..9ff76371 --- /dev/null +++ b/crates/brightstaff/src/router/orchestrator.rs @@ -0,0 +1,139 @@ +use std::{collections::HashMap, sync::Arc}; + +use common::{ + configuration::{AgentUsagePreference, OrchestrationPreference}, + consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER}, +}; +use hermesllm::apis::openai::Message; +use hyper::header; +use opentelemetry::global; +use opentelemetry_http::HeaderInjector; +use thiserror::Error; +use tracing::{debug, info}; + +use super::http::{self, post_and_extract_content}; +use super::orchestrator_model::OrchestratorModel; + +use crate::router::orchestrator_model_v1; + +pub struct OrchestratorService { + orchestrator_url: String, + client: reqwest::Client, + orchestrator_model: Arc, + orchestrator_provider_name: String, +} + +#[derive(Debug, Error)] +pub enum OrchestrationError { + #[error(transparent)] + Http(#[from] http::HttpError), + + #[error("Orchestrator model error: {0}")] + OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError), +} + +pub type Result = std::result::Result; + +impl OrchestratorService { + pub fn new( + orchestrator_url: String, + orchestration_model_name: String, + orchestrator_provider_name: String, + ) -> Self { + let agent_orchestrations: HashMap> = HashMap::new(); + + let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new( + agent_orchestrations, + orchestration_model_name.clone(), + orchestrator_model_v1::MAX_TOKEN_LEN, + )); + + OrchestratorService { + orchestrator_url, + client: reqwest::Client::new(), + orchestrator_model, + orchestrator_provider_name, + } + } + + pub async fn determine_orchestration( + &self, + messages: &[Message], + usage_preferences: Option>, + request_id: Option, + ) -> Result>> { + if messages.is_empty() { + return Ok(None); + } + + if usage_preferences + .as_ref() + .is_none_or(|prefs| prefs.is_empty()) + { + return Ok(None); + } + + let orchestrator_request = self + .orchestrator_model + .generate_request(messages, &usage_preferences); + + debug!( + model = %self.orchestrator_model.get_model_name(), + endpoint = %self.orchestrator_url, + "sending request to arch-orchestrator" + ); + + let body = serde_json::to_string(&orchestrator_request) + .map_err(super::orchestrator_model::OrchestratorModelError::from)?; + debug!(body = %body, "arch orchestrator request"); + + let mut headers = header::HeaderMap::new(); + headers.insert( + header::CONTENT_TYPE, + header::HeaderValue::from_static("application/json"), + ); + headers.insert( + header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), + header::HeaderValue::from_str(&self.orchestrator_provider_name) + .unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")), + ); + + // Inject OpenTelemetry trace context from current span + global::get_text_map_propagator(|propagator| { + let cx = + tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); + propagator.inject_context(&cx, &mut HeaderInjector(&mut headers)); + }); + + if let Some(ref request_id) = request_id { + if let Ok(val) = header::HeaderValue::from_str(request_id) { + headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val); + } + } + + headers.insert( + header::HeaderName::from_static("model"), + header::HeaderValue::from_str(&self.orchestrator_provider_name) + .unwrap_or_else(|_| header::HeaderValue::from_static("plano-orchestrator")), + ); + + let Some((content, elapsed)) = + post_and_extract_content(&self.client, &self.orchestrator_url, headers, body).await? + else { + return Ok(None); + }; + + let parsed = self + .orchestrator_model + .parse_response(&content, &usage_preferences)?; + + info!( + content = %content.replace("\n", "\\n"), + selected_routes = ?parsed, + response_time_ms = elapsed.as_millis(), + "arch-orchestrator determined routes" + ); + + Ok(parsed) + } +} diff --git a/crates/brightstaff/src/router/plano_orchestrator.rs b/crates/brightstaff/src/router/plano_orchestrator.rs deleted file mode 100644 index 12140570..00000000 --- a/crates/brightstaff/src/router/plano_orchestrator.rs +++ /dev/null @@ -1,174 +0,0 @@ -use std::{collections::HashMap, sync::Arc}; - -use common::{ - configuration::{AgentUsagePreference, OrchestrationPreference}, - consts::{ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER}, -}; -use hermesllm::apis::openai::{ChatCompletionsResponse, Message}; -use hyper::header; -use opentelemetry::global; -use opentelemetry_http::HeaderInjector; -use thiserror::Error; -use tracing::{debug, info, warn}; - -use crate::router::orchestrator_model_v1::{self}; - -use super::orchestrator_model::OrchestratorModel; - -pub struct OrchestratorService { - orchestrator_url: String, - client: reqwest::Client, - orchestrator_model: Arc, - orchestrator_provider_name: String, -} - -#[derive(Debug, Error)] -pub enum OrchestrationError { - #[error("Failed to send request: {0}")] - RequestError(#[from] reqwest::Error), - - #[error("Failed to parse JSON: {0}, JSON: {1}")] - JsonError(serde_json::Error, String), - - #[error("Orchestrator model error: {0}")] - OrchestratorModelError(#[from] super::orchestrator_model::OrchestratorModelError), -} - -pub type Result = std::result::Result; - -impl OrchestratorService { - pub fn new( - orchestrator_url: String, - orchestration_model_name: String, - orchestrator_provider_name: String, - ) -> Self { - // Empty agent orchestrations - will be provided via usage_preferences in requests - let agent_orchestrations: HashMap> = HashMap::new(); - - let orchestrator_model = Arc::new(orchestrator_model_v1::OrchestratorModelV1::new( - agent_orchestrations, - orchestration_model_name.clone(), - orchestrator_model_v1::MAX_TOKEN_LEN, - )); - - OrchestratorService { - orchestrator_url, - client: reqwest::Client::new(), - orchestrator_model, - orchestrator_provider_name, - } - } - - pub async fn determine_orchestration( - &self, - messages: &[Message], - usage_preferences: Option>, - request_id: Option, - ) -> Result>> { - if messages.is_empty() { - return Ok(None); - } - - // Require usage_preferences to be provided - if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() { - return Ok(None); - } - - let orchestrator_request = self - .orchestrator_model - .generate_request(messages, &usage_preferences); - - debug!( - model = %self.orchestrator_model.get_model_name(), - endpoint = %self.orchestrator_url, - "sending request to plano-orchestrator" - ); - - debug!( - body = %serde_json::to_string(&orchestrator_request).unwrap(), - "plano orchestrator request" - ); - - let mut orchestration_request_headers = header::HeaderMap::new(); - orchestration_request_headers.insert( - header::CONTENT_TYPE, - header::HeaderValue::from_static("application/json"), - ); - - orchestration_request_headers.insert( - header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(), - ); - - // Inject OpenTelemetry trace context from current span - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut orchestration_request_headers)); - }); - - if let Some(request_id) = request_id { - orchestration_request_headers.insert( - header::HeaderName::from_static(REQUEST_ID_HEADER), - header::HeaderValue::from_str(&request_id).unwrap(), - ); - } - - orchestration_request_headers.insert( - header::HeaderName::from_static("model"), - header::HeaderValue::from_str(&self.orchestrator_provider_name).unwrap(), - ); - - let start_time = std::time::Instant::now(); - let res = self - .client - .post(&self.orchestrator_url) - .headers(orchestration_request_headers) - .body(serde_json::to_string(&orchestrator_request).unwrap()) - .send() - .await?; - - let body = res.text().await?; - let orchestrator_response_time = start_time.elapsed(); - - let chat_completion_response: ChatCompletionsResponse = match serde_json::from_str(&body) { - Ok(response) => response, - Err(err) => { - warn!( - error = %err, - body = %serde_json::to_string(&body).unwrap(), - "failed to parse json response" - ); - return Err(OrchestrationError::JsonError( - err, - format!("Failed to parse JSON: {}", body), - )); - } - }; - - if chat_completion_response.choices.is_empty() { - warn!(body = %body, "no choices in orchestrator response"); - return Ok(None); - } - - if let Some(content) = &chat_completion_response.choices[0].message.content { - let parsed_response = self - .orchestrator_model - .parse_response(content, &usage_preferences)?; - info!( - content = %content.replace("\n", "\\n"), - selected_routes = ?parsed_response, - response_time_ms = orchestrator_response_time.as_millis(), - "arch-orchestrator determined routes" - ); - - if let Some(ref parsed_response) = parsed_response { - return Ok(Some(parsed_response.clone())); - } - - Ok(None) - } else { - Ok(None) - } - } -} diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index 3d59f359..43454ee2 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -88,35 +88,18 @@ pub trait StateStorage: Send + Sync { combined_input.extend(current_input); debug!( - "PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}", - prev_state.response_id, - prev_count, - current_count, - combined_input.len(), - serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + response_id = %prev_state.response_id, + prev_items = prev_count, + current_items = current_count, + total_items = combined_input.len(), + combined_json = %serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()), + "merged conversation state" ); combined_input } } -/// Storage backend type enum -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StorageBackend { - Memory, - Supabase, -} - -impl StorageBackend { - pub fn parse_backend(s: &str) -> Option { - match s.to_lowercase().as_str() { - "memory" => Some(StorageBackend::Memory), - "supabase" => Some(StorageBackend::Supabase), - _ => None, - } - } -} - // === Utility functions for state management === /// Extract input items from InputParam, converting text to structured format diff --git a/crates/brightstaff/src/state/response_state_processor.rs b/crates/brightstaff/src/state/response_state_processor.rs index 6f6c7b62..bca2991a 100644 --- a/crates/brightstaff/src/state/response_state_processor.rs +++ b/crates/brightstaff/src/state/response_state_processor.rs @@ -7,8 +7,8 @@ use std::io::Read; use std::sync::Arc; use tracing::{debug, info, warn}; -use crate::handlers::streaming::StreamProcessor; use crate::state::{OpenAIConversationState, StateStorage}; +use crate::streaming::StreamProcessor; /// Processor that wraps another processor and handles v1/responses state management /// Captures response_id and output from streaming responses, stores state after completion diff --git a/crates/brightstaff/src/handlers/streaming.rs b/crates/brightstaff/src/streaming.rs similarity index 99% rename from crates/brightstaff/src/handlers/streaming.rs rename to crates/brightstaff/src/streaming.rs index 0cea182e..f7af8ae0 100644 --- a/crates/brightstaff/src/handlers/streaming.rs +++ b/crates/brightstaff/src/streaming.rs @@ -13,7 +13,7 @@ use tokio_stream::StreamExt; use tracing::{debug, info, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use super::pipeline_processor::{PipelineError, PipelineProcessor}; +use crate::handlers::agents::pipeline::{PipelineError, PipelineProcessor}; const STREAM_BUFFER_SIZE: usize = 16; use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; diff --git a/crates/brightstaff/src/tracing/custom_attributes.rs b/crates/brightstaff/src/tracing/custom_attributes.rs index 24abc72b..7d4244d2 100644 --- a/crates/brightstaff/src/tracing/custom_attributes.rs +++ b/crates/brightstaff/src/tracing/custom_attributes.rs @@ -52,6 +52,7 @@ pub fn collect_custom_trace_attributes( attributes } +#[allow(dead_code)] pub fn append_span_attributes( mut span_builder: SpanBuilder, attributes: &HashMap, diff --git a/crates/brightstaff/src/utils/tracing.rs b/crates/brightstaff/src/tracing/init.rs similarity index 94% rename from crates/brightstaff/src/utils/tracing.rs rename to crates/brightstaff/src/tracing/init.rs index 21882303..ed351148 100644 --- a/crates/brightstaff/src/utils/tracing.rs +++ b/crates/brightstaff/src/tracing/init.rs @@ -11,7 +11,7 @@ use tracing_subscriber::registry::LookupSpan; use tracing_subscriber::util::SubscriberInitExt; use tracing_subscriber::EnvFilter; -use crate::tracing::ServiceNameOverrideExporter; +use super::ServiceNameOverrideExporter; use common::configuration::Tracing; struct BracketedTime; @@ -96,17 +96,16 @@ pub fn init_tracer(tracing_config: Option<&Tracing>) -> &'static SdkTracerProvid tracing_enabled, otel_endpoint, random_sampling ); - // Create OTLP exporter to send spans to collector - if tracing_enabled { - // Set service name via environment if not already set + // Create OTLP exporter to send spans to collector. + // Use `if let` to destructure the endpoint, avoiding an unwrap. + if let Some(endpoint) = otel_endpoint.as_deref().filter(|_| tracing_enabled) { if std::env::var("OTEL_SERVICE_NAME").is_err() { std::env::set_var("OTEL_SERVICE_NAME", "plano"); } - // Create ServiceNameOverrideExporter to support per-span service names // This allows spans to have different service names (e.g., plano(orchestrator), // plano(filter), plano(llm)) by setting the "service.name.override" attribute - let exporter = ServiceNameOverrideExporter::new(otel_endpoint.as_ref().unwrap()); + let exporter = ServiceNameOverrideExporter::new(endpoint); let provider = SdkTracerProvider::builder() .with_batch_exporter(exporter) diff --git a/crates/brightstaff/src/tracing/mod.rs b/crates/brightstaff/src/tracing/mod.rs index 1fa8a7e2..644db31a 100644 --- a/crates/brightstaff/src/tracing/mod.rs +++ b/crates/brightstaff/src/tracing/mod.rs @@ -1,11 +1,13 @@ mod constants; mod custom_attributes; +mod init; mod service_name_exporter; pub use constants::{ error, http, llm, operation_component, routing, signals, OperationNameBuilder, }; -pub use custom_attributes::{append_span_attributes, collect_custom_trace_attributes}; +pub use custom_attributes::collect_custom_trace_attributes; +pub use init::init_tracer; pub use service_name_exporter::{ServiceNameOverrideExporter, SERVICE_NAME_OVERRIDE_KEY}; use opentelemetry::trace::get_active_span; diff --git a/crates/brightstaff/src/utils/mod.rs b/crates/brightstaff/src/utils/mod.rs deleted file mode 100644 index 5ee45fbc..00000000 --- a/crates/brightstaff/src/utils/mod.rs +++ /dev/null @@ -1 +0,0 @@ -pub mod tracing;