add support for agents (#564)

This commit is contained in:
Adil Hafeez 2025-10-14 14:01:11 -07:00 committed by GitHub
parent f8991a3c4b
commit 96e0732089
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
41 changed files with 3571 additions and 856 deletions

View file

@ -0,0 +1,172 @@
use std::sync::Arc;
use bytes::Bytes;
use hermesllm::apis::openai::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response};
use tracing::{debug, info, warn};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentFilterChainError {
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] super::response_handler::ResponseError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
Http(#[from] hyper::Error),
}
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
match handle_agent_chat(request, router_service, agents_list, listeners).await {
Ok(response) => Ok(response),
Err(err) => {
// Print detailed error information with full error chain
let mut error_chain = Vec::new();
let mut current_error: &dyn std::error::Error = &err;
// Collect the full error chain
loop {
error_chain.push(current_error.to_string());
match current_error.source() {
Some(source) => current_error = source,
None => break,
}
}
// Log the complete error chain
warn!("Agent chat error chain: {:#?}", error_chain);
warn!("Root error: {:?}", err);
// Create structured error response as JSON
let error_json = serde_json::json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
// Log the error for debugging
info!("Structured error info: {}", error_json);
// Return JSON error response
Ok(ResponseHandler::create_json_error_response(&error_json))
}
}
}
async fn handle_agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers
let listener_name = request
.headers()
.get("x-arch-agent-listener-name")
.and_then(|name| name.to_str().ok());
// Find the appropriate listener
let listener = {
let listeners = listeners.read().await;
agent_selector
.find_listener(listener_name, &listeners)
.await?
};
info!("Handling request for listener: {}", listener.name);
// Parse request body
let request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
AgentFilterChainError::RequestParsing(err)
})?;
// Extract trace parent for routing
let trace_parent = request_headers
.iter()
.find(|(key, _)| key.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Create agent map for pipeline processing
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
// Process the filter chain
let processed_messages = pipeline_processor
.process_filter_chain(
&chat_completions_request,
&selected_agent,
&agent_map,
&request_headers,
)
.await?;
// Get terminal agent and send final response
let terminal_agent_name = selected_agent.id;
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let llm_response = pipeline_processor
.invoke_upstream_agent(
&processed_messages,
&chat_completions_request,
terminal_agent,
&request_headers,
)
.await?;
// Create streaming response
response_handler
.create_streaming_response(llm_response)
.await
.map_err(AgentFilterChainError::from)
}

View file

