refactor: decompose orchestrator handler, deduplicate headers, fix unwraps

This commit is contained in:
Adil Hafeez 2026-03-06 23:04:35 +00:00
parent 2c7d3a9c6c
commit dd74df6ca8
6 changed files with 166 additions and 134 deletions

View file

@ -89,7 +89,7 @@ pub async fn agent_chat(
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
response.headers_mut().insert( response.headers_mut().insert(
hyper::header::CONTENT_TYPE, hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(), hyper::header::HeaderValue::from_static("application/json"),
); );
return Ok(response); return Ok(response);
} }
@ -102,16 +102,22 @@ pub async fn agent_chat(
.await .await
} }
async fn handle_agent_chat_inner( /// Parsed and validated agent request data.
struct AgentRequest {
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: hyper::HeaderMap,
request_id: Option<String>,
}
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
async fn parse_agent_request(
request: Request<hyper::body::Incoming>, request: Request<hyper::body::Incoming>,
state: Arc<AppState>, state: &AppState,
request_id: String, request_id: &str,
custom_attrs: std::collections::HashMap<String, String>, custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> { ) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service)); let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers // Extract listener name from headers
let listener_name = request let listener_name = request
@ -129,7 +135,7 @@ async fn handle_agent_chat_inner(
get_active_span(|span| { get_active_span(|span| {
span.update_name(listener.name.to_string()); span.update_name(listener.name.to_string());
for (key, value) in &custom_attrs { for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
} }
}); });
@ -147,12 +153,10 @@ async fn handle_agent_chat_inner(
let mut headers = request.headers().clone(); let mut headers = request.headers().clone();
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER); headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
// Set the request_id in headers if not already present
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) { if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
headers.insert( if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
common::consts::REQUEST_ID_HEADER, headers.insert(common::consts::REQUEST_ID_HEADER, val);
hyper::header::HeaderValue::from_str(&request_id).unwrap(), }
);
} }
headers headers
@ -165,7 +169,6 @@ async fn handle_agent_chat_inner(
"received request body" "received request body"
); );
// Determine the API type from the endpoint
let api_type = let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| { SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
let err_msg = format!("Unsupported endpoint: {}", request_path); let err_msg = format!("Unsupported endpoint: {}", request_path);
@ -173,25 +176,48 @@ async fn handle_agent_chat_inner(
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
})?; })?;
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
Ok(request) => request, .map_err(|err| {
Err(err) => {
warn!("failed to parse request as ProviderRequestType: {}", err); warn!("failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err); AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
return Err(AgentFilterChainError::RequestParsing( "Failed to parse request: {}",
serde_json::Error::custom(err_msg), err
)); )))
} })?;
};
let message: Vec<OpenAIMessage> = client_request.get_messages(); let messages: Vec<OpenAIMessage> = client_request.get_messages();
let request_id = request_headers let request_id = request_headers
.get(common::consts::REQUEST_ID_HEADER) .get(common::consts::REQUEST_ID_HEADER)
.and_then(|val| val.to_str().ok()) .and_then(|val| val.to_str().ok())
.map(|s| s.to_string()); .map(|s| s.to_string());
// Create agent map for pipeline processing and agent selection Ok((
AgentRequest {
client_request,
messages,
request_headers,
request_id,
},
listener,
agent_selector,
))
}
/// Select agents via the orchestrator model and record selection metrics.
async fn select_and_build_agent_map(
agent_selector: &AgentSelector,
state: &AppState,
messages: &[OpenAIMessage],
listener: &common::configuration::Listener,
request_id: Option<String>,
) -> Result<
(
Vec<common::configuration::AgentFilterChain>,
std::collections::HashMap<String, common::configuration::Agent>,
),
AgentFilterChainError,
> {
let agent_map = { let agent_map = {
let agents = state.agents_list.read().await; let agents = state.agents_list.read().await;
let agents = agents.as_ref().ok_or_else(|| { let agents = agents.as_ref().ok_or_else(|| {
@ -200,13 +226,11 @@ async fn handle_agent_chat_inner(
agent_selector.create_agent_map(agents) agent_selector.create_agent_map(agents)
}; };
// Select appropriate agents using arch orchestrator llm model
let selection_start = Instant::now(); let selection_start = Instant::now();
let selected_agents = agent_selector let selected_agents = agent_selector
.select_agents(&message, &listener, request_id.clone()) .select_agents(messages, listener, request_id)
.await?; .await?;
// Record selection attributes on the current orchestrator span
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0; let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
get_active_span(|span| { get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new( span.set_attribute(opentelemetry::KeyValue::new(
@ -236,12 +260,25 @@ async fn handle_agent_chat_inner(
"selected agents for execution" "selected agents for execution"
); );
// Execute agents sequentially, passing output from one to the next Ok((selected_agents, agent_map))
let mut current_messages = message.clone(); }
/// Execute the agent chain: run each selected agent sequentially, streaming
/// the final agent's response back to the client.
async fn execute_agent_chain(
selected_agents: &[common::configuration::AgentFilterChain],
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: &hyper::HeaderMap,
custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
let mut current_messages = messages;
let agent_count = selected_agents.len(); let agent_count = selected_agents.len();
for (agent_index, selected_agent) in selected_agents.iter().enumerate() { for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
// Get agent name
let agent_name = selected_agent.id.clone(); let agent_name = selected_agent.id.clone();
let is_last_agent = agent_index == agent_count - 1; let is_last_agent = agent_index == agent_count - 1;
@ -252,17 +289,15 @@ async fn handle_agent_chat_inner(
"processing agent" "processing agent"
); );
// Process the filter chain
let chat_history = pipeline_processor let chat_history = pipeline_processor
.process_filter_chain( .process_filter_chain(
&current_messages, &current_messages,
selected_agent, selected_agent,
&agent_map, agent_map,
&request_headers, request_headers,
) )
.await?; .await?;
// Get agent details and invoke
let agent = agent_map.get(&agent_name).ok_or_else(|| { let agent = agent_map.get(&agent_name).ok_or_else(|| {
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!( AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
"Selected agent '{}' not found in configuration", "Selected agent '{}' not found in configuration",
@ -282,7 +317,7 @@ async fn handle_agent_chat_inner(
set_service_name(operation_component::AGENT); set_service_name(operation_component::AGENT);
get_active_span(|span| { get_active_span(|span| {
span.update_name(format!("{} /v1/chat/completions", agent_name)); span.update_name(format!("{} /v1/chat/completions", agent_name));
for (key, value) in &custom_attrs { for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone())); span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
} }
}); });
@ -292,28 +327,25 @@ async fn handle_agent_chat_inner(
&chat_history, &chat_history,
client_request.clone(), client_request.clone(),
agent, agent,
&request_headers, request_headers,
) )
.await .await
} }
.instrument(agent_span.clone()) .instrument(agent_span.clone())
.await?; .await?;
// If this is the last agent, return the streaming response
if is_last_agent { if is_last_agent {
info!( info!(
agent = %agent_name, agent = %agent_name,
"completed agent chain, returning response" "completed agent chain, returning response"
); );
// Capture the orchestrator span (parent of the agent span) so it
// stays open for the full streaming duration alongside the agent span.
let orchestrator_span = tracing::Span::current(); let orchestrator_span = tracing::Span::current();
return async { return async {
response_handler response_handler
.create_streaming_response( .create_streaming_response(
llm_response, llm_response,
tracing::Span::current(), // agent span (inner) tracing::Span::current(),
orchestrator_span, // orchestrator span (outer) orchestrator_span,
) )
.await .await
.map_err(AgentFilterChainError::from) .map_err(AgentFilterChainError::from)
@ -322,7 +354,6 @@ async fn handle_agent_chat_inner(
.await; .await;
} }
// For intermediate agents, collect the full response and pass to next agent
debug!(agent = %agent_name, "collecting response from intermediate agent"); debug!(agent = %agent_name, "collecting response from intermediate agent");
let response_text = async { response_handler.collect_full_response(llm_response).await } let response_text = async { response_handler.collect_full_response(llm_response).await }
.instrument(agent_span) .instrument(agent_span)
@ -334,14 +365,11 @@ async fn handle_agent_chat_inner(
"agent completed, passing response to next agent" "agent completed, passing response to next agent"
); );
// remove last message and add new one at the end
let Some(last_message) = current_messages.pop() else { let Some(last_message) = current_messages.pop() else {
warn!(agent = %agent_name, "no messages in conversation history"); warn!(agent = %agent_name, "no messages in conversation history");
break; break;
}; };
// Create a new message with the agent's response as assistant message
// and add it to the conversation history
current_messages.push(OpenAIMessage { current_messages.push(OpenAIMessage {
role: hermesllm::apis::openai::Role::Assistant, role: hermesllm::apis::openai::Role::Assistant,
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)), content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
@ -353,6 +381,34 @@ async fn handle_agent_chat_inner(
current_messages.push(last_message); current_messages.push(last_message);
} }
// This should never be reached since we return in the last agent iteration
unreachable!("Agent execution loop should have returned a response") unreachable!("Agent execution loop should have returned a response")
} }
async fn handle_agent_chat_inner(
request: Request<hyper::body::Incoming>,
state: Arc<AppState>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let (agent_req, listener, agent_selector) =
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
let (selected_agents, agent_map) = select_and_build_agent_map(
&agent_selector,
&state,
&agent_req.messages,
&listener,
agent_req.request_id,
)
.await?;
execute_agent_chain(
&selected_agents,
&agent_map,
agent_req.client_request,
agent_req.messages,
&agent_req.request_headers,
&custom_attrs,
)
.await
}

View file

@ -141,12 +141,11 @@ impl PipelineProcessor {
Ok(chat_history_updated) Ok(chat_history_updated)
} }
/// Build common MCP headers for requests /// Prepare headers shared by all agent/filter requests: removes
fn build_mcp_headers( /// content-length, injects trace context, sets upstream host and retry.
&self, fn build_agent_headers(
request_headers: &HeaderMap, request_headers: &HeaderMap,
agent_id: &str, agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> { ) -> Result<HeaderMap, PipelineError> {
let mut headers = request_headers.clone(); let mut headers = request_headers.clone();
headers.remove(hyper::header::CONTENT_LENGTH); headers.remove(hyper::header::CONTENT_LENGTH);
@ -167,24 +166,34 @@ impl PipelineProcessor {
headers.insert( headers.insert(
ENVOY_RETRY_HEADER, ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(), hyper::header::HeaderValue::from_static("3"),
); );
Ok(headers)
}
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
fn build_mcp_headers(
&self,
request_headers: &HeaderMap,
agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> {
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
headers.insert( headers.insert(
"Accept", "Accept",
hyper::header::HeaderValue::from_static("application/json, text/event-stream"), hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
); );
headers.insert( headers.insert(
"Content-Type", "Content-Type",
hyper::header::HeaderValue::from_static("application/json"), hyper::header::HeaderValue::from_static("application/json"),
); );
if let Some(sid) = session_id { if let Some(sid) = session_id {
headers.insert( if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
"mcp-session-id", headers.insert("mcp-session-id", val);
hyper::header::HeaderValue::from_str(sid).unwrap(), }
);
} }
Ok(headers) Ok(headers)
@ -530,33 +539,11 @@ impl PipelineProcessor {
}); });
// Build headers // Build headers
let mut agent_headers = request_headers.clone(); let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
agent_headers.insert( agent_headers.insert(
"Accept", "Accept",
hyper::header::HeaderValue::from_static("application/json"), hyper::header::HeaderValue::from_static("application/json"),
); );
agent_headers.insert( agent_headers.insert(
"Content-Type", "Content-Type",
hyper::header::HeaderValue::from_static("application/json"), hyper::header::HeaderValue::from_static("application/json"),
@ -629,27 +616,7 @@ impl PipelineProcessor {
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?; .map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
debug!("sending request to terminal agent {}", terminal_agent.id); debug!("sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone(); let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self let response = self
.client .client

View file

@ -497,13 +497,16 @@ async fn send_upstream(
"Routing to upstream" "Routing to upstream"
); );
request_headers.insert( if let Ok(val) = header::HeaderValue::from_str(resolved_model) {
ARCH_PROVIDER_HINT_HEADER, request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val);
header::HeaderValue::from_str(resolved_model).unwrap(), }
);
request_headers.insert( request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER), header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(), header::HeaderValue::from_static(if is_streaming_request {
"true"
} else {
"false"
}),
); );
request_headers.remove(header::CONTENT_LENGTH); request_headers.remove(header::CONTENT_LENGTH);
@ -535,10 +538,11 @@ async fn send_upstream(
let response_headers = llm_response.headers().clone(); let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status(); let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status); let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap(); if let Some(headers) = response.headers_mut() {
for (name, value) in response_headers.iter() { for (name, value) in response_headers.iter() {
headers.insert(name, value.clone()); headers.insert(name, value.clone());
} }
}
let byte_stream = llm_response.bytes_stream(); let byte_stream = llm_response.bytes_stream();
@ -637,8 +641,9 @@ async fn get_upstream_path(
) -> String { ) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await; let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
let client_api = SupportedAPIsFromClient::from_endpoint(request_path) let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else {
.expect("Should have valid API endpoint"); return request_path.to_string();
};
client_api.target_endpoint_for_provider( client_api.target_endpoint_for_provider(
&provider_id, &provider_id,

View file

@ -36,7 +36,7 @@ impl ResponseHandler {
*response.status_mut() = hyper::StatusCode::BAD_REQUEST; *response.status_mut() = hyper::StatusCode::BAD_REQUEST;
response.headers_mut().insert( response.headers_mut().insert(
hyper::header::CONTENT_TYPE, hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(), hyper::header::HeaderValue::from_static("application/json"),
); );
response response
} }
@ -55,7 +55,7 @@ impl ResponseHandler {
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR; *response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert( response.headers_mut().insert(
hyper::header::CONTENT_TYPE, hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(), hyper::header::HeaderValue::from_static("application/json"),
); );
response response
} }

View file

@ -83,7 +83,9 @@ impl RouterService {
return Ok(None); return Ok(None);
} }
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2) if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.len() < 2)
&& !self.llm_usage_defined && !self.llm_usage_defined
{ {
return Ok(None); return Ok(None);
@ -108,18 +110,18 @@ impl RouterService {
header::CONTENT_TYPE, header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"), header::HeaderValue::from_static("application/json"),
); );
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
headers.insert( headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(), val,
);
headers.insert(
header::HeaderName::from_static(TRACE_PARENT_HEADER),
header::HeaderValue::from_str(traceparent).unwrap(),
);
headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(request_id).unwrap(),
); );
}
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
}
if let Ok(val) = header::HeaderValue::from_str(request_id) {
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
}
headers.insert( headers.insert(
header::HeaderName::from_static("model"), header::HeaderName::from_static("model"),
header::HeaderValue::from_static("arch-router"), header::HeaderValue::from_static("arch-router"),

View file

@ -60,7 +60,10 @@ impl OrchestratorService {
return Ok(None); return Ok(None);
} }
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() { if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.is_empty())
{
return Ok(None); return Ok(None);
} }
@ -85,7 +88,7 @@ impl OrchestratorService {
); );
headers.insert( headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER), header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(), header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
); );
// Inject OpenTelemetry trace context from current span // Inject OpenTelemetry trace context from current span
@ -96,10 +99,9 @@ impl OrchestratorService {
}); });
if let Some(ref request_id) = request_id { if let Some(ref request_id) = request_id {
headers.insert( if let Ok(val) = header::HeaderValue::from_str(request_id) {
header::HeaderName::from_static(REQUEST_ID_HEADER), headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
header::HeaderValue::from_str(request_id).unwrap(), }
);
} }
headers.insert( headers.insert(