Merge branch 'main' into salmanap/add-support-for-bedrock-llms

This commit is contained in:
Adil Hafeez 2025-10-22 11:13:13 -07:00
commit 2b54b8833e
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
29 changed files with 2741 additions and 391 deletions

50
crates/Cargo.lock generated
View file

@ -68,6 +68,16 @@ version = "1.0.98"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16d2d3311acee920a9eb8d33b8cbc1787ce4a264e85f964c2404b969bdcd487"
[[package]]
name = "assert-json-diff"
version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "async-trait"
version = "0.1.88"
@ -202,6 +212,7 @@ dependencies = [
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"mockito",
"opentelemetry",
"opentelemetry-http",
"opentelemetry-otlp",
@ -283,6 +294,15 @@ dependencies = [
"windows-link",
]
[[package]]
name = "colored"
version = "3.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e"
dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "common"
version = "0.1.0"
@ -1318,6 +1338,30 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "mockito"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48"
dependencies = [
"assert-json-diff",
"bytes",
"colored",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"log",
"rand 0.9.1",
"regex",
"serde_json",
"serde_urlencoded",
"similar",
"tokio",
]
[[package]]
name = "native-tls"
version = "0.2.14"
@ -2269,6 +2313,12 @@ dependencies = [
"libc",
]
[[package]]
name = "similar"
version = "2.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
[[package]]
name = "slab"
version = "0.4.9"

View file

@ -28,8 +28,14 @@ serde_with = "3.13.0"
serde_yaml = "0.9.34"
thiserror = "2.0.12"
tokio = { version = "1.44.2", features = ["full"] }
tokio-stream = "0.1.17"
tokio-stream = "0.1"
time = { version = "0.3", features = ["formatting", "macros"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
[dev-dependencies]
mockito = "1.0"
tokio-stream = "0.1.17"
tracing = "0.1.41"
tracing-opentelemetry = "0.30.0"
tracing-subscriber = { version = "0.3.19", features = ["env-filter", "fmt", "time"] }

View file

@ -0,0 +1,172 @@
use std::sync::Arc;
use bytes::Bytes;
use hermesllm::apis::openai::ChatCompletionsRequest;
use http_body_util::combinators::BoxBody;
use http_body_util::BodyExt;
use hyper::{Request, Response};
use tracing::{debug, info, warn};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
pub enum AgentFilterChainError {
#[error("Agent selection error: {0}")]
Selection(#[from] AgentSelectionError),
#[error("Pipeline processing error: {0}")]
Pipeline(#[from] PipelineError),
#[error("Response handling error: {0}")]
Response(#[from] super::response_handler::ResponseError),
#[error("Request parsing error: {0}")]
RequestParsing(#[from] serde_json::Error),
#[error("HTTP error: {0}")]
Http(#[from] hyper::Error),
}
pub async fn agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
_: String,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
match handle_agent_chat(request, router_service, agents_list, listeners).await {
Ok(response) => Ok(response),
Err(err) => {
// Print detailed error information with full error chain
let mut error_chain = Vec::new();
let mut current_error: &dyn std::error::Error = &err;
// Collect the full error chain
loop {
error_chain.push(current_error.to_string());
match current_error.source() {
Some(source) => current_error = source,
None => break,
}
}
// Log the complete error chain
warn!("Agent chat error chain: {:#?}", error_chain);
warn!("Root error: {:?}", err);
// Create structured error response as JSON
let error_json = serde_json::json!({
"error": {
"type": "AgentFilterChainError",
"message": err.to_string(),
"error_chain": error_chain,
"debug_info": format!("{:?}", err)
}
});
// Log the error for debugging
info!("Structured error info: {}", error_json);
// Return JSON error response
Ok(ResponseHandler::create_json_error_response(&error_json))
}
}
}
async fn handle_agent_chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
agents_list: Arc<tokio::sync::RwLock<Option<Vec<common::configuration::Agent>>>>,
listeners: Arc<tokio::sync::RwLock<Vec<common::configuration::Listener>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, AgentFilterChainError> {
// Initialize services
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
let response_handler = ResponseHandler::new();
// Extract listener name from headers
let listener_name = request
.headers()
.get("x-arch-agent-listener-name")
.and_then(|name| name.to_str().ok());
// Find the appropriate listener
let listener = {
let listeners = listeners.read().await;
agent_selector
.find_listener(listener_name, &listeners)
.await?
};
info!("Handling request for listener: {}", listener.name);
// Parse request body
let request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
AgentFilterChainError::RequestParsing(err)
})?;
// Extract trace parent for routing
let trace_parent = request_headers
.iter()
.find(|(key, _)| key.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.id);
// Create agent map for pipeline processing
let agent_map = {
let agents = agents_list.read().await;
let agents = agents.as_ref().unwrap();
agent_selector.create_agent_map(agents)
};
// Process the filter chain
let processed_messages = pipeline_processor
.process_filter_chain(
&chat_completions_request,
&selected_agent,
&agent_map,
&request_headers,
)
.await?;
// Get terminal agent and send final response
let terminal_agent_name = selected_agent.id;
let terminal_agent = agent_map.get(&terminal_agent_name).unwrap();
debug!("Processing terminal agent: {}", terminal_agent_name);
debug!("Terminal agent details: {:?}", terminal_agent);
let llm_response = pipeline_processor
.invoke_upstream_agent(
&processed_messages,
&chat_completions_request,
terminal_agent,
&request_headers,
)
.await?;
// Create streaming response
response_handler
.create_streaming_response(llm_response)
.await
.map_err(AgentFilterChainError::from)
}

View file

@ -0,0 +1,296 @@
use std::collections::HashMap;
use std::sync::Arc;
use common::configuration::{
Agent, AgentFilterChain, Listener, ModelUsagePreference, RoutingPreference,
};
use hermesllm::apis::openai::Message;
use tracing::{debug, warn};
use crate::router::llm_router::RouterService;
/// Errors that can occur during agent selection
#[derive(Debug, thiserror::Error)]
pub enum AgentSelectionError {
#[error("Listener not found for name: {0}")]
ListenerNotFound(String),
#[error("No agents configured for listener: {0}")]
NoAgentsConfigured(String),
#[error("Routing service error: {0}")]
RoutingError(String),
#[error("Default agent not found for listener: {0}")]
DefaultAgentNotFound(String),
}
/// Service for selecting agents based on routing preferences and listener configuration
pub struct AgentSelector {
router_service: Arc<RouterService>,
}
impl AgentSelector {
pub fn new(router_service: Arc<RouterService>) -> Self {
Self { router_service }
}
/// Find listener by name from the request headers
pub async fn find_listener(
&self,
listener_name: Option<&str>,
listeners: &[common::configuration::Listener],
) -> Result<Listener, AgentSelectionError> {
let listener = listeners
.iter()
.find(|l| listener_name.map(|name| l.name == name).unwrap_or(false))
.cloned()
.ok_or_else(|| {
AgentSelectionError::ListenerNotFound(
listener_name.unwrap_or("unknown").to_string(),
)
})?;
Ok(listener)
}
/// Create agent name to agent mapping for efficient lookup
pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap<String, Agent> {
agents
.iter()
.map(|agent| (agent.id.clone(), agent.clone()))
.collect()
}
/// Select appropriate agent based on routing preferences
pub async fn select_agent(
&self,
messages: &[Message],
listener: &Listener,
trace_parent: Option<String>,
) -> Result<AgentFilterChain, AgentSelectionError> {
let agents = listener
.agents
.as_ref()
.ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?;
// If only one agent, skip routing
if agents.len() == 1 {
debug!("Only one agent available, skipping routing");
return Ok(agents[0].clone());
}
let usage_preferences = self.convert_agent_description_to_routing_preferences(agents);
debug!(
"Agents usage preferences for agent routing str: {}",
serde_json::to_string(&usage_preferences).unwrap_or_default()
);
match self
.router_service
.determine_route(messages, trace_parent, Some(usage_preferences))
.await
{
Ok(Some((_, agent_name))) => {
debug!("Determined agent: {}", agent_name);
let selected_agent = agents
.iter()
.find(|a| a.id == agent_name)
.cloned()
.ok_or_else(|| {
AgentSelectionError::RoutingError(format!(
"Selected agent '{}' not found in listener agents",
agent_name
))
})?;
Ok(selected_agent)
}
Ok(None) => {
debug!("No agent determined using routing preferences, using default agent");
self.get_default_agent(agents, &listener.name)
}
Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())),
}
}
/// Get the default agent or the first agent if no default is specified
fn get_default_agent(
&self,
agents: &[AgentFilterChain],
listener_name: &str,
) -> Result<AgentFilterChain, AgentSelectionError> {
agents
.iter()
.find(|a| a.default.unwrap_or(false))
.cloned()
.or_else(|| {
warn!(
"No default agent found, routing request to first agent: {}",
agents[0].id
);
Some(agents[0].clone())
})
.ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string()))
}
/// Convert agent descriptions to routing preferences
fn convert_agent_description_to_routing_preferences(
&self,
agents: &[AgentFilterChain],
) -> Vec<ModelUsagePreference> {
agents
.iter()
.map(|agent| ModelUsagePreference {
model: agent.id.clone(),
routing_preferences: vec![RoutingPreference {
name: agent.id.clone(),
description: agent.description.as_ref().unwrap_or(&String::new()).clone(),
}],
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentFilterChain {
AgentFilterChain {
id: name.to_string(),
description: Some(description.to_string()),
default: Some(is_default),
filter_chain: vec![name.to_string()],
}
}
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
name: name.to_string(),
agents: Some(agents),
port: 8080,
router: None,
}
}
fn create_test_agent_struct(name: &str) -> Agent {
Agent {
id: name.to_string(),
kind: Some("test".to_string()),
url: "http://localhost:8080".to_string(),
}
}
#[tokio::test]
async fn test_find_listener_success() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listener1 = create_test_listener("test-listener", vec![]);
let listener2 = create_test_listener("other-listener", vec![]);
let listeners = vec![listener1.clone(), listener2];
let result = selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().name, "test-listener");
}
#[tokio::test]
async fn test_find_listener_not_found() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let listeners = vec![create_test_listener("other-listener", vec![])];
let result = selector
.find_listener(Some("nonexistent"), &listeners)
.await;
assert!(result.is_err());
matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
);
}
#[test]
fn test_create_agent_map() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent_struct("agent1"),
create_test_agent_struct("agent2"),
];
let agent_map = selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("agent1"));
assert!(agent_map.contains_key("agent2"));
}
#[test]
fn test_convert_agent_description_to_routing_preferences() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent description", true),
create_test_agent("agent2", "Second agent description", false),
];
let preferences = selector.convert_agent_description_to_routing_preferences(&agents);
assert_eq!(preferences.len(), 2);
assert_eq!(preferences[0].model, "agent1");
assert_eq!(preferences[0].routing_preferences[0].name, "agent1");
assert_eq!(
preferences[0].routing_preferences[0].description,
"First agent description"
);
}
#[test]
fn test_get_default_agent() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Default agent", true),
create_test_agent("agent3", "Third agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "agent2");
}
#[test]
fn test_get_default_agent_fallback_to_first() {
let router_service = create_test_router_service();
let selector = AgentSelector::new(router_service);
let agents = vec![
create_test_agent("agent1", "First agent", false),
create_test_agent("agent2", "Second agent", false),
];
let result = selector.get_default_agent(&agents, "test-listener");
assert!(result.is_ok());
assert_eq!(result.unwrap().id, "agent1");
}
}