@ -0,0 +1,296 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, ModelUsagePreference, RoutingPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
pub enum AgentSelectionError {
#[error("Listener not found for name: {0}")]
ListenerNotFound(String),
#[error("No agents configured for listener: {0}")]
NoAgentsConfigured(String),
#[error("Routing service error: {0}")]
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
pub struct AgentSelector {
router_service: Arc<RouterService>,
}
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
Self { router_service }
}
/// Find listener by name from the request headers
pub async fn find_listener(
&self,
listener_name: Option<&str>,
listeners: &[common::configuration::Listener],
) -> Result<Listener, AgentSelectionError> {
let listener = listeners
.iter()
.find(|l| listener_name.map(|name| l.name == name).unwrap_or(false))
.cloned()
.ok_or_else(|| {
AgentSelectionError::ListenerNotFound(
listener_name.unwrap_or("unknown").to_string(),
)
})?;
Ok(listener)
}
/// Create agent name to agent mapping for efficient lookup
pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap<String, Agent> {
agents
.iter()
.map(|agent| (agent.id.clone(), agent.clone()))
.collect()
}
/// Select appropriate agent based on routing preferences
pub async fn select_agent(
&self,
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
) -> Result<AgentFilterChain, AgentSelectionError> {
let agents = listener
.agents
.as_ref()
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
// If only one agent, skip routing
if agents.len() == 1 {
debug!("Only one agent available, skipping routing");
return Ok(agents[0].clone());
}
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
);
match self
.router_service
.determine_route(messages, trace_parent, Some(usage_preferences))
.await
{
Ok(Some((_, agent_name))) => {
debug!("Determined agent: {}", agent_name);
let selected_agent = agents
.iter()
.find(|a| a.id == agent_name)
.cloned()
.ok_or_else(|| {
AgentSelectionError::RoutingError(format!(
"Selected agent '{}' not found in listener agents",
agent_name
))
})?;
Ok(selected_agent)
}
Ok(None) => {
debug!("No agent determined using routing preferences, using default agent");
self.get_default_agent(agents, &listener.name)
}
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
}
}
/// Get the default agent or the first agent if no default is specified
fn get_default_agent(
&self,
agents: &[AgentFilterChain],
listener_name: &str,
) -> Result<AgentFilterChain, AgentSelectionError> {
agents
.iter()
.find(|a| a.default.unwrap_or(false))
.cloned()
.or_else(|| {
warn!(
"No default agent found, routing request to first agent: {}",
agents[0].id
);
Some(agents[0].clone())
})
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
}
/// Convert agent descriptions to routing preferences
fn convert_agent_description_to_routing_preferences(
&self,
agents: &[AgentFilterChain],
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.id.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.id.clone(),
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
}],
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentFilterChain {
AgentFilterChain {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: vec![name.to_string()],
}
}
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
name: name.to_string(),
agents: Some(agents),
port: 8080,
router: None,
}
}
fn create_test_agent_struct(name: &str) -> Agent {
Agent {
id: name.to_string(),
kind: Some("test".to_string()),
url: "http://localhost:8080".to_string(),
}
}
#[tokio::test]
async fn test_find_listener_success() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listener1 = create_test_listener("test-listener", vec![]);
let listener2 = create_test_listener("other-listener", vec![]);
let listeners = vec![listener1.clone(), listener2];
let result = selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "test-listener");
}
#[tokio::test]
async fn test_find_listener_not_found() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listeners = vec![create_test_listener("other-listener", vec![])];
let result = selector
.find_listener(Some("nonexistent"), &listeners)
.await;
assert!(result.is_err());
matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
);
}
#[test]
fn test_create_agent_map() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent_struct("agent1"),
create_test_agent_struct("agent2"),
];
let agent_map = selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("agent1"));
assert!(agent_map.contains_key("agent2"));
}
#[test]
fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent description", true),
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
assert_eq!(
preferences[0].routing_preferences[0].description,
"First agent description"
);
}
#[test]
fn test_get_default_agent() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Default agent", true),
create_test_agent("agent3", "Third agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "agent2");
}
#[test]
fn test_get_default_agent_fallback_to_first() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Second agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "agent1");
}
}

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::ARCH_PROVIDER_HINT_HEADER;
@ -11,6 +9,8 @@ use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -30,14 +30,19 @@ pub async fn chat(
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
@ -77,7 +82,10 @@ pub async fn chat(
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
match ProviderRequestType::try_from((
client_request,
&SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(ProviderRequestType::MessagesRequest(_)) => {
// This should not happen after conversion to OpenAI format
@ -86,9 +94,12 @@ pub async fn chat(
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
},
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to convert request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
@ -106,28 +117,29 @@ pub async fn chat(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log =
chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
let latest_message_for_log = chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.chars().count() > MAX_MESSAGE_LENGTH {
let truncated: String = latest_message_for_log.chars().take(MAX_MESSAGE_LENGTH).collect();
let truncated: String = latest_message_for_log
.chars()
.take(MAX_MESSAGE_LENGTH)
.collect();
format!("{}...", truncated)
} else {
latest_message_for_log
@ -153,12 +165,11 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
info!(
info!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {

View file

@ -0,0 +1,155 @@
use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::handlers::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agent based on routing
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[tokio::test]
async fn test_modular_agent_chat_flow() {
// Setup services
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
// Create test data
let agents = vec![
Agent {
id: "filter-agent".to_string(),
kind: Some("filter".to_string()),
url: "http://localhost:8081".to_string(),
},
Agent {
id: "terminal-agent".to_string(),
kind: Some("terminal".to_string()),
url: "http://localhost:8082".to_string(),
},
];
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
description: Some("Test pipeline".to_string()),
default: Some(true),
};
let listener = Listener {
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
port: 8080,
router: None,
};
let listeners = vec![listener];
let messages = vec![create_test_message(Role::User, "Hello world!")];
// Test 1: Agent Selection
let selected_listener = agent_selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(selected_listener.is_ok());
let listener = selected_listener.unwrap();
assert_eq!(listener.name, "test-listener");
// Test 2: Agent Map Creation
let agent_map = agent_selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("filter-agent"));
assert!(agent_map.contains_key("terminal-agent"));
// Test 3: Pipeline Processing (empty filter chain for testing)
let request = ChatCompletionsRequest {
messages: messages.clone(),
model: "test-model".to_string(),
..Default::default()
};
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: vec![], // Empty filter chain - no network calls needed
description: None,
default: None,
};
let headers = HeaderMap::new();
let result = pipeline_processor
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
.await;
println!("Pipeline processing result: {:?}", result);
assert!(result.is_ok());
let processed_messages = result.unwrap();
// With empty filter chain, should return the original messages unchanged
assert_eq!(processed_messages.len(), 1);
if let MessageContent::Text(content) = &processed_messages[0].content {
assert_eq!(content, "Hello world!");
} else {
panic!("Expected text content");
}
// Test 4: Error Response Creation
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
println!("✅ All modular components working correctly!");
}
#[tokio::test]
async fn test_error_handling_flow() {
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
// Test listener not found
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
));
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
println!("✅ Error handling working correctly!");
}
}

View file

@ -1,2 +1,9 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod chat_completions;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
#[cfg(test)]
mod integration_tests;

View file

@ -0,0 +1,228 @@
use std::collections::HashMap;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use hyper::header::HeaderMap;
use tracing::{debug, warn};
/// Errors that can occur during pipeline processing
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("HTTP request failed: {0}")]
RequestFailed(#[from] reqwest::Error),
#[error("Failed to parse response: {0}")]
ParseError(#[from] serde_json::Error),
#[error("Agent '{0}' not found in agent map")]
AgentNotFound(String),
#[error("No choices in response from agent '{0}'")]
NoChoicesInResponse(String),
#[error("No content in response from agent '{0}'")]
NoContentInResponse(String),
}
/// Service for processing agent pipelines
pub struct PipelineProcessor {
client: reqwest::Client,
url: String,
}
impl Default for PipelineProcessor {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
url: "http://localhost:11000/v1/chat/completions".to_string(),
}
}
}
impl PipelineProcessor {
pub fn new(url: String) -> Self {
Self {
client: reqwest::Client::new(),
url,
}
}
/// Process the filter chain of agents (all except the terminal agent)
pub async fn process_filter_chain(
&self,
initial_request: &ChatCompletionsRequest,
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_completions_history = initial_request.messages.clone();
for agent_name in &agent_filter_chain.filter_chain {
debug!("Processing filter agent: {}", agent_name);
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
debug!("Agent details: {:?}", agent);
let response_content = self
.send_agent_filter_chain_request(
&chat_completions_history,
initial_request,
agent,
request_headers,
)
.await?;
debug!("Received response from filter agent {}", agent_name);
// Parse the response content as new message history
chat_completions_history =
serde_json::from_str(&response_content).inspect_err(|err| {
warn!(
"Failed to parse response from agent {}, err: {}, response: {}",
agent_name, err, response_content
)
})?;
}
Ok(chat_completions_history)
}
/// Send request to a specific agent and return the response content
async fn send_agent_filter_chain_request(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<String, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to agent {}", agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.url)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
let response_bytes = response.bytes().await?;
// Parse the response as JSON to extract the content
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
let content = response_json
.get("choices")
.and_then(|choices| choices.as_array())
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(|content| content.as_str())
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
.to_string();
Ok(content)
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn invoke_upstream_agent(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
terminal_agent: &Agent,
request_headers: &HeaderMap,
) -> Result<reqwest::Response, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.url)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
description: None,
default: None,
}
}
#[tokio::test]
async fn test_agent_not_found_error() {
let processor = PipelineProcessor::default();
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let initial_request = ChatCompletionsRequest {
messages: vec![create_test_message(Role::User, "Hello")],
model: "test-model".to_string(),
..Default::default()
};
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
.await;
assert!(result.is_err());
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
}
}

