add tests and some refactor

This commit is contained in:
Adil Hafeez 2025-09-17 11:03:51 -07:00
parent 08471d8adf
commit 2229f0d4d4
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
8 changed files with 1004 additions and 276 deletions

50
crates/Cargo.lock generated
View file

@ -104,6 +104,16 @@ version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "async-trait"
version = "0.1.88"
@ -216,6 +226,7 @@ dependencies = [
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"mockito",
"opentelemetry",
"opentelemetry-http",
"opentelemetry-otlp",
@ -321,6 +332,15 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15"
[[package]]
name = "colored"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "common"
version = "0.1.0"
@ -1731,6 +1751,30 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "mockito"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
dependencies = [
"assert-json-diff",
"bytes",
"colored",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"log",
"rand 0.9.1",
"regex",
"serde_json",
"serde_urlencoded",
"similar",
"tokio",
]
[[package]]
name = "more-asserts"
version = "0.3.1"
@ -2817,6 +2861,12 @@ dependencies = [
"libc",
]
[[package]]
name = "similar"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "slab"
version = "0.4.9"

View file

@ -28,6 +28,12 @@ serde_with = "3.13.0"
serde_yaml = "0.9.34"
thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] }
tokio-stream = "0.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
mockito = "1.0"
tokio-stream = "0.1.17"
tracing = "0.1.41"
tracing-opentelemetry = "0.30.0"

View file