View file

@ -0,0 +1,155 @@
use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::handlers::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Integration test that demonstrates the modular agent chat flow
/// This test shows how the three main components work together:
/// 1. AgentSelector - selects the appropriate agent based on routing
/// 2. PipelineProcessor - executes the agent pipeline
/// 3. ResponseHandler - handles response streaming
#[cfg(test)]
mod integration_tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
fn create_test_router_service() -> Arc<RouterService> {
Arc::new(RouterService::new(
vec![], // empty providers for testing
"http://localhost:8080".to_string(),
"test-model".to_string(),
"test-provider".to_string(),
))
}
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
#[tokio::test]
async fn test_modular_agent_chat_flow() {
// Setup services
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
let pipeline_processor = PipelineProcessor::default();
// Create test data
let agents = vec![
Agent {
id: "filter-agent".to_string(),
kind: Some("filter".to_string()),
url: "http://localhost:8081".to_string(),
},
Agent {
id: "terminal-agent".to_string(),
kind: Some("terminal".to_string()),
url: "http://localhost:8082".to_string(),
},
];
let agent_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()],
description: Some("Test pipeline".to_string()),
default: Some(true),
};
let listener = Listener {
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
port: 8080,
router: None,
};
let listeners = vec![listener];
let messages = vec![create_test_message(Role::User, "Hello world!")];
// Test 1: Agent Selection
let selected_listener = agent_selector
.find_listener(Some("test-listener"), &listeners)
.await;
assert!(selected_listener.is_ok());
let listener = selected_listener.unwrap();
assert_eq!(listener.name, "test-listener");
// Test 2: Agent Map Creation
let agent_map = agent_selector.create_agent_map(&agents);
assert_eq!(agent_map.len(), 2);
assert!(agent_map.contains_key("filter-agent"));
assert!(agent_map.contains_key("terminal-agent"));
// Test 3: Pipeline Processing (empty filter chain for testing)
let request = ChatCompletionsRequest {
messages: messages.clone(),
model: "test-model".to_string(),
..Default::default()
};
// Create a pipeline with empty filter chain to avoid network calls
let test_pipeline = AgentFilterChain {
id: "terminal-agent".to_string(),
filter_chain: vec![], // Empty filter chain - no network calls needed
description: None,
default: None,
};
let headers = HeaderMap::new();
let result = pipeline_processor
.process_filter_chain(&request, &test_pipeline, &agent_map, &headers)
.await;
println!("Pipeline processing result: {:?}", result);
assert!(result.is_ok());
let processed_messages = result.unwrap();
// With empty filter chain, should return the original messages unchanged
assert_eq!(processed_messages.len(), 1);
if let MessageContent::Text(content) = &processed_messages[0].content {
assert_eq!(content, "Hello world!");
} else {
panic!("Expected text content");
}
// Test 4: Error Response Creation
let error_response = ResponseHandler::create_bad_request("Test error");
assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST);
println!("✅ All modular components working correctly!");
}
#[tokio::test]
async fn test_error_handling_flow() {
let router_service = create_test_router_service();
let agent_selector = AgentSelector::new(router_service);
// Test listener not found
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
));
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
println!("✅ Error handling working correctly!");
}
}

