diff --git a/crates/brightstaff/src/app_state.rs b/crates/brightstaff/src/app_state.rs index b7732c9d..8b781ee9 100644 --- a/crates/brightstaff/src/app_state.rs +++ b/crates/brightstaff/src/app_state.rs @@ -22,4 +22,6 @@ pub struct AppState { pub listeners: Arc>>, pub state_storage: Option>, pub llm_provider_url: String, + /// Shared HTTP client for upstream LLM requests (connection pooling / keep-alive). + pub http_client: reqwest::Client, } diff --git a/crates/brightstaff/src/handlers/agents/orchestrator.rs b/crates/brightstaff/src/handlers/agents/orchestrator.rs index ff4fad88..e6ec26fc 100644 --- a/crates/brightstaff/src/handlers/agents/orchestrator.rs +++ b/crates/brightstaff/src/handlers/agents/orchestrator.rs @@ -15,10 +15,10 @@ use tracing::{debug, info, info_span, warn, Instrument}; use super::pipeline::{PipelineError, PipelineProcessor}; use super::selector::{AgentSelectionError, AgentSelector}; +use crate::app_state::AppState; use crate::handlers::errors::build_error_chain_response; use crate::handlers::request::extract_request_id; use crate::handlers::response::ResponseHandler; -use crate::router::orchestrator::OrchestratorService; use crate::tracing::{operation_component, set_service_name}; /// Main errors for agent chat completions @@ -38,9 +38,7 @@ pub enum AgentFilterChainError { pub async fn agent_chat( request: Request, - orchestrator_service: Arc, - agents_list: Arc>>>, - listeners: Arc>>, + state: Arc, ) -> Result>, hyper::Error> { let request_id = extract_request_id(&request); @@ -58,15 +56,7 @@ pub async fn agent_chat( // Set service name for orchestrator operations set_service_name(operation_component::ORCHESTRATOR); - match handle_agent_chat_inner( - request, - orchestrator_service, - agents_list, - listeners, - request_id, - ) - .await - { + match handle_agent_chat_inner(request, state, request_id).await { Ok(response) => Ok(response), Err(err) => { // Check if this is a client error from the pipeline that should be cascaded @@ -112,13 +102,11 @@ pub async fn agent_chat( async fn handle_agent_chat_inner( request: Request, - orchestrator_service: Arc, - agents_list: Arc>>>, - listeners: Arc>>, + state: Arc, request_id: String, ) -> Result>, AgentFilterChainError> { // Initialize services - let agent_selector = AgentSelector::new(orchestrator_service); + let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service)); let mut pipeline_processor = PipelineProcessor::default(); let response_handler = ResponseHandler::new(); @@ -130,7 +118,7 @@ async fn handle_agent_chat_inner( // Find the appropriate listener let listener: common::configuration::Listener = { - let listeners = listeners.read().await; + let listeners = state.listeners.read().await; agent_selector .find_listener(listener_name, &listeners) .await? @@ -143,12 +131,10 @@ async fn handle_agent_chat_inner( info!(listener = %listener.name, "handling request"); // Parse request body - let request_path = request - .uri() - .path() - .to_string() + let full_path = request.uri().path().to_string(); + let request_path = full_path .strip_prefix("/agents") - .unwrap() + .unwrap_or(&full_path) .to_string(); let request_headers = { @@ -201,7 +187,7 @@ async fn handle_agent_chat_inner( // Create agent map for pipeline processing and agent selection let agent_map = { - let agents = agents_list.read().await; + let agents = state.agents_list.read().await; let agents = agents.as_ref().ok_or_else(|| { AgentFilterChainError::RequestParsing(serde_json::Error::custom("No agents configured")) })?; @@ -340,7 +326,10 @@ async fn handle_agent_chat_inner( ); // remove last message and add new one at the end - let last_message = current_messages.pop().unwrap(); + let Some(last_message) = current_messages.pop() else { + warn!(agent = %agent_name, "no messages in conversation history"); + break; + }; // Create a new message with the agent's response as assistant message // and add it to the conversation history diff --git a/crates/brightstaff/src/handlers/agents/pipeline.rs b/crates/brightstaff/src/handlers/agents/pipeline.rs index d0d05889..0b269ce1 100644 --- a/crates/brightstaff/src/handlers/agents/pipeline.rs +++ b/crates/brightstaff/src/handlers/agents/pipeline.rs @@ -310,7 +310,7 @@ impl PipelineProcessor { let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) { session_id.clone() } else { - let session_id = self.get_new_session_id(&agent.id, request_headers).await; + let session_id = self.get_new_session_id(&agent.id, request_headers).await?; self.agent_id_session_map .insert(agent.id.clone(), session_id.clone()); session_id @@ -464,18 +464,19 @@ impl PipelineProcessor { Ok(()) } - async fn get_new_session_id(&self, agent_id: &str, request_headers: &HeaderMap) -> String { + async fn get_new_session_id( + &self, + agent_id: &str, + request_headers: &HeaderMap, + ) -> Result { info!("initializing MCP session for agent {}", agent_id); let initialize_request = self.build_initialize_request(); - let headers = self - .build_mcp_headers(request_headers, agent_id, None) - .expect("Failed to build headers for initialization"); + let headers = self.build_mcp_headers(request_headers, agent_id, None)?; let response = self .send_mcp_request(&initialize_request, &headers, agent_id) - .await - .expect("Failed to initialize MCP session"); + .await?; info!("initialize response status: {}", response.status()); @@ -483,8 +484,13 @@ impl PipelineProcessor { .headers() .get("mcp-session-id") .and_then(|v| v.to_str().ok()) - .expect("No mcp-session-id in response") - .to_string(); + .map(|s| s.to_string()) + .ok_or_else(|| { + PipelineError::NoContentInResponse(format!( + "No mcp-session-id header in initialize response from agent {}", + agent_id + )) + })?; info!( "created new MCP session for agent {}: {}", @@ -493,10 +499,9 @@ impl PipelineProcessor { // Send initialized notification self.send_initialized_notification(agent_id, &session_id, &headers) - .await - .expect("Failed to send initialized notification"); + .await?; - session_id + Ok(session_id) } /// Execute a HTTP-based filter agent @@ -620,8 +625,8 @@ impl PipelineProcessor { let request_url = "/v1/chat/completions"; - let request_body = ProviderRequestType::to_bytes(&original_request).unwrap(); - // let request_body = serde_json::to_string(&request)?; + let request_body = ProviderRequestType::to_bytes(&original_request) + .map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?; debug!("sending request to terminal agent {}", terminal_agent.id); let mut agent_headers = request_headers.clone(); diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index d4f7e763..bc4ff811 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -20,11 +20,11 @@ use tracing::{debug, info, info_span, warn, Instrument}; mod router; +use crate::app_state::AppState; use crate::handlers::request::extract_request_id; use crate::handlers::utils::{ create_streaming_response, truncate_message, ObservableStreamProcessor, }; -use crate::router::llm::RouterService; use crate::state::response_state_processor::ResponsesStateProcessor; use crate::state::{ extract_input_items, retrieve_and_combine_input, StateStorage, StateStorageError, @@ -40,11 +40,7 @@ fn full>(chunk: T) -> BoxBody { pub async fn llm_chat( request: Request, - router_service: Arc, - full_qualified_llm_provider_url: String, - model_aliases: Arc>>, - llm_providers: Arc>, - state_storage: Option>, + state: Arc, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let request_headers = request.headers().clone(); @@ -64,29 +60,14 @@ pub async fn llm_chat( ); // Execute the rest of the handler inside the span - llm_chat_inner( - request, - router_service, - full_qualified_llm_provider_url, - model_aliases, - llm_providers, - state_storage, - request_id, - request_path, - request_headers, - ) - .instrument(request_span) - .await + llm_chat_inner(request, state, request_id, request_path, request_headers) + .instrument(request_span) + .await } -#[allow(clippy::too_many_arguments)] async fn llm_chat_inner( request: Request, - router_service: Arc, - full_qualified_llm_provider_url: String, - model_aliases: Arc>>, - llm_providers: Arc>, - state_storage: Option>, + state: Arc, request_id: String, request_path: String, mut request_headers: hyper::HeaderMap, @@ -96,14 +77,20 @@ async fn llm_chat_inner( let traceparent = extract_or_generate_traceparent(&request_headers); + let full_qualified_llm_provider_url = format!("{}{}", state.llm_provider_url, request_path); + // --- Phase 1: Parse and validate the incoming request --- - let parsed = - match parse_and_validate_request(request, &request_path, &model_aliases, &llm_providers) - .await - { - Ok(p) => p, - Err(response) => return Ok(response), - }; + let parsed = match parse_and_validate_request( + request, + &request_path, + &state.model_aliases, + &state.llm_providers, + ) + .await + { + Ok(p) => p, + Err(response) => return Ok(response), + }; let PreparedRequest { mut client_request, @@ -139,8 +126,8 @@ async fn llm_chat_inner( let state_ctx = match resolve_conversation_state( &mut client_request, is_responses_api_client, - &state_storage, - &llm_providers, + &state.state_storage, + &state.llm_providers, &alias_resolved_model, &request_path, is_streaming_request, @@ -177,7 +164,7 @@ async fn llm_chat_inner( let routing_result = match async { set_service_name(operation_component::ROUTING); router_chat_get_upstream_model( - router_service, + Arc::clone(&state.router_service), client_request, &traceparent, &request_path, @@ -207,6 +194,7 @@ async fn llm_chat_inner( // --- Phase 4: Forward to upstream and stream back --- send_upstream( + &state.http_client, &full_qualified_llm_provider_url, &mut request_headers, client_request_bytes_for_upstream, @@ -218,7 +206,7 @@ async fn llm_chat_inner( is_streaming_request, messages_for_signals, state_ctx, - state_storage, + state.state_storage.clone(), request_id, ) .await @@ -458,6 +446,7 @@ async fn resolve_conversation_state( #[allow(clippy::too_many_arguments)] async fn send_upstream( + http_client: &reqwest::Client, upstream_url: &str, request_headers: &mut hyper::HeaderMap, body: bytes::Bytes, @@ -509,7 +498,7 @@ async fn send_upstream( let request_start_time = std::time::Instant::now(); - let llm_response = match reqwest::Client::new() + let llm_response = match http_client .post(upstream_url) .headers(request_headers.clone()) .body(body) diff --git a/crates/brightstaff/src/handlers/response.rs b/crates/brightstaff/src/handlers/response.rs index 7331ab4c..e9862d2f 100644 --- a/crates/brightstaff/src/handlers/response.rs +++ b/crates/brightstaff/src/handlers/response.rs @@ -112,7 +112,9 @@ impl ResponseHandler { let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap(); + let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).map_err(|e| { + ResponseError::StreamError(format!("Failed to parse SSE stream: {}", e)) + })?; let mut accumulated_text = String::new(); for sse_event in sse_iter { @@ -122,7 +124,13 @@ impl ResponseHandler { } let transformed_event = - SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap(); + match SseEvent::try_from((sse_event, &client_api, &upstream_api)) { + Ok(event) => event, + Err(e) => { + warn!(error = ?e, "failed to transform SSE event, skipping"); + continue; + } + }; // Try to get provider response and extract content delta match transformed_event.provider_response() { diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 174c6130..14235e06 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -145,6 +145,7 @@ async fn init_app_state( listeners: Arc::new(RwLock::new(config.listeners.clone())), state_storage, llm_provider_url, + http_client: reqwest::Client::new(), }) } @@ -206,31 +207,18 @@ async fn route( stripped, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH ) { - return agent_chat( - req, - Arc::clone(&state.orchestrator_service), - Arc::clone(&state.agents_list), - Arc::clone(&state.listeners), - ) - .with_context(parent_cx) - .await; + return agent_chat(req, Arc::clone(&state)) + .with_context(parent_cx) + .await; } } // --- Standard routes --- match (req.method(), path.as_str()) { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => { - let url = format!("{}{}", state.llm_provider_url, path); - llm_chat( - req, - Arc::clone(&state.router_service), - url, - Arc::clone(&state.model_aliases), - Arc::clone(&state.llm_providers), - state.state_storage.clone(), - ) - .with_context(parent_cx) - .await + llm_chat(req, Arc::clone(&state)) + .with_context(parent_cx) + .await } (&Method::POST, "/function_calling") => { let url = format!("{}/v1/chat/completions", state.llm_provider_url); diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 6067a1b8..87330b29 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -99,7 +99,8 @@ impl RouterService { "sending request to arch-router" ); - let body = serde_json::to_string(&router_request).unwrap(); + let body = serde_json::to_string(&router_request) + .map_err(super::router_model::RoutingModelError::from)?; debug!(body = %body, "arch router request"); let mut headers = header::HeaderMap::new(); diff --git a/crates/brightstaff/src/router/orchestrator.rs b/crates/brightstaff/src/router/orchestrator.rs index f587e749..19588725 100644 --- a/crates/brightstaff/src/router/orchestrator.rs +++ b/crates/brightstaff/src/router/orchestrator.rs @@ -74,7 +74,8 @@ impl OrchestratorService { "sending request to arch-orchestrator" ); - let body = serde_json::to_string(&orchestrator_request).unwrap(); + let body = serde_json::to_string(&orchestrator_request) + .map_err(super::orchestrator_model::OrchestratorModelError::from)?; debug!(body = %body, "arch orchestrator request"); let mut headers = header::HeaderMap::new(); diff --git a/crates/brightstaff/src/state/mod.rs b/crates/brightstaff/src/state/mod.rs index a39a4576..da8c4077 100644 --- a/crates/brightstaff/src/state/mod.rs +++ b/crates/brightstaff/src/state/mod.rs @@ -88,12 +88,11 @@ pub trait StateStorage: Send + Sync { combined_input.extend(current_input); debug!( - "PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}", - prev_state.response_id, - prev_count, - current_count, - combined_input.len(), - serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string()) + response_id = %prev_state.response_id, + prev_items = prev_count, + current_items = current_count, + total_items = combined_input.len(), + "merged conversation state" ); combined_input