@ -1,25 +1,30 @@
use std::sync::Arc;
use bytes::Bytes;
use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use http_body_util::BodyExt;
use hyper::{Request, Response};
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use super::agent_selector::{AgentSelector, AgentSelectionError};
use super::pipeline_processor::{PipelineProcessor, PipelineError};
use super::response_handler::ResponseHandler;
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentChatError {
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] super::response_handler::ResponseError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
Http(#[from] hyper::Error),
}
pub async fn agent_chat(
@ -29,25 +34,44 @@ pub async fn agent_chat(
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// find listener that is running at port 8001 for agents
let listener_name = request.headers().get("x-arch-agent-listener-name");
match handle_agent_chat(request, router_service, agents_list, listeners).await {
Ok(response) => Ok(response),
Err(err) => {
warn!("Agent chat error: {}", err);
Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err)))
}
}
}
async fn handle_agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentChatError> {
// Initialize services
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers
let listener_name = request
.headers()
.get("x-arch-agent-listener-name")
.and_then(|name| name.to_str().ok());
// Find the appropriate listener
let listener = {
let listeners = listeners.read().await;
listeners
.iter()
.find(|l| {
listener_name
.and_then(|name| name.to_str().ok())
.map(|name| l.name == name)
.unwrap_or(false)
})
.cloned()
}
.unwrap();
agent_selector
.find_listener(listener_name, &listeners)
.await?
};
info!("Handling request for listener: {}", listener.name);
let mut request_headers = request.headers().clone();
// Parse request body
let request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
@ -56,269 +80,64 @@ pub async fn agent_chat(
);
let chat_completions_request: ChatCompletionsRequest =
match serde_json::from_slice(&chat_request_bytes) {
Ok(req) => req,
Err(err) => {
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to parse request body: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
}
};
let agent_name_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
let mut map = std::collections::HashMap::new();
for agent in agents.iter() {
map.insert(agent.name.clone(), agent.clone());
}
map
};
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
warn!("Failed to parse request body as ChatCompletionsRequest: {}", err);
AgentChatError::RequestParsing(err)
})?;
// Extract trace parent for routing
let trace_parent = request_headers
.iter()
.find(|(ty, _)| ty.as_str() == "traceparent")
.find(|(key, _)| key.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let agents_usage_preferences: Vec<ModelUsagePreference> =
convert_agent_description_to_routing_preferences(&listener.agents.as_ref().unwrap());
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(
&chat_completions_request.messages,
&listener,
trace_parent,
)
.await?;
debug!(
"Agents usage preferences for agent routing: {:?}",
agents_usage_preferences
);
debug!("Processing agent pipeline: {}", selected_agent.name);
let agent_pipeline = match agents_usage_preferences.len() > 1 {
false => {
debug!("Only one agent available, skipping routing");
listener.agents.as_ref().unwrap()[0].clone()
}
true => {
let selected_agent = match router_service
.determine_route(
&chat_completions_request.messages,
trace_parent.clone(),
Some(agents_usage_preferences),
)
.await
{
Ok(route) => {
match route {
Some((_, agent_name)) => {
debug!("Determined agent: {}", agent_name);
listener
.agents
.as_ref()
.unwrap()
.iter()
.find(|a| a.name == agent_name)
.cloned()
// selected agent must exist in the agent map
.unwrap()
}
None => {
debug!("No agent determined using routing preferences, using default agent");
listener
.agents
.as_ref()
.unwrap()
.iter()
.find(|a| a.default.unwrap_or(false))
.cloned()
.unwrap_or_else(|| {
warn!(
"No default agent found, routing request to first agent: {}",
listener.agents.as_ref().unwrap()[0].name
);
listener.agents.as_ref().unwrap()[0].clone()
})
}
}
}
Err(err) => {
let err_msg = format!("Failed to determine route: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
selected_agent
}
// Create agent map for pipeline processing
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
debug!("Processing agent pipeline: {}", agent_pipeline.name);
// Process the filter chain
let processed_messages = pipeline_processor
.process_filter_chain(
&chat_completions_request,
&selected_agent,
&agent_map,
&request_headers,
)
.await?;
let mut chat_completions_history = chat_completions_request.messages.clone();
// Get terminal agent and send final response
let terminal_agent_name = selected_agent.filter_chain.last().unwrap();
let terminal_agent = agent_map.get(terminal_agent_name).unwrap();
request_headers.remove(header::CONTENT_LENGTH);
let filter_chain_without_terminal_agent =
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1];
for agent_name in filter_chain_without_terminal_agent {
debug!("Processing agent: {}", agent_name);
let agent = agent_name_map.get(agent_name).unwrap();
debug!("Agent details: {:?}", agent);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}", agent_name);
let mut agent_request_headers = request_headers.clone();
agent_request_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(agent.name.as_str()).unwrap(),
);
let response = match reqwest::Client::new()
.post("http://localhost:11000/v1/chat/completions")
.headers(agent_request_headers)
.body(request_str)
.send()
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let response_bytes = match response.bytes().await {
Ok(bytes) => bytes,
Err(err) => {
let err_msg = format!("Failed to read response bytes: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let chat_completions_response: hermesllm::apis::openai::ChatCompletionsResponse =
match serde_json::from_slice(&response_bytes) {
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to parse response body: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
let response_str = chat_completions_response.choices[0]
.message
.content
.clone()
.unwrap();
debug!("Received response from agent {}", agent_name);
chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]);
}
let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap();
let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let mut request = chat_completions_request.clone();
request.messages = chat_completions_history.clone();
let llm_response = pipeline_processor
.send_terminal_request(
&processed_messages,
&chat_completions_request,
terminal_agent,
&request_headers,
)
.await?;
let request_str = serde_json::to_string(&request).unwrap();
debug!("Sending request to agent {}", terminal_agent_name);
let mut agent_request_headers = request_headers.clone();
agent_request_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(),
);
let llm_response = match reqwest::Client::new()
.post("http://localhost:11000/v1/chat/completions")
.headers(agent_request_headers)
.body(request_str)
.send()
// Create streaming response
response_handler
.create_streaming_response(llm_response)
.await
{
Ok(res) => res,
Err(err) => {
let err_msg = format!("Failed to send request: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
return Ok(internal_error);
}
};
// copy over the headers from the original response
let response_headers = llm_response.headers().clone();
let mut response = Response::builder();
let headers = response.headers_mut().unwrap();
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// channel to create async stream
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn a task to send data as it becomes available
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let item = match item {
Ok(item) => item,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(item).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
match response.body(stream_body) {
Ok(response) => Ok(response),
Err(err) => {
let err_msg = format!("Failed to create response: {}", err);
let mut internal_error = Response::new(full(err_msg));
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
Ok(internal_error)
}
}
}
fn convert_agent_description_to_routing_preferences(
agents: &Vec<AgentPipeline>,
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.name.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.name.clone(),
description: agent
.description
.as_ref()
.unwrap_or(&"".to_string())
.clone(),
}],
})
.collect()
.map_err(AgentChatError::from)
}

View file