View file

@ -1,2 +1,9 @@
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod chat_completions;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;
#[cfg(test)]
mod integration_tests;

View file

@ -0,0 +1,228 @@
use std::collections::HashMap;
use common::configuration::{Agent, AgentFilterChain};
use common::consts::{ARCH_UPSTREAM_HOST_HEADER, ENVOY_RETRY_HEADER};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message};
use hyper::header::HeaderMap;
use tracing::{debug, warn};
/// Errors that can occur during pipeline processing
#[derive(Debug, thiserror::Error)]
pub enum PipelineError {
#[error("HTTP request failed: {0}")]
RequestFailed(#[from] reqwest::Error),
#[error("Failed to parse response: {0}")]
ParseError(#[from] serde_json::Error),
#[error("Agent '{0}' not found in agent map")]
AgentNotFound(String),
#[error("No choices in response from agent '{0}'")]
NoChoicesInResponse(String),
#[error("No content in response from agent '{0}'")]
NoContentInResponse(String),
}
/// Service for processing agent pipelines
pub struct PipelineProcessor {
client: reqwest::Client,
url: String,
}
impl Default for PipelineProcessor {
fn default() -> Self {
Self {
client: reqwest::Client::new(),
url: "http://localhost:11000/v1/chat/completions".to_string(),
}
}
}
impl PipelineProcessor {
pub fn new(url: String) -> Self {
Self {
client: reqwest::Client::new(),
url,
}
}
/// Process the filter chain of agents (all except the terminal agent)
pub async fn process_filter_chain(
&self,
initial_request: &ChatCompletionsRequest,
agent_filter_chain: &AgentFilterChain,
agent_map: &HashMap<String, Agent>,
request_headers: &HeaderMap,
) -> Result<Vec<Message>, PipelineError> {
let mut chat_completions_history = initial_request.messages.clone();
for agent_name in &agent_filter_chain.filter_chain {
debug!("Processing filter agent: {}", agent_name);
let agent = agent_map
.get(agent_name)
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
debug!("Agent details: {:?}", agent);
let response_content = self
.send_agent_filter_chain_request(
&chat_completions_history,
initial_request,
agent,
request_headers,
)
.await?;
debug!("Received response from filter agent {}", agent_name);
// Parse the response content as new message history
chat_completions_history =
serde_json::from_str(&response_content).inspect_err(|err| {
warn!(
"Failed to parse response from agent {}, err: {}, response: {}",
agent_name, err, response_content
)
})?;
}
Ok(chat_completions_history)
}
/// Send request to a specific agent and return the response content
async fn send_agent_filter_chain_request(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
agent: &Agent,
request_headers: &HeaderMap,
) -> Result<String, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to agent {}", agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&agent.id)
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.url)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
let response_bytes = response.bytes().await?;
// Parse the response as JSON to extract the content
let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?;
let content = response_json
.get("choices")
.and_then(|choices| choices.as_array())
.and_then(|choices| choices.first())
.and_then(|choice| choice.get("message"))
.and_then(|message| message.get("content"))
.and_then(|content| content.as_str())
.ok_or_else(|| PipelineError::NoContentInResponse(agent.id.clone()))?
.to_string();
Ok(content)
}
/// Send request to terminal agent and return the raw response for streaming
pub async fn invoke_upstream_agent(
&self,
messages: &[Message],
original_request: &ChatCompletionsRequest,
terminal_agent: &Agent,
request_headers: &HeaderMap,
) -> Result<reqwest::Response, PipelineError> {
let mut request = original_request.clone();
request.messages = messages.to_vec();
let request_body = serde_json::to_string(&request)?;
debug!("Sending request to terminal agent {}", terminal_agent.id);
let mut agent_headers = request_headers.clone();
agent_headers.remove(hyper::header::CONTENT_LENGTH);
agent_headers.insert(
ARCH_UPSTREAM_HOST_HEADER,
hyper::header::HeaderValue::from_str(&terminal_agent.id)
.map_err(|_| PipelineError::AgentNotFound(terminal_agent.id.clone()))?,
);
agent_headers.insert(
ENVOY_RETRY_HEADER,
hyper::header::HeaderValue::from_str("3").unwrap(),
);
let response = self
.client
.post(&self.url)
.headers(agent_headers)
.body(request_body)
.send()
.await?;
Ok(response)
}
}
#[cfg(test)]
mod tests {
use super::*;
use hermesllm::apis::openai::{Message, MessageContent, Role};
use std::collections::HashMap;
fn create_test_message(role: Role, content: &str) -> Message {
Message {
role,
content: MessageContent::Text(content.to_string()),
name: None,
tool_calls: None,
tool_call_id: None,
}
}
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
AgentFilterChain {
id: "test-agent".to_string(),
filter_chain: agents.iter().map(|s| s.to_string()).collect(),
description: None,
default: None,
}
}
#[tokio::test]
async fn test_agent_not_found_error() {
let processor = PipelineProcessor::default();
let agent_map = HashMap::new();
let request_headers = HeaderMap::new();
let initial_request = ChatCompletionsRequest {
messages: vec![create_test_message(Role::User, "Hello")],
model: "test-model".to_string(),
..Default::default()
};
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
let result = processor
.process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers)
.await;
assert!(result.is_err());
matches!(result.unwrap_err(), PipelineError::AgentNotFound(_));
}
}

