diff --git a/crates/brightstaff/src/handlers/agents/orchestrator.rs b/crates/brightstaff/src/handlers/agents/orchestrator.rs index 68cc3c6f..b3aa8116 100644 --- a/crates/brightstaff/src/handlers/agents/orchestrator.rs +++ b/crates/brightstaff/src/handlers/agents/orchestrator.rs @@ -89,7 +89,7 @@ pub async fn agent_chat( .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); response.headers_mut().insert( hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), + hyper::header::HeaderValue::from_static("application/json"), ); return Ok(response); } @@ -102,16 +102,22 @@ pub async fn agent_chat( .await } -async fn handle_agent_chat_inner( +/// Parsed and validated agent request data. +struct AgentRequest { + client_request: ProviderRequestType, + messages: Vec, + request_headers: hyper::HeaderMap, + request_id: Option, +} + +/// Parse the incoming HTTP request, resolve the listener, and extract messages. +async fn parse_agent_request( request: Request, - state: Arc, - request_id: String, - custom_attrs: std::collections::HashMap, -) -> Result>, AgentFilterChainError> { - // Initialize services + state: &AppState, + request_id: &str, + custom_attrs: &std::collections::HashMap, +) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> { let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service)); - let mut pipeline_processor = PipelineProcessor::default(); - let response_handler = ResponseHandler::new(); // Extract listener name from headers let listener_name = request @@ -129,7 +135,7 @@ async fn handle_agent_chat_inner( get_active_span(|span| { span.update_name(listener.name.to_string()); - for (key, value) in &custom_attrs { + for (key, value) in custom_attrs { span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); } }); @@ -147,12 +153,10 @@ async fn handle_agent_chat_inner( let mut headers = request.headers().clone(); headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER); - // Set the request_id in headers if not already present if !headers.contains_key(common::consts::REQUEST_ID_HEADER) { - headers.insert( - common::consts::REQUEST_ID_HEADER, - hyper::header::HeaderValue::from_str(&request_id).unwrap(), - ); + if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) { + headers.insert(common::consts::REQUEST_ID_HEADER, val); + } } headers @@ -165,7 +169,6 @@ async fn handle_agent_chat_inner( "received request body" ); - // Determine the API type from the endpoint let api_type = SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| { let err_msg = format!("Unsupported endpoint: {}", request_path); @@ -173,25 +176,48 @@ async fn handle_agent_chat_inner( AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) })?; - let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { - Ok(request) => request, - Err(err) => { + let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) + .map_err(|err| { warn!("failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); - return Err(AgentFilterChainError::RequestParsing( - serde_json::Error::custom(err_msg), - )); - } - }; + AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!( + "Failed to parse request: {}", + err + ))) + })?; - let message: Vec = client_request.get_messages(); + let messages: Vec = client_request.get_messages(); let request_id = request_headers .get(common::consts::REQUEST_ID_HEADER) .and_then(|val| val.to_str().ok()) .map(|s| s.to_string()); - // Create agent map for pipeline processing and agent selection + Ok(( + AgentRequest { + client_request, + messages, + request_headers, + request_id, + }, + listener, + agent_selector, + )) +} + +/// Select agents via the orchestrator model and record selection metrics. +async fn select_and_build_agent_map( + agent_selector: &AgentSelector, + state: &AppState, + messages: &[OpenAIMessage], + listener: &common::configuration::Listener, + request_id: Option, +) -> Result< + ( + Vec, + std::collections::HashMap, + ), + AgentFilterChainError, +> { let agent_map = { let agents = state.agents_list.read().await; let agents = agents.as_ref().ok_or_else(|| { @@ -200,13 +226,11 @@ async fn handle_agent_chat_inner( agent_selector.create_agent_map(agents) }; - // Select appropriate agents using arch orchestrator llm model let selection_start = Instant::now(); let selected_agents = agent_selector - .select_agents(&message, &listener, request_id.clone()) + .select_agents(messages, listener, request_id) .await?; - // Record selection attributes on the current orchestrator span let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0; get_active_span(|span| { span.set_attribute(opentelemetry::KeyValue::new( @@ -236,12 +260,25 @@ async fn handle_agent_chat_inner( "selected agents for execution" ); - // Execute agents sequentially, passing output from one to the next - let mut current_messages = message.clone(); + Ok((selected_agents, agent_map)) +} + +/// Execute the agent chain: run each selected agent sequentially, streaming +/// the final agent's response back to the client. +async fn execute_agent_chain( + selected_agents: &[common::configuration::AgentFilterChain], + agent_map: &std::collections::HashMap, + client_request: ProviderRequestType, + messages: Vec, + request_headers: &hyper::HeaderMap, + custom_attrs: &std::collections::HashMap, +) -> Result>, AgentFilterChainError> { + let mut pipeline_processor = PipelineProcessor::default(); + let response_handler = ResponseHandler::new(); + let mut current_messages = messages; let agent_count = selected_agents.len(); for (agent_index, selected_agent) in selected_agents.iter().enumerate() { - // Get agent name let agent_name = selected_agent.id.clone(); let is_last_agent = agent_index == agent_count - 1; @@ -252,17 +289,15 @@ async fn handle_agent_chat_inner( "processing agent" ); - // Process the filter chain let chat_history = pipeline_processor .process_filter_chain( ¤t_messages, selected_agent, - &agent_map, - &request_headers, + agent_map, + request_headers, ) .await?; - // Get agent details and invoke let agent = agent_map.get(&agent_name).ok_or_else(|| { AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!( "Selected agent '{}' not found in configuration", @@ -282,7 +317,7 @@ async fn handle_agent_chat_inner( set_service_name(operation_component::AGENT); get_active_span(|span| { span.update_name(format!("{} /v1/chat/completions", agent_name)); - for (key, value) in &custom_attrs { + for (key, value) in custom_attrs { span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); } }); @@ -292,28 +327,25 @@ async fn handle_agent_chat_inner( &chat_history, client_request.clone(), agent, - &request_headers, + request_headers, ) .await } .instrument(agent_span.clone()) .await?; - // If this is the last agent, return the streaming response if is_last_agent { info!( agent = %agent_name, "completed agent chain, returning response" ); - // Capture the orchestrator span (parent of the agent span) so it - // stays open for the full streaming duration alongside the agent span. let orchestrator_span = tracing::Span::current(); return async { response_handler .create_streaming_response( llm_response, - tracing::Span::current(), // agent span (inner) - orchestrator_span, // orchestrator span (outer) + tracing::Span::current(), + orchestrator_span, ) .await .map_err(AgentFilterChainError::from) @@ -322,7 +354,6 @@ async fn handle_agent_chat_inner( .await; } - // For intermediate agents, collect the full response and pass to next agent debug!(agent = %agent_name, "collecting response from intermediate agent"); let response_text = async { response_handler.collect_full_response(llm_response).await } .instrument(agent_span) @@ -334,14 +365,11 @@ async fn handle_agent_chat_inner( "agent completed, passing response to next agent" ); - // remove last message and add new one at the end let Some(last_message) = current_messages.pop() else { warn!(agent = %agent_name, "no messages in conversation history"); break; }; - // Create a new message with the agent's response as assistant message - // and add it to the conversation history current_messages.push(OpenAIMessage { role: hermesllm::apis::openai::Role::Assistant, content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)), @@ -353,6 +381,34 @@ async fn handle_agent_chat_inner( current_messages.push(last_message); } - // This should never be reached since we return in the last agent iteration unreachable!("Agent execution loop should have returned a response") } + +async fn handle_agent_chat_inner( + request: Request, + state: Arc, + request_id: String, + custom_attrs: std::collections::HashMap, +) -> Result>, AgentFilterChainError> { + let (agent_req, listener, agent_selector) = + parse_agent_request(request, &state, &request_id, &custom_attrs).await?; + + let (selected_agents, agent_map) = select_and_build_agent_map( + &agent_selector, + &state, + &agent_req.messages, + &listener, + agent_req.request_id, + ) + .await?; + + execute_agent_chain( + &selected_agents, + &agent_map, + agent_req.client_request, + agent_req.messages, + &agent_req.request_headers, + &custom_attrs, + ) + .await +} diff --git a/crates/brightstaff/src/handlers/agents/pipeline.rs b/crates/brightstaff/src/handlers/agents/pipeline.rs index 0b269ce1..ac71fe72 100644 --- a/crates/brightstaff/src/handlers/agents/pipeline.rs +++ b/crates/brightstaff/src/handlers/agents/pipeline.rs @@ -141,12 +141,11 @@ impl PipelineProcessor { Ok(chat_history_updated) } - /// Build common MCP headers for requests - fn build_mcp_headers( - &self, + /// Prepare headers shared by all agent/filter requests: removes + /// content-length, injects trace context, sets upstream host and retry. + fn build_agent_headers( request_headers: &HeaderMap, agent_id: &str, - session_id: Option<&str>, ) -> Result { let mut headers = request_headers.clone(); headers.remove(hyper::header::CONTENT_LENGTH); @@ -167,24 +166,34 @@ impl PipelineProcessor { headers.insert( ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), + hyper::header::HeaderValue::from_static("3"), ); + Ok(headers) + } + + /// Build headers for MCP requests (adds Accept, Content-Type, optional session id). + fn build_mcp_headers( + &self, + request_headers: &HeaderMap, + agent_id: &str, + session_id: Option<&str>, + ) -> Result { + let mut headers = Self::build_agent_headers(request_headers, agent_id)?; + headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json, text/event-stream"), ); - headers.insert( "Content-Type", hyper::header::HeaderValue::from_static("application/json"), ); if let Some(sid) = session_id { - headers.insert( - "mcp-session-id", - hyper::header::HeaderValue::from_str(sid).unwrap(), - ); + if let Ok(val) = hyper::header::HeaderValue::from_str(sid) { + headers.insert("mcp-session-id", val); + } } Ok(headers) @@ -530,33 +539,11 @@ impl PipelineProcessor { }); // Build headers - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - // Inject OpenTelemetry trace context automatically - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&agent.id) - .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, - ); - - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); - + let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?; agent_headers.insert( "Accept", hyper::header::HeaderValue::from_static("application/json"), ); - agent_headers.insert( "Content-Type", hyper::header::HeaderValue::from_static("application/json"), @@ -629,27 +616,7 @@ impl PipelineProcessor { .map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?; debug!("sending request to terminal agent {}", terminal_agent.id); - let mut agent_headers = request_headers.clone(); - agent_headers.remove(hyper::header::CONTENT_LENGTH); - - // Inject OpenTelemetry trace context automatically - agent_headers.remove(TRACE_PARENT_HEADER); - global::get_text_map_propagator(|propagator| { - let cx = - tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); - propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); - }); - - agent_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(&terminal_agent.id) - .map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?, - ); - - agent_headers.insert( - ENVOY_RETRY_HEADER, - hyper::header::HeaderValue::from_str("3").unwrap(), - ); + let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?; let response = self .client diff --git a/crates/brightstaff/src/handlers/llm/mod.rs b/crates/brightstaff/src/handlers/llm/mod.rs index 724bb6f2..64e780a4 100644 --- a/crates/brightstaff/src/handlers/llm/mod.rs +++ b/crates/brightstaff/src/handlers/llm/mod.rs @@ -497,13 +497,16 @@ async fn send_upstream( "Routing to upstream" ); - request_headers.insert( - ARCH_PROVIDER_HINT_HEADER, - header::HeaderValue::from_str(resolved_model).unwrap(), - ); + if let Ok(val) = header::HeaderValue::from_str(resolved_model) { + request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val); + } request_headers.insert( header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), - header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), + header::HeaderValue::from_static(if is_streaming_request { + "true" + } else { + "false" + }), ); request_headers.remove(header::CONTENT_LENGTH); @@ -535,9 +538,10 @@ async fn send_upstream( let response_headers = llm_response.headers().clone(); let upstream_status = llm_response.status(); let mut response = Response::builder().status(upstream_status); - let headers = response.headers_mut().unwrap(); - for (name, value) in response_headers.iter() { - headers.insert(name, value.clone()); + if let Some(headers) = response.headers_mut() { + for (name, value) in response_headers.iter() { + headers.insert(name, value.clone()); + } } let byte_stream = llm_response.bytes_stream(); @@ -637,8 +641,9 @@ async fn get_upstream_path( ) -> String { let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; - let client_api = SupportedAPIsFromClient::from_endpoint(request_path) - .expect("Should have valid API endpoint"); + let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else { + return request_path.to_string(); + }; client_api.target_endpoint_for_provider( &provider_id, diff --git a/crates/brightstaff/src/handlers/response.rs b/crates/brightstaff/src/handlers/response.rs index 7861ba16..0a6bbe4c 100644 --- a/crates/brightstaff/src/handlers/response.rs +++ b/crates/brightstaff/src/handlers/response.rs @@ -36,7 +36,7 @@ impl ResponseHandler { *response.status_mut() = hyper::StatusCode::BAD_REQUEST; response.headers_mut().insert( hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), + hyper::header::HeaderValue::from_static("application/json"), ); response } @@ -55,7 +55,7 @@ impl ResponseHandler { *response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR; response.headers_mut().insert( hyper::header::CONTENT_TYPE, - "application/json".parse().unwrap(), + hyper::header::HeaderValue::from_static("application/json"), ); response } diff --git a/crates/brightstaff/src/router/llm.rs b/crates/brightstaff/src/router/llm.rs index 87330b29..2951eee6 100644 --- a/crates/brightstaff/src/router/llm.rs +++ b/crates/brightstaff/src/router/llm.rs @@ -83,7 +83,9 @@ impl RouterService { return Ok(None); } - if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2) + if usage_preferences + .as_ref() + .is_none_or(|prefs| prefs.len() < 2) && !self.llm_usage_defined { return Ok(None); @@ -108,18 +110,18 @@ impl RouterService { header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"), ); - headers.insert( - header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(&self.routing_provider_name).unwrap(), - ); - headers.insert( - header::HeaderName::from_static(TRACE_PARENT_HEADER), - header::HeaderValue::from_str(traceparent).unwrap(), - ); - headers.insert( - header::HeaderName::from_static(REQUEST_ID_HEADER), - header::HeaderValue::from_str(request_id).unwrap(), - ); + if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) { + headers.insert( + header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), + val, + ); + } + if let Ok(val) = header::HeaderValue::from_str(traceparent) { + headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val); + } + if let Ok(val) = header::HeaderValue::from_str(request_id) { + headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val); + } headers.insert( header::HeaderName::from_static("model"), header::HeaderValue::from_static("arch-router"), diff --git a/crates/brightstaff/src/router/orchestrator.rs b/crates/brightstaff/src/router/orchestrator.rs index 19588725..42ade470 100644 --- a/crates/brightstaff/src/router/orchestrator.rs +++ b/crates/brightstaff/src/router/orchestrator.rs @@ -60,7 +60,10 @@ impl OrchestratorService { return Ok(None); } - if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() { + if usage_preferences + .as_ref() + .is_none_or(|prefs| prefs.is_empty()) + { return Ok(None); } @@ -85,7 +88,7 @@ impl OrchestratorService { ); headers.insert( header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), - header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(), + header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME), ); // Inject OpenTelemetry trace context from current span @@ -96,10 +99,9 @@ impl OrchestratorService { }); if let Some(ref request_id) = request_id { - headers.insert( - header::HeaderName::from_static(REQUEST_ID_HEADER), - header::HeaderValue::from_str(request_id).unwrap(), - ); + if let Ok(val) = header::HeaderValue::from_str(request_id) { + headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val); + } } headers.insert(