@ -0,0 +1,296 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentPipeline, Listener, ModelUsagePreference, RoutingPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
pub enum AgentSelectionError {
#[error("Listener not found for name: {0}")]
ListenerNotFound(String),
#[error("No agents configured for listener: {0}")]
NoAgentsConfigured(String),
#[error("Routing service error: {0}")]
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
pub struct AgentSelector {
router_service: Arc<RouterService>,
}
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
Self { router_service }
}
/// Find listener by name from the request headers
pub async fn find_listener(
&self,
listener_name: Option<&str>,
listeners: &[common::configuration::Listener],
) -> Result<Listener, AgentSelectionError> {
let listener = listeners
.iter()
.find(|l| listener_name.map(|name| l.name == name).unwrap_or(false))
.cloned()
.ok_or_else(|| {
AgentSelectionError::ListenerNotFound(
listener_name.unwrap_or("unknown").to_string(),
)
})?;
Ok(listener)
}
/// Create agent name to agent mapping for efficient lookup
pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap<String, Agent> {
agents
.iter()
.map(|agent| (agent.name.clone(), agent.clone()))
.collect()
}
/// Select appropriate agent based on routing preferences
pub async fn select_agent(
&self,
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
) -> Result<AgentPipeline, AgentSelectionError> {
let agents = listener
.agents
.as_ref()
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
// If only one agent, skip routing
if agents.len() == 1 {
debug!("Only one agent available, skipping routing");
return Ok(agents[0].clone());
}
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
);
match self
.router_service
.determine_route(messages, trace_parent, Some(usage_preferences))
.await
{
Ok(Some((_, agent_name))) => {
debug!("Determined agent: {}", agent_name);
let selected_agent = agents
.iter()
.find(|a| a.name == agent_name)
.cloned()
.ok_or_else(|| {
AgentSelectionError::RoutingError(format!(
"Selected agent '{}' not found in listener agents",
agent_name
))
})?;
Ok(selected_agent)
}
Ok(None) => {
debug!("No agent determined using routing preferences, using default agent");
self.get_default_agent(agents, &listener.name)
}
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
}
}
/// Get the default agent or the first agent if no default is specified
fn get_default_agent(
&self,
agents: &[AgentPipeline],
listener_name: &str,
) -> Result<AgentPipeline, AgentSelectionError> {
agents
.iter()
.find(|a| a.default.unwrap_or(false))
.cloned()
.or_else(|| {
warn!(
"No default agent found, routing request to first agent: {}",
agents[0].name
);
Some(agents[0].clone())
})
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
}
/// Convert agent descriptions to routing preferences
fn convert_agent_description_to_routing_preferences(
&self,
agents: &[AgentPipeline],
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.name.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.name.clone(),
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
}],
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentPipeline, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentPipeline {
AgentPipeline {
name: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: vec![name.to_string()],
}
}
fn create_test_listener(name: &str, agents: Vec<AgentPipeline>) -> Listener {
Listener {
name: name.to_string(),
agents: Some(agents),
port: 8080,
router: None,
}
}
fn create_test_agent_struct(name: &str) -> Agent {
Agent {
name: name.to_string(),
kind: "test".to_string(),
endpoint: "http://localhost:8080".to_string(),
}
}
#[tokio::test]
async fn test_find_listener_success() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listener1 = create_test_listener("test-listener", vec![]);
let listener2 = create_test_listener("other-listener", vec![]);
let listeners = vec![listener1.clone(), listener2];
let result = selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "test-listener");
}
#[tokio::test]
async fn test_find_listener_not_found() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listeners = vec![create_test_listener("other-listener", vec![])];
let result = selector
.find_listener(Some("nonexistent"), &listeners)
.await;
assert!(result.is_err());
matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
);
}
#[test]
fn test_create_agent_map() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent_struct("agent1"),
create_test_agent_struct("agent2"),
];
let agent_map = selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("agent1"));
assert!(agent_map.contains_key("agent2"));
}
#[test]
fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent description", true),
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
assert_eq!(
preferences[0].routing_preferences[0].description,
"First agent description"
);
}
#[test]
fn test_get_default_agent() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Default agent", true),
create_test_agent("agent3", "Third agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "agent2");
}
#[test]
fn test_get_default_agent_fallback_to_first() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Second agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "agent1");
}
}