View file

@ -0,0 +1,191 @@
use bytes::Bytes;
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::{Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::warn;
/// Errors that can occur during response handling
#[derive(Debug, thiserror::Error)]
pub enum ResponseError {
#[error("Failed to create response: {0}")]
ResponseCreationFailed(#[from] hyper::http::Error),
#[error("Stream error: {0}")]
StreamError(String),
}
/// Service for handling HTTP responses and streaming
pub struct ResponseHandler;
impl ResponseHandler {
pub fn new() -> Self {
Self
}
/// Create a full response body from bytes
pub fn create_full_body<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into())
.map_err(|never| match never {})
.boxed()
}
/// Create an error response with a given status code and message
pub fn create_error_response(
status: StatusCode,
message: &str,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let mut response = Response::new(Self::create_full_body(message.to_string()));
*response.status_mut() = status;
response
}
/// Create a bad request response
pub fn create_bad_request(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::BAD_REQUEST, message)
}
/// Create an internal server error response
pub fn create_internal_error(message: &str) -> Response<BoxBody<Bytes, hyper::Error>> {
Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message)
}
/// Create a JSON error response
pub fn create_json_error_response(
error_json: &serde_json::Value,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let json_string = error_json.to_string();
let mut response = Response::new(Self::create_full_body(json_string));
*response.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
response
}
/// Create a streaming response from a reqwest response
pub async fn create_streaming_response(
&self,
llm_response: reqwest::Response,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, ResponseError> {
// Copy headers from the original response
let response_headers = llm_response.headers();
let mut response_builder = Response::builder();
let headers = response_builder.headers_mut().ok_or_else(|| {
ResponseError::StreamError("Failed to get mutable headers".to_string())
})?;
for (header_name, header_value) in response_headers.iter() {
headers.insert(header_name, header_value.clone());
}
// Create channel for async streaming
let (tx, rx) = mpsc::channel::<Bytes>(16);
// Spawn task to stream data
tokio::spawn(async move {
let mut byte_stream = llm_response.bytes_stream();
while let Some(item) = byte_stream.next().await {
let chunk = match item {
Ok(chunk) => chunk,
Err(err) => {
warn!("Error receiving chunk: {:?}", err);
break;
}
};
if tx.send(chunk).await.is_err() {
warn!("Receiver dropped");
break;
}
}
});
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
let stream_body = BoxBody::new(StreamBody::new(stream));
response_builder
.body(stream_body)
.map_err(ResponseError::from)
}
}
impl Default for ResponseHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use hyper::StatusCode;
#[test]
fn test_create_bad_request() {
let response = ResponseHandler::create_bad_request("Invalid request");
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[test]
fn test_create_internal_error() {
let response = ResponseHandler::create_internal_error("Server error");
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_create_error_response() {
let response =
ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found");
assert_eq!(response.status(), StatusCode::NOT_FOUND);
}
#[test]
fn test_create_json_error_response() {
let error_json = serde_json::json!({
"error": {
"type": "TestError",
"message": "Test error message"
}
});
let response = ResponseHandler::create_json_error_response(&error_json);
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
response.headers().get("content-type").unwrap(),
"application/json"
);
}
#[tokio::test]
async fn test_create_streaming_response_with_mock() {
use mockito::Server;
let mut server = Server::new_async().await;
let mock = server
.mock("GET", "/test")
.with_status(200)
.with_header("content-type", "text/plain")
.with_body("streaming response")
.create_async()
.await;
let client = reqwest::Client::new();
let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap();
let handler = ResponseHandler::new();
let result = handler.create_streaming_response(llm_response).await;
mock.assert_async().await;
assert!(result.is_ok());
let response = result.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert!(response.headers().contains_key("content-type"));
}
}