View file

@ -0,0 +1,191 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::{Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::warn;
/// Errors that can occur during response handling
#[derive(Debug, thiserror::Error)]
pub enum ResponseError {
#[error("Failed to create response: {0}")]
ResponseCreationFailed(#[from] hyper::http::Error),
#[error("Stream error: {0}")]
StreamError(String),
}
/// Service for handling HTTP responses and streaming
pub struct ResponseHandler;
impl ResponseHandler {
pub fn new() -> Self {
Self
}
/// Create a full response body from bytes
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
/// Create an error response with a given status code and message
pub fn create_error_response(
status: StatusCode,
message: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::new(Self::create_full_body(message.to_string()));
*response.status_mut() = status;
response
}
/// Create a bad request response
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::BAD_REQUEST, message)
}
/// Create an internal server error response
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
}
/// Create a JSON error response
pub fn create_json_error_response(
error_json: &serde_json::Value,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let json_string = error_json.to_string();
let mut response = Response::new(Self::create_full_body(json_string));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
response
}
/// Create a streaming response from a reqwest response
pub async fn create_streaming_response(
&self,
llm_response: reqwest::Response,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
// Copy headers from the original response
let response_headers = llm_response.headers();
let mut response_builder = Response::builder();
let headers = response_builder.headers_mut().ok_or_else(|| {
ResponseError::StreamError("Failed to get mutable headers".to_string())
})?;
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Create channel for async streaming
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn task to stream data
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(chunk).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
response_builder
.body(stream_body)
.map_err(ResponseError::from)
}
}
impl Default for ResponseHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::StatusCode;
#[test]
fn test_create_bad_request() {
let response = ResponseHandler::create_bad_request("Invalid request");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_create_internal_error() {
let response = ResponseHandler::create_internal_error("Server error");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_create_error_response() {
let response =
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[test]
fn test_create_json_error_response() {
let error_json = serde_json::json!({
"error": {
"type": "TestError",
"message": "Test error message"
}
});
let response = ResponseHandler::create_json_error_response(&error_json);
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[tokio::test]
async fn test_create_streaming_response_with_mock() {
use mockito::Server;
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/test")
.with_status(200)
.with_header("content-type", "text/plain")
.with_body("streaming response")
.create_async()
.await;
let client = reqwest::Client::new();
let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap();
let handler = ResponseHandler::new();
let result = handler.create_streaming_response(llm_response).await;
mock.assert_async().await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("content-type"));
}
}

View file

@ -1,3 +1,4 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
@ -61,15 +62,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
debug!(
"arch_config: {:?}",
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001".to_string());
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address);
@ -84,11 +87,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let routing_llm_provider = arch_config
.routing
.as_ref()
.and_then(|r| r.llm_provider.clone())
.and_then(|r| r.model_provider.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
arch_config.model_providers.clone(),
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
routing_model_name,
routing_llm_provider,
@ -96,7 +99,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
@ -107,24 +109,44 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
(&Method::POST, "/agents/v1/chat/completions") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(
req,
router_service,
fully_qualified_url,
agents_list,
listeners,
)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)
}
// hack for now to get openw-web-ui to work
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
response
@ -148,6 +170,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(response)
}
_ => {
debug!("No route for {} {}", req.method(), req.uri().path());
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)