View file

@ -0,0 +1,143 @@
use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError};
use crate::handlers::pipeline_processor::{PipelineProcessor};
use crate::handlers::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agent based on routing
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
use super::*;
use common::configuration::{Agent, AgentPipeline, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[tokio::test]
async fn test_modular_agent_chat_flow() {
// Setup services
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
// Create test data
let agents = vec![
Agent {
name: "filter-agent".to_string(),
kind: "filter".to_string(),
endpoint: "http://localhost:8081".to_string(),
},
Agent {
name: "terminal-agent".to_string(),
kind: "terminal".to_string(),
endpoint: "http://localhost:8082".to_string(),
},
];
let agent_pipeline = AgentPipeline {
name: "test-pipeline".to_string(),
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
description: Some("Test pipeline".to_string()),
default: Some(true),
};
let listener = Listener {
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
port: 8080,
router: None,
};
let listeners = vec![listener];
let messages = vec![create_test_message(Role::User, "Hello world!")];
// Test 1: Agent Selection
let selected_listener = agent_selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(selected_listener.is_ok());
let listener = selected_listener.unwrap();
assert_eq!(listener.name, "test-listener");
// Test 2: Agent Map Creation
let agent_map = agent_selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("filter-agent"));
assert!(agent_map.contains_key("terminal-agent"));
// Test 3: Pipeline Processing (empty filter chain for testing)
let request = ChatCompletionsRequest {
messages: messages.clone(),
model: "test-model".to_string(),
..Default::default()
};
// Create a pipeline with only terminal agent to avoid network calls
let test_pipeline = AgentPipeline {
name: "test-pipeline".to_string(),
filter_chain: vec!["terminal-agent".to_string()],
description: None,
default: None,
};
let headers = HeaderMap::new();
let result = pipeline_processor
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
.await;
assert!(result.is_ok());
let processed_messages = result.unwrap();
assert_eq!(processed_messages.len(), 1);
// Test 4: Error Response Creation
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
println!("✅ All modular components working correctly!");
}
#[tokio::test]
async fn test_error_handling_flow() {
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
// Test listener not found
let result = agent_selector
.find_listener(Some("nonexistent"), &[])
.await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_)));
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR);
println!("✅ Error handling working correctly!");
}
}

View file

@ -1,3 +1,9 @@
pub mod chat_completions;
pub mod models;
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod pipeline_processor;
pub mod response_handler;
#[cfg(test)]
mod integration_tests;

View file

