refactor: decompose orchestrator handler, deduplicate headers, fix unwraps

This commit is contained in:
Adil Hafeez 2026-03-06 23:04:35 +00:00
parent 2c7d3a9c6c
commit dd74df6ca8
6 changed files with 166 additions and 134 deletions

View file

@ -89,7 +89,7 @@ pub async fn agent_chat(
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
hyper::header::HeaderValue::from_static("application/json"),
);
return Ok(response);
}
@ -102,16 +102,22 @@ pub async fn agent_chat(
.await
}
async fn handle_agent_chat_inner(
/// Parsed and validated agent request data.
struct AgentRequest {
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: hyper::HeaderMap,
request_id: Option<String>,
}
/// Parse the incoming HTTP request, resolve the listener, and extract messages.
async fn parse_agent_request(
request: Request<hyper::body::Incoming>,
state: Arc<AppState>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
state: &AppState,
request_id: &str,
custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<(AgentRequest, common::configuration::Listener, AgentSelector), AgentFilterChainError> {
let agent_selector = AgentSelector::new(Arc::clone(&state.orchestrator_service));
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers
let listener_name = request
@ -129,7 +135,7 @@ async fn handle_agent_chat_inner(
get_active_span(|span| {
span.update_name(listener.name.to_string());
for (key, value) in &custom_attrs {
for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
@ -147,12 +153,10 @@ async fn handle_agent_chat_inner(
let mut headers = request.headers().clone();
headers.remove(common::consts::ENVOY_ORIGINAL_PATH_HEADER);
// Set the request_id in headers if not already present
if !headers.contains_key(common::consts::REQUEST_ID_HEADER) {
headers.insert(
common::consts::REQUEST_ID_HEADER,
hyper::header::HeaderValue::from_str(&request_id).unwrap(),
);
if let Ok(val) = hyper::header::HeaderValue::from_str(request_id) {
headers.insert(common::consts::REQUEST_ID_HEADER, val);
}
}
headers
@ -165,7 +169,6 @@ async fn handle_agent_chat_inner(
"received request body"
);
// Determine the API type from the endpoint
let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).ok_or_else(|| {
let err_msg = format!("Unsupported endpoint: {}", request_path);
@ -173,25 +176,48 @@ async fn handle_agent_chat_inner(
AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg))
})?;
let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) {
Ok(request) => request,
Err(err) => {
let client_request = ProviderRequestType::try_from((&chat_request_bytes[..], &api_type))
.map_err(|err| {
warn!("failed to parse request as ProviderRequestType: {}", err);
let err_msg = format!("Failed to parse request: {}", err);
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom(err_msg),
));
}
};
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
"Failed to parse request: {}",
err
)))
})?;
let message: Vec<OpenAIMessage> = client_request.get_messages();
let messages: Vec<OpenAIMessage> = client_request.get_messages();
let request_id = request_headers
.get(common::consts::REQUEST_ID_HEADER)
.and_then(|val| val.to_str().ok())
.map(|s| s.to_string());
// Create agent map for pipeline processing and agent selection
Ok((
AgentRequest {
client_request,
messages,
request_headers,
request_id,
},
listener,
agent_selector,
))
}
/// Select agents via the orchestrator model and record selection metrics.
async fn select_and_build_agent_map(
agent_selector: &AgentSelector,
state: &AppState,
messages: &[OpenAIMessage],
listener: &common::configuration::Listener,
request_id: Option<String>,
) -> Result<
(
Vec<common::configuration::AgentFilterChain>,
std::collections::HashMap<String, common::configuration::Agent>,
),
AgentFilterChainError,
> {
let agent_map = {
let agents = state.agents_list.read().await;
let agents = agents.as_ref().ok_or_else(|| {
@ -200,13 +226,11 @@ async fn handle_agent_chat_inner(
agent_selector.create_agent_map(agents)
};
// Select appropriate agents using arch orchestrator llm model
let selection_start = Instant::now();
let selected_agents = agent_selector
.select_agents(&message, &listener, request_id.clone())
.select_agents(messages, listener, request_id)
.await?;
// Record selection attributes on the current orchestrator span
let selection_elapsed_ms = selection_start.elapsed().as_secs_f64() * 1000.0;
get_active_span(|span| {
span.set_attribute(opentelemetry::KeyValue::new(
@ -236,12 +260,25 @@ async fn handle_agent_chat_inner(
"selected agents for execution"
);
// Execute agents sequentially, passing output from one to the next
let mut current_messages = message.clone();
Ok((selected_agents, agent_map))
}
/// Execute the agent chain: run each selected agent sequentially, streaming
/// the final agent's response back to the client.
async fn execute_agent_chain(
selected_agents: &[common::configuration::AgentFilterChain],
agent_map: &std::collections::HashMap<String, common::configuration::Agent>,
client_request: ProviderRequestType,
messages: Vec<OpenAIMessage>,
request_headers: &hyper::HeaderMap,
custom_attrs: &std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let mut pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
let mut current_messages = messages;
let agent_count = selected_agents.len();
for (agent_index, selected_agent) in selected_agents.iter().enumerate() {
// Get agent name
let agent_name = selected_agent.id.clone();
let is_last_agent = agent_index == agent_count - 1;
@ -252,17 +289,15 @@ async fn handle_agent_chat_inner(
"processing agent"
);
// Process the filter chain
let chat_history = pipeline_processor
.process_filter_chain(
&current_messages,
selected_agent,
&agent_map,
&request_headers,
agent_map,
request_headers,
)
.await?;
// Get agent details and invoke
let agent = agent_map.get(&agent_name).ok_or_else(|| {
AgentFilterChainError::RequestParsing(serde_json::Error::custom(format!(
"Selected agent '{}' not found in configuration",
@ -282,7 +317,7 @@ async fn handle_agent_chat_inner(
set_service_name(operation_component::AGENT);
get_active_span(|span| {
span.update_name(format!("{} /v1/chat/completions", agent_name));
for (key, value) in &custom_attrs {
for (key, value) in custom_attrs {
span.set_attribute(opentelemetry::KeyValue::new(key.clone(), value.clone()));
}
});
@ -292,28 +327,25 @@ async fn handle_agent_chat_inner(
&chat_history,
client_request.clone(),
agent,
&request_headers,
request_headers,
)
.await
}
.instrument(agent_span.clone())
.await?;
// If this is the last agent, return the streaming response
if is_last_agent {
info!(
agent = %agent_name,
"completed agent chain, returning response"
);
// Capture the orchestrator span (parent of the agent span) so it
// stays open for the full streaming duration alongside the agent span.
let orchestrator_span = tracing::Span::current();
return async {
response_handler
.create_streaming_response(
llm_response,
tracing::Span::current(), // agent span (inner)
orchestrator_span, // orchestrator span (outer)
tracing::Span::current(),
orchestrator_span,
)
.await
.map_err(AgentFilterChainError::from)
@ -322,7 +354,6 @@ async fn handle_agent_chat_inner(
.await;
}
// For intermediate agents, collect the full response and pass to next agent
debug!(agent = %agent_name, "collecting response from intermediate agent");
let response_text = async { response_handler.collect_full_response(llm_response).await }
.instrument(agent_span)
@ -334,14 +365,11 @@ async fn handle_agent_chat_inner(
"agent completed, passing response to next agent"
);
// remove last message and add new one at the end
let Some(last_message) = current_messages.pop() else {
warn!(agent = %agent_name, "no messages in conversation history");
break;
};
// Create a new message with the agent's response as assistant message
// and add it to the conversation history
current_messages.push(OpenAIMessage {
role: hermesllm::apis::openai::Role::Assistant,
content: Some(hermesllm::apis::openai::MessageContent::Text(response_text)),
@ -353,6 +381,34 @@ async fn handle_agent_chat_inner(
current_messages.push(last_message);
}
// This should never be reached since we return in the last agent iteration
unreachable!("Agent execution loop should have returned a response")
}
async fn handle_agent_chat_inner(
request: Request<hyper::body::Incoming>,
state: Arc<AppState>,
request_id: String,
custom_attrs: std::collections::HashMap<String, String>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
let (agent_req, listener, agent_selector) =
parse_agent_request(request, &state, &request_id, &custom_attrs).await?;
let (selected_agents, agent_map) = select_and_build_agent_map(
&agent_selector,
&state,
&agent_req.messages,
&listener,
agent_req.request_id,
)
.await?;
execute_agent_chain(
&selected_agents,
&agent_map,
agent_req.client_request,
agent_req.messages,
&agent_req.request_headers,
&custom_attrs,
)
.await
}

View file

@ -141,12 +141,11 @@ impl PipelineProcessor {
Ok(chat_history_updated)
}
/// Build common MCP headers for requests
fn build_mcp_headers(
&self,
/// Prepare headers shared by all agent/filter requests: removes
/// content-length, injects trace context, sets upstream host and retry.
fn build_agent_headers(
request_headers: &HeaderMap,
agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> {
let mut headers = request_headers.clone();
headers.remove(hyper::header::CONTENT_LENGTH);
@ -167,24 +166,34 @@ impl PipelineProcessor {
headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
hyper::header::HeaderValue::from_static("3"),
);
Ok(headers)
}
/// Build headers for MCP requests (adds Accept, Content-Type, optional session id).
fn build_mcp_headers(
&self,
request_headers: &HeaderMap,
agent_id: &str,
session_id: Option<&str>,
) -> Result<HeaderMap, PipelineError> {
let mut headers = Self::build_agent_headers(request_headers, agent_id)?;
headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json, text/event-stream"),
);
headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
);
if let Some(sid) = session_id {
headers.insert(
"mcp-session-id",
hyper::header::HeaderValue::from_str(sid).unwrap(),
);
if let Ok(val) = hyper::header::HeaderValue::from_str(sid) {
headers.insert("mcp-session-id", val);
}
}
Ok(headers)
@ -530,33 +539,11 @@ impl PipelineProcessor {
});
// Build headers
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let mut agent_headers = Self::build_agent_headers(request_headers, &agent.id)?;
agent_headers.insert(
"Accept",
hyper::header::HeaderValue::from_static("application/json"),
);
agent_headers.insert(
"Content-Type",
hyper::header::HeaderValue::from_static("application/json"),
@ -629,27 +616,7 @@ impl PipelineProcessor {
.map_err(|e| PipelineError::NoContentInResponse(e.to_string()))?;
debug!("sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
// Inject OpenTelemetry trace context automatically
agent_headers.remove(TRACE_PARENT_HEADER);
global::get_text_map_propagator(|propagator| {
let cx =
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
});
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let agent_headers = Self::build_agent_headers(request_headers, &terminal_agent.id)?;
let response = self
.client

View file

@ -497,13 +497,16 @@ async fn send_upstream(
"Routing to upstream"
);
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(resolved_model).unwrap(),
);
if let Ok(val) = header::HeaderValue::from_str(resolved_model) {
request_headers.insert(ARCH_PROVIDER_HINT_HEADER, val);
}
request_headers.insert(
header::HeaderName::from_static(ARCH_IS_STREAMING_HEADER),
header::HeaderValue::from_str(&is_streaming_request.to_string()).unwrap(),
header::HeaderValue::from_static(if is_streaming_request {
"true"
} else {
"false"
}),
);
request_headers.remove(header::CONTENT_LENGTH);
@ -535,9 +538,10 @@ async fn send_upstream(
let response_headers = llm_response.headers().clone();
let upstream_status = llm_response.status();
let mut response = Response::builder().status(upstream_status);
let headers = response.headers_mut().unwrap();
for (name, value) in response_headers.iter() {
headers.insert(name, value.clone());
if let Some(headers) = response.headers_mut() {
for (name, value) in response_headers.iter() {
headers.insert(name, value.clone());
}
}
let byte_stream = llm_response.bytes_stream();
@ -637,8 +641,9 @@ async fn get_upstream_path(
) -> String {
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
.expect("Should have valid API endpoint");
let Some(client_api) = SupportedAPIsFromClient::from_endpoint(request_path) else {
return request_path.to_string();
};
client_api.target_endpoint_for_provider(
&provider_id,

View file

@ -36,7 +36,7 @@ impl ResponseHandler {
*response.status_mut() = hyper::StatusCode::BAD_REQUEST;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
hyper::header::HeaderValue::from_static("application/json"),
);
response
}
@ -55,7 +55,7 @@ impl ResponseHandler {
*response.status_mut() = hyper::StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
hyper::header::HeaderValue::from_static("application/json"),
);
response
}

View file

@ -83,7 +83,9 @@ impl RouterService {
return Ok(None);
}
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.len() < 2)
&& !self.llm_usage_defined
{
return Ok(None);
@ -108,18 +110,18 @@ impl RouterService {
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(&self.routing_provider_name).unwrap(),
);
headers.insert(
header::HeaderName::from_static(TRACE_PARENT_HEADER),
header::HeaderValue::from_str(traceparent).unwrap(),
);
headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(request_id).unwrap(),
);
if let Ok(val) = header::HeaderValue::from_str(&self.routing_provider_name) {
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
val,
);
}
if let Ok(val) = header::HeaderValue::from_str(traceparent) {
headers.insert(header::HeaderName::from_static(TRACE_PARENT_HEADER), val);
}
if let Ok(val) = header::HeaderValue::from_str(request_id) {
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
}
headers.insert(
header::HeaderName::from_static("model"),
header::HeaderValue::from_static("arch-router"),

View file

@ -60,7 +60,10 @@ impl OrchestratorService {
return Ok(None);
}
if usage_preferences.is_none() || usage_preferences.as_ref().unwrap().is_empty() {
if usage_preferences
.as_ref()
.is_none_or(|prefs| prefs.is_empty())
{
return Ok(None);
}
@ -85,7 +88,7 @@ impl OrchestratorService {
);
headers.insert(
header::HeaderName::from_static(ARCH_PROVIDER_HINT_HEADER),
header::HeaderValue::from_str(PLANO_ORCHESTRATOR_MODEL_NAME).unwrap(),
header::HeaderValue::from_static(PLANO_ORCHESTRATOR_MODEL_NAME),
);
// Inject OpenTelemetry trace context from current span
@ -96,10 +99,9 @@ impl OrchestratorService {
});
if let Some(ref request_id) = request_id {
headers.insert(
header::HeaderName::from_static(REQUEST_ID_HEADER),
header::HeaderValue::from_str(request_id).unwrap(),
);
if let Ok(val) = header::HeaderValue::from_str(request_id) {
headers.insert(header::HeaderName::from_static(REQUEST_ID_HEADER), val);
}
}
headers.insert(