View file

@ -1,3 +1,4 @@
use brightstaff::handlers::agent_chat_completions::agent_chat;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
@ -61,7 +62,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let arch_config = Arc::new(config);
let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone()));
let llm_providers = Arc::new(RwLock::new(arch_config.model_providers.clone()));
let agents_list = Arc::new(RwLock::new(arch_config.agents.clone()));
let listeners = Arc::new(RwLock::new(arch_config.listeners.clone()));
debug!(
"arch_config: {:?}",
@ -84,11 +87,11 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let routing_llm_provider = arch_config
.routing
.as_ref()
.and_then(|r| r.llm_provider.clone())
.and_then(|r| r.model_provider.clone())
.unwrap_or_else(|| DEFAULT_ROUTING_LLM_PROVIDER.to_string());
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
arch_config.model_providers.clone(),
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
routing_model_name,
routing_llm_provider,
@ -106,12 +109,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let listeners = listeners.clone();
async move {
match (req.method(), req.uri().path()) {
@ -122,8 +129,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
(&Method::POST, "/agents/v1/chat/completions") => {
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(
req,
router_service,
fully_qualified_url,
agents_list,
listeners,
)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models" | "/agents/v1/models") => {
Ok(list_models(llm_providers).await)
}
// hack for now to get openw-web-ui to work
(&Method::OPTIONS, "/v1/models" | "/agents/v1/models") => {
let mut response = Response::new(empty());
*response.status_mut() = StatusCode::NO_CONTENT;
response
@ -147,6 +170,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(response)
}
_ => {
debug!("No route for {} {}", req.method(), req.uri().path());
let mut not_found = Response::new(empty());
*not_found.status_mut() = StatusCode::NOT_FOUND;
Ok(not_found)

View file

@ -79,7 +79,13 @@ impl RouterService {
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<(String, String)>> {
if !self.llm_usage_defined {
if messages.is_empty() {
return Ok(None);
}
if (usage_preferences.is_none() || usage_preferences.as_ref().unwrap().len() < 2)
&& !self.llm_usage_defined
{
return Ok(None);
}

View file

@ -9,7 +9,7 @@ use crate::api::open_ai::{
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Routing {
pub llm_provider: Option<String>,
pub model_provider: Option<String>,
pub model: Option<String>,
}
@ -18,11 +18,34 @@ pub struct ModelAlias {
pub target: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Agent {
pub id: String,
pub kind: Option<String>,
pub url: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentFilterChain {
pub id: String,
pub default: Option<bool>,
pub description: Option<String>,
pub filter_chain: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
pub name: String,
pub router: Option<String>,
pub agents: Option<Vec<AgentFilterChain>>,
pub port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Configuration {
pub version: String,
pub endpoints: Option<HashMap<String, Endpoint>>,
pub llm_providers: Vec<LlmProvider>,
pub model_providers: Vec<LlmProvider>,
pub model_aliases: Option<HashMap<String, ModelAlias>>,
pub overrides: Option<Overrides>,
pub system_prompt: Option<String>,
@ -33,6 +56,8 @@ pub struct Configuration {
pub tracing: Option<Tracing>,
pub mode: Option<GatewayMode>,
pub routing: Option<Routing>,
pub agents: Option<Vec<Agent>>,
pub listeners: Vec<Listener>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]

View file

@ -30,3 +30,4 @@ pub const HALLUCINATION_TEMPLATE: &str =
pub const OTEL_COLLECTOR_HTTP: &str = "opentelemetry_collector_http";
pub const OTEL_POST_PATH: &str = "/v1/traces";
pub const LLM_ROUTE_HEADER: &str = "x-arch-llm-route";
pub const ENVOY_RETRY_HEADER: &str = "x-envoy-max-retries";

View file

@ -74,7 +74,7 @@ impl RootContext for FilterContext {
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
self.overrides = Rc::new(config.overrides);
match config.llm_providers.try_into() {
match config.model_providers.try_into() {
Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)),
Err(err) => panic!("{err}"),
}