View file

@ -79,7 +79,13 @@ impl RouterService {
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>> {
if !self.llm_usage_defined {
if messages.is_empty() {
return Ok(None);
}
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
&& !self.llm_usage_defined
{
return Ok(None);
}

View file

@ -1,9 +1,7 @@
use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
};
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
use common::configuration::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};

View file

@ -1,20 +1,27 @@
use std::sync::OnceLock;
use std::fmt;
use std::sync::OnceLock;
use opentelemetry::global;
use opentelemetry_sdk::{propagation::TraceContextPropagator, trace::SdkTracerProvider};
use opentelemetry_stdout::SpanExporter;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing::{Event, Subscriber};
use time::macros::format_description;
use tracing::{Event, Subscriber};
use tracing_subscriber::fmt::{format, time::FormatTime, FmtContext, FormatEvent, FormatFields};
use tracing_subscriber::EnvFilter;
struct BracketedTime;
impl FormatTime for BracketedTime {
fn format_time(&self, w: &mut format::Writer<'_>) -> fmt::Result {
let now = time::OffsetDateTime::now_utc();
write!(w, "[{}]", now.format(&format_description!("[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]")).unwrap())
write!(
w,
"[{}]",
now.format(&format_description!(
"[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]"
))
.unwrap()
)
}
}
@ -34,7 +41,11 @@ where
let timer = BracketedTime;
timer.format_time(&mut writer)?;
write!(writer, "[{}] ", event.metadata().level().to_string().to_lowercase())?;
write!(
writer,
"[{}] ",
event.metadata().level().to_string().to_lowercase()
)?;
ctx.field_format().format_fields(writer.by_ref(), event)?;