fix: address review findings from refactoring PR

- Replace unreachable!() with proper error return in orchestrator agent chain
- Remove incorrect #[allow(dead_code)] on routing_provider_name
- Change SerError alias to _ (trait import for method resolution only)
- Remove dead commented-out code in pipeline.rs
- Replace unwrap()s with expect/if-let in LLM handler filter paths
- Make find_listener synchronous (no await needed)
- Unify message truncation logic via streaming::truncate_message

Made-with: Cursor
This commit is contained in:
Adil Hafeez 2026-03-18 18:26:05 -07:00
parent 8ed4b36087
commit 4845d83100
7 changed files with 36 additions and 46 deletions

View file

@ -10,7 +10,7 @@ use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response};
use opentelemetry::trace::get_active_span;
use serde::ser::Error as SerError;
use serde::ser::Error as _;
use tracing::{debug, info, info_span, warn, Instrument};
use super::pipeline::{PipelineError, PipelineProcessor};
@ -128,9 +128,7 @@ async fn parse_agent_request(
// Find the appropriate listener
let listener: common::configuration::Listener = {
let listeners = state.listeners.read().await;
agent_selector
.find_listener(listener_name, &listeners)
.await?
agent_selector.find_listener(listener_name, &listeners)?
};
get_active_span(|span| {
@ -389,7 +387,9 @@ async fn execute_agent_chain(
let Some(last_message) = current_messages.pop() else {
warn!(agent = %agent_name, "no messages in conversation history");
break;
return Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom("No messages in conversation history after agent response"),
));
};
current_messages.push(OpenAIMessage {
@ -403,7 +403,9 @@ async fn execute_agent_chain(
current_messages.push(last_message);
}
unreachable!("Agent execution loop should have returned a response")
Err(AgentFilterChainError::RequestParsing(
serde_json::Error::custom("Agent chain completed without producing a response"),
))
}
async fn handle_agent_chat_inner(

View file

@ -574,7 +574,6 @@ impl PipelineProcessor {
terminal_agent: &Agent,
request_headers: &HeaderMap,
) -> Result<reqwest::Response, PipelineError> {
// let mut request = original_request.clone();
original_request.set_messages(messages);
let request_url = "/v1/chat/completions";

View file

@ -37,7 +37,7 @@ impl AgentSelector {
}
/// Find listener by name from the request headers
pub async fn find_listener(
pub fn find_listener(
&self,
listener_name: Option<&str>,
listeners: &[common::configuration::Listener],
@ -222,9 +222,7 @@ mod tests {
let listener2 = create_test_listener("other-listener", vec![]);
let listeners = vec![listener1.clone(), listener2];
let result = selector
.find_listener(Some("test-listener"), &listeners)
.await;
let result = selector.find_listener(Some("test-listener"), &listeners);
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "test-listener");
@ -237,9 +235,7 @@ mod tests {
let listeners = vec![create_test_listener("other-listener", vec![])];
let result = selector
.find_listener(Some("nonexistent"), &listeners)
.await;
let result = selector.find_listener(Some("nonexistent"), &listeners);
assert!(result.is_err());
matches!(

View file

@ -85,9 +85,7 @@ mod tests {
let messages = vec![create_test_message(Role::User, "Hello world!")];
// Test 1: Agent Selection
let selected_listener = agent_selector
.find_listener(Some("test-listener"), &listeners)
.await;
let selected_listener = agent_selector.find_listener(Some("test-listener"), &listeners);
assert!(selected_listener.is_ok());
let listener = selected_listener.unwrap();
@ -153,7 +151,7 @@ mod tests {
let agent_selector = AgentSelector::new(router_service);
// Test listener not found
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
let result = agent_selector.find_listener(Some("nonexistent"), &[]);
assert!(result.is_err());
assert!(matches!(

View file

@ -162,8 +162,10 @@ async fn llm_chat_inner(
.await
{
Ok(filtered_bytes) => {
let api_type =
SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap();
let api_type = SupportedAPIsFromClient::from_endpoint(
request_path.as_str(),
)
.expect("endpoint validated in parse_and_validate_request");
match ProviderRequestType::try_from((&filtered_bytes[..], &api_type)) {
Ok(updated_request) => {
client_request = updated_request;
@ -198,7 +200,7 @@ async fn llm_chat_inner(
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
error_response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
hyper::header::HeaderValue::from_static("application/json"),
);
return Ok(error_response);
}
@ -661,8 +663,7 @@ async fn send_upstream(
messages_for_signals,
);
let has_output_filter = filter_pipeline.has_output_filters();
let output_filter_request_headers = if has_output_filter {
let output_filter_request_headers = if filter_pipeline.has_output_filters() {
Some(request_headers.clone())
} else {
None
@ -694,18 +695,21 @@ async fn send_upstream(
Box::new(base_processor)
};
let streaming_response = if has_output_filter {
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain,
output_filter_request_headers.unwrap(),
request_path.to_string(),
)
} else {
create_streaming_response(byte_stream, processor)
};
let streaming_response =
if let (Some(output_chain), Some(filter_headers)) = (
filter_pipeline.output.as_ref().filter(|c| !c.is_empty()),
output_filter_request_headers,
) {
create_streaming_response_with_output_filter(
byte_stream,
processor,
output_chain.clone(),
filter_headers,
request_path.to_string(),
)
} else {
create_streaming_response(byte_stream, processor)
};
match response.body(streaming_response.body) {
Ok(response) => Ok(response),

View file

@ -5,6 +5,7 @@ use hyper::StatusCode;
use std::sync::Arc;
use tracing::{debug, info, warn};
use crate::handlers::streaming::truncate_message;
use crate::router::llm::RouterService;
use crate::tracing::routing;
@ -103,16 +104,7 @@ pub async fn router_chat_get_upstream_model(
.map_or("None".to_string(), |c| c.to_string().replace('\n', "\\n"))
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
};
let latest_message_for_log = truncate_message(&latest_message_for_log, 50);
info!(
has_usage_preferences = usage_preferences.is_some(),

View file

@ -18,7 +18,6 @@ pub struct RouterService {
router_url: String,
client: reqwest::Client,
router_model: Arc<dyn RouterModel>,
#[allow(dead_code)]
routing_provider_name: String,
llm_usage_defined: bool,
}