@ -0,0 +1,248 @@
use std::collections::HashMap;
use common::configuration::{Agent, AgentPipeline};
use common::consts::ARCH_UPSTREAM_HOST_HEADER;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use hyper::header::HeaderMap;
use tracing::{debug, warn};
/// Errors that can occur during pipeline processing
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("HTTP request failed: {0}")]
RequestFailed(#[from] reqwest::Error),
#[error("Failed to parse response: {0}")]
ParseError(#[from] serde_json::Error),
#[error("Agent '{0}' not found in agent map")]
AgentNotFound(String),
#[error("No choices in response from agent '{0}'")]
NoChoicesInResponse(String),
#[error("No content in response from agent '{0}'")]
NoContentInResponse(String),
}
/// Service for processing agent pipelines
pub struct PipelineProcessor {
client: reqwest::Client,
llm_endpoint: String,
}
impl Default for PipelineProcessor {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
llm_endpoint: "http://localhost:11000/v1/chat/completions".to_string(),
}
}
}
impl PipelineProcessor {
pub fn new(llm_endpoint: String) -> Self {
Self {
client: reqwest::Client::new(),
llm_endpoint,
}
}
/// Process the filter chain of agents (all except the terminal agent)
pub async fn process_filter_chain(
&self,
initial_request: &ChatCompletionsRequest,
agent_pipeline: &AgentPipeline,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_completions_history = initial_request.messages.clone();
let filter_chain_without_terminal =
&agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len().saturating_sub(1)];
for agent_name in filter_chain_without_terminal {
debug!("Processing filter agent: {}", agent_name);
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
debug!("Agent details: {:?}", agent);
let response_content = self
.send_agent_request(
&chat_completions_history,
initial_request,
agent,
request_headers,
)
.await?;
debug!("Received response from filter agent {}", agent_name);
// Parse the response content as new message history
chat_completions_history = serde_json::from_str(&response_content).map_err(|err| {
warn!(
"Failed to parse response from agent {}, response: {}",
agent_name, response_content
);
err
})?;
}
Ok(chat_completions_history)
}
/// Send request to a specific agent and return the response content
async fn send_agent_request(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<String, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to agent {}", agent.name);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.name)
.map_err(|_| PipelineError::AgentNotFound(agent.name.clone()))?,
);
let response = self
.client
.post(&self.llm_endpoint)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
let response_bytes = response.bytes().await?;
// Parse the response as JSON to extract the content
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
let content = response_json
.get("choices")
.and_then(|choices| choices.as_array())
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(|content| content.as_str())
.ok_or_else(|| PipelineError::NoContentInResponse(agent.name.clone()))?
.to_string();
Ok(content)
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn send_terminal_request(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
terminal_agent: &Agent,
request_headers: &HeaderMap,
) -> Result<reqwest::Response, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to terminal agent {}", terminal_agent.name);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.name)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.name.clone()))?,
);
let response = self
.client
.post(&self.llm_endpoint)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentPipeline {
AgentPipeline {
name: "test-pipeline".to_string(),
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
description: None,
default: None,
}
}
#[tokio::test]
async fn test_process_empty_filter_chain() {
let processor = PipelineProcessor::default();
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let initial_request = ChatCompletionsRequest {
messages: vec![create_test_message(Role::User, "Hello")],
model: "test-model".to_string(),
..Default::default()
};
// Pipeline with only terminal agent (no filter chain)
let pipeline = create_test_pipeline(vec!["terminal-agent"]);
let result = processor
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
.await;
assert!(result.is_ok());
let messages = result.unwrap();
assert_eq!(messages.len(), 1);
if let MessageContent::Text(text) = &messages[0].content {
assert_eq!(text, "Hello");
}
}
#[tokio::test]
async fn test_agent_not_found_error() {
let processor = PipelineProcessor::default();
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let initial_request = ChatCompletionsRequest {
messages: vec![create_test_message(Role::User, "Hello")],
model: "test-model".to_string(),
..Default::default()
};
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
.await;
assert!(result.is_err());
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
}
}

View file

@ -0,0 +1,160 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::{Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::warn;
/// Errors that can occur during response handling
#[derive(Debug, thiserror::Error)]
pub enum ResponseError {
#[error("Failed to create response: {0}")]
ResponseCreationFailed(#[from] hyper::http::Error),
#[error("Stream error: {0}")]
StreamError(String),
}
/// Service for handling HTTP responses and streaming
pub struct ResponseHandler;
impl ResponseHandler {
pub fn new() -> Self {
Self
}
/// Create a full response body from bytes
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
/// Create an error response with a given status code and message
pub fn create_error_response(
status: StatusCode,
message: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::new(Self::create_full_body(message.to_string()));
*response.status_mut() = status;
response
}
/// Create a bad request response
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::BAD_REQUEST, message)
}
/// Create an internal server error response
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
}
/// Create a streaming response from a reqwest response
pub async fn create_streaming_response(
&self,
llm_response: reqwest::Response,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
// Copy headers from the original response
let response_headers = llm_response.headers();
let mut response_builder = Response::builder();
let headers = response_builder.headers_mut().ok_or_else(|| {
ResponseError::StreamError("Failed to get mutable headers".to_string())
})?;
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Create channel for async streaming
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn task to stream data
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(chunk).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
response_builder
.body(stream_body)
.map_err(ResponseError::from)
}
}
impl Default for ResponseHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::StatusCode;
#[test]
fn test_create_bad_request() {
let response = ResponseHandler::create_bad_request("Invalid request");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_create_internal_error() {
let response = ResponseHandler::create_internal_error("Server error");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_create_error_response() {
let response =
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_create_streaming_response_with_mock() {
use mockito::Server;
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/test")
.with_status(200)
.with_header("content-type", "text/plain")
.with_body("streaming response")
.create_async()
.await;
let client = reqwest::Client::new();
let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap();
let handler = ResponseHandler::new();
let result = handler.create_streaming_response(llm_response).await;
mock.assert_async().await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("content-type"));
}
}