diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 1f33b4c2..524f3cd9 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -104,6 +104,16 @@ version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dde20b3d026af13f561bdd0f15edf01fc734f0dafcedbaf42bba506a9517f223" +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -216,6 +226,7 @@ dependencies = [ "http-body-util", "hyper 1.6.0", "hyper-util", + "mockito", "opentelemetry", "opentelemetry-http", "opentelemetry-otlp", @@ -321,6 +332,15 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67ba02a97a2bd10f4b59b25c7973101c79642302776489e030cd13cdab09ed15" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "common" version = "0.1.0" @@ -1731,6 +1751,30 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mockito" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7760e0e418d9b7e5777c0374009ca4c93861b9066f18cb334a20ce50ab63aa48" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "hyper 1.6.0", + "hyper-util", + "log", + "rand 0.9.1", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "more-asserts" version = "0.3.1" @@ -2817,6 +2861,12 @@ dependencies = [ "libc", ] +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "slab" version = "0.4.9" diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 23f3e8b1..721870fa 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -28,6 +28,12 @@ serde_with = "3.13.0" serde_yaml = "0.9.34" thiserror = "2.0.12" tokio = { version = "1.44.2", features = ["full"] } +tokio-stream = "0.1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dev-dependencies] +mockito = "1.0" tokio-stream = "0.1.17" tracing = "0.1.41" tracing-opentelemetry = "0.30.0" diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index f362a14d..cb4052d3 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -1,25 +1,30 @@ use std::sync::Arc; use bytes::Bytes; -use common::configuration::{AgentPipeline, ModelUsagePreference, RoutingPreference}; -use common::consts::ARCH_UPSTREAM_HOST_HEADER; use hermesllm::apis::openai::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full, StreamBody}; -use hyper::body::Frame; -use hyper::header::{self}; -use hyper::{Request, Response, StatusCode}; -use tokio::sync::mpsc; -use tokio_stream::wrappers::ReceiverStream; -use tokio_stream::StreamExt; +use http_body_util::BodyExt; +use hyper::{Request, Response}; use tracing::{debug, info, warn}; use crate::router::llm_router::RouterService; +use super::agent_selector::{AgentSelector, AgentSelectionError}; +use super::pipeline_processor::{PipelineProcessor, PipelineError}; +use super::response_handler::ResponseHandler; -fn full>(chunk: T) -> BoxBody { - Full::new(chunk.into()) - .map_err(|never| match never {}) - .boxed() +/// Main errors for agent chat completions +#[derive(Debug, thiserror::Error)] +pub enum AgentChatError { + #[error("Agent selection error: {0}")] + Selection(#[from] AgentSelectionError), + #[error("Pipeline processing error: {0}")] + Pipeline(#[from] PipelineError), + #[error("Response handling error: {0}")] + Response(#[from] super::response_handler::ResponseError), + #[error("Request parsing error: {0}")] + RequestParsing(#[from] serde_json::Error), + #[error("HTTP error: {0}")] + Http(#[from] hyper::Error), } pub async fn agent_chat( @@ -29,25 +34,44 @@ pub async fn agent_chat( agents_list: Arc>>>, listeners: Arc>>, ) -> Result>, hyper::Error> { - // find listener that is running at port 8001 for agents - let listener_name = request.headers().get("x-arch-agent-listener-name"); + match handle_agent_chat(request, router_service, agents_list, listeners).await { + Ok(response) => Ok(response), + Err(err) => { + warn!("Agent chat error: {}", err); + Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err))) + } + } +} + +async fn handle_agent_chat( + request: Request, + router_service: Arc, + agents_list: Arc>>>, + listeners: Arc>>, +) -> Result>, AgentChatError> { + // Initialize services + let agent_selector = AgentSelector::new(router_service); + let pipeline_processor = PipelineProcessor::default(); + let response_handler = ResponseHandler::new(); + + // Extract listener name from headers + let listener_name = request + .headers() + .get("x-arch-agent-listener-name") + .and_then(|name| name.to_str().ok()); + + // Find the appropriate listener let listener = { let listeners = listeners.read().await; - listeners - .iter() - .find(|l| { - listener_name - .and_then(|name| name.to_str().ok()) - .map(|name| l.name == name) - .unwrap_or(false) - }) - .cloned() - } - .unwrap(); + agent_selector + .find_listener(listener_name, &listeners) + .await? + }; info!("Handling request for listener: {}", listener.name); - let mut request_headers = request.headers().clone(); + // Parse request body + let request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); debug!( @@ -56,269 +80,64 @@ pub async fn agent_chat( ); let chat_completions_request: ChatCompletionsRequest = - match serde_json::from_slice(&chat_request_bytes) { - Ok(req) => req, - Err(err) => { - warn!( - "Failed to parse request body as ChatCompletionsRequest: {}", - err - ); - let err_msg = format!("Failed to parse request body: {}", err); - let mut bad_request = Response::new(full(err_msg)); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - }; - - let agent_name_map = { - let agents = agents_list.read().await; - let agents = agents.as_ref().unwrap(); - let mut map = std::collections::HashMap::new(); - for agent in agents.iter() { - map.insert(agent.name.clone(), agent.clone()); - } - map - }; + serde_json::from_slice(&chat_request_bytes).map_err(|err| { + warn!("Failed to parse request body as ChatCompletionsRequest: {}", err); + AgentChatError::RequestParsing(err) + })?; + // Extract trace parent for routing let trace_parent = request_headers .iter() - .find(|(ty, _)| ty.as_str() == "traceparent") + .find(|(key, _)| key.as_str() == "traceparent") .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); - let agents_usage_preferences: Vec = - convert_agent_description_to_routing_preferences(&listener.agents.as_ref().unwrap()); + // Select appropriate agent using arch router llm model + let selected_agent = agent_selector + .select_agent( + &chat_completions_request.messages, + &listener, + trace_parent, + ) + .await?; - debug!( - "Agents usage preferences for agent routing: {:?}", - agents_usage_preferences - ); + debug!("Processing agent pipeline: {}", selected_agent.name); - let agent_pipeline = match agents_usage_preferences.len() > 1 { - false => { - debug!("Only one agent available, skipping routing"); - listener.agents.as_ref().unwrap()[0].clone() - } - true => { - let selected_agent = match router_service - .determine_route( - &chat_completions_request.messages, - trace_parent.clone(), - Some(agents_usage_preferences), - ) - .await - { - Ok(route) => { - match route { - Some((_, agent_name)) => { - debug!("Determined agent: {}", agent_name); - listener - .agents - .as_ref() - .unwrap() - .iter() - .find(|a| a.name == agent_name) - .cloned() - // selected agent must exist in the agent map - .unwrap() - } - None => { - debug!("No agent determined using routing preferences, using default agent"); - listener - .agents - .as_ref() - .unwrap() - .iter() - .find(|a| a.default.unwrap_or(false)) - .cloned() - .unwrap_or_else(|| { - warn!( - "No default agent found, routing request to first agent: {}", - listener.agents.as_ref().unwrap()[0].name - ); - listener.agents.as_ref().unwrap()[0].clone() - }) - } - } - } - Err(err) => { - let err_msg = format!("Failed to determine route: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - selected_agent - } + // Create agent map for pipeline processing + let agent_map = { + let agents = agents_list.read().await; + let agents = agents.as_ref().unwrap(); + agent_selector.create_agent_map(agents) }; - debug!("Processing agent pipeline: {}", agent_pipeline.name); + // Process the filter chain + let processed_messages = pipeline_processor + .process_filter_chain( + &chat_completions_request, + &selected_agent, + &agent_map, + &request_headers, + ) + .await?; - let mut chat_completions_history = chat_completions_request.messages.clone(); + // Get terminal agent and send final response + let terminal_agent_name = selected_agent.filter_chain.last().unwrap(); + let terminal_agent = agent_map.get(terminal_agent_name).unwrap(); - request_headers.remove(header::CONTENT_LENGTH); - - let filter_chain_without_terminal_agent = - &agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len() - 1]; - - for agent_name in filter_chain_without_terminal_agent { - debug!("Processing agent: {}", agent_name); - let agent = agent_name_map.get(agent_name).unwrap(); - debug!("Agent details: {:?}", agent); - - let mut request = chat_completions_request.clone(); - request.messages = chat_completions_history.clone(); - - let request_str = serde_json::to_string(&request).unwrap(); - debug!("Sending request to agent {}", agent_name); - - let mut agent_request_headers = request_headers.clone(); - agent_request_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(agent.name.as_str()).unwrap(), - ); - - let response = match reqwest::Client::new() - .post("http://localhost:11000/v1/chat/completions") - .headers(agent_request_headers) - .body(request_str) - .send() - .await - { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let response_bytes = match response.bytes().await { - Ok(bytes) => bytes, - Err(err) => { - let err_msg = format!("Failed to read response bytes: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let chat_completions_response: hermesllm::apis::openai::ChatCompletionsResponse = - match serde_json::from_slice(&response_bytes) { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to parse response body: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - let response_str = chat_completions_response.choices[0] - .message - .content - .clone() - .unwrap(); - - debug!("Received response from agent {}", agent_name); - - chat_completions_history = serde_json::from_str(response_str.as_str()).unwrap_or(vec![]); - } - - let terminal_agent_name = agent_pipeline.filter_chain.last().unwrap(); - let terminal_agent = agent_name_map.get(terminal_agent_name).unwrap(); debug!("Processing terminal agent: {}", terminal_agent_name); debug!("Terminal agent details: {:?}", terminal_agent); - let mut request = chat_completions_request.clone(); - request.messages = chat_completions_history.clone(); + let llm_response = pipeline_processor + .send_terminal_request( + &processed_messages, + &chat_completions_request, + terminal_agent, + &request_headers, + ) + .await?; - let request_str = serde_json::to_string(&request).unwrap(); - debug!("Sending request to agent {}", terminal_agent_name); - - let mut agent_request_headers = request_headers.clone(); - agent_request_headers.insert( - ARCH_UPSTREAM_HOST_HEADER, - hyper::header::HeaderValue::from_str(terminal_agent.name.as_str()).unwrap(), - ); - - let llm_response = match reqwest::Client::new() - .post("http://localhost:11000/v1/chat/completions") - .headers(agent_request_headers) - .body(request_str) - .send() + // Create streaming response + response_handler + .create_streaming_response(llm_response) .await - { - Ok(res) => res, - Err(err) => { - let err_msg = format!("Failed to send request: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); - } - }; - - // copy over the headers from the original response - let response_headers = llm_response.headers().clone(); - let mut response = Response::builder(); - let headers = response.headers_mut().unwrap(); - for (header_name, header_value) in response_headers.iter() { - headers.insert(header_name, header_value.clone()); - } - - // channel to create async stream - let (tx, rx) = mpsc::channel::(16); - - // Spawn a task to send data as it becomes available - tokio::spawn(async move { - let mut byte_stream = llm_response.bytes_stream(); - - while let Some(item) = byte_stream.next().await { - let item = match item { - Ok(item) => item, - Err(err) => { - warn!("Error receiving chunk: {:?}", err); - break; - } - }; - - if tx.send(item).await.is_err() { - warn!("Receiver dropped"); - break; - } - } - }); - - let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); - - let stream_body = BoxBody::new(StreamBody::new(stream)); - - match response.body(stream_body) { - Ok(response) => Ok(response), - Err(err) => { - let err_msg = format!("Failed to create response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - Ok(internal_error) - } - } -} - -fn convert_agent_description_to_routing_preferences( - agents: &Vec, -) -> Vec { - agents - .iter() - .map(|agent| ModelUsagePreference { - model: agent.name.clone(), - routing_preferences: vec![RoutingPreference { - name: agent.name.clone(), - description: agent - .description - .as_ref() - .unwrap_or(&"".to_string()) - .clone(), - }], - }) - .collect() + .map_err(AgentChatError::from) } diff --git a/crates/brightstaff/src/handlers/agent_selector.rs b/crates/brightstaff/src/handlers/agent_selector.rs new file mode 100644 index 00000000..cf833954 --- /dev/null +++ b/crates/brightstaff/src/handlers/agent_selector.rs @@ -0,0 +1,296 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use common::configuration::{ + Agent, AgentPipeline, Listener, ModelUsagePreference, RoutingPreference, +}; +use hermesllm::apis::openai::Message; +use tracing::{debug, warn}; + +use crate::router::llm_router::RouterService; + +/// Errors that can occur during agent selection +#[derive(Debug, thiserror::Error)] +pub enum AgentSelectionError { + #[error("Listener not found for name: {0}")] + ListenerNotFound(String), + #[error("No agents configured for listener: {0}")] + NoAgentsConfigured(String), + #[error("Routing service error: {0}")] + RoutingError(String), + #[error("Default agent not found for listener: {0}")] + DefaultAgentNotFound(String), +} + +/// Service for selecting agents based on routing preferences and listener configuration +pub struct AgentSelector { + router_service: Arc, +} + +impl AgentSelector { + pub fn new(router_service: Arc) -> Self { + Self { router_service } + } + + /// Find listener by name from the request headers + pub async fn find_listener( + &self, + listener_name: Option<&str>, + listeners: &[common::configuration::Listener], + ) -> Result { + let listener = listeners + .iter() + .find(|l| listener_name.map(|name| l.name == name).unwrap_or(false)) + .cloned() + .ok_or_else(|| { + AgentSelectionError::ListenerNotFound( + listener_name.unwrap_or("unknown").to_string(), + ) + })?; + + Ok(listener) + } + + /// Create agent name to agent mapping for efficient lookup + pub fn create_agent_map(&self, agents: &[Agent]) -> HashMap { + agents + .iter() + .map(|agent| (agent.name.clone(), agent.clone())) + .collect() + } + + /// Select appropriate agent based on routing preferences + pub async fn select_agent( + &self, + messages: &[Message], + listener: &Listener, + trace_parent: Option, + ) -> Result { + let agents = listener + .agents + .as_ref() + .ok_or_else(|| AgentSelectionError::NoAgentsConfigured(listener.name.clone()))?; + + // If only one agent, skip routing + if agents.len() == 1 { + debug!("Only one agent available, skipping routing"); + return Ok(agents[0].clone()); + } + + let usage_preferences = self.convert_agent_description_to_routing_preferences(agents); + debug!( + "Agents usage preferences for agent routing str: {}", + serde_json::to_string(&usage_preferences).unwrap_or_default() + ); + + match self + .router_service + .determine_route(messages, trace_parent, Some(usage_preferences)) + .await + { + Ok(Some((_, agent_name))) => { + debug!("Determined agent: {}", agent_name); + let selected_agent = agents + .iter() + .find(|a| a.name == agent_name) + .cloned() + .ok_or_else(|| { + AgentSelectionError::RoutingError(format!( + "Selected agent '{}' not found in listener agents", + agent_name + )) + })?; + Ok(selected_agent) + } + Ok(None) => { + debug!("No agent determined using routing preferences, using default agent"); + self.get_default_agent(agents, &listener.name) + } + Err(err) => Err(AgentSelectionError::RoutingError(err.to_string())), + } + } + + /// Get the default agent or the first agent if no default is specified + fn get_default_agent( + &self, + agents: &[AgentPipeline], + listener_name: &str, + ) -> Result { + agents + .iter() + .find(|a| a.default.unwrap_or(false)) + .cloned() + .or_else(|| { + warn!( + "No default agent found, routing request to first agent: {}", + agents[0].name + ); + Some(agents[0].clone()) + }) + .ok_or_else(|| AgentSelectionError::DefaultAgentNotFound(listener_name.to_string())) + } + + /// Convert agent descriptions to routing preferences + fn convert_agent_description_to_routing_preferences( + &self, + agents: &[AgentPipeline], + ) -> Vec { + agents + .iter() + .map(|agent| ModelUsagePreference { + model: agent.name.clone(), + routing_preferences: vec![RoutingPreference { + name: agent.name.clone(), + description: agent.description.as_ref().unwrap_or(&String::new()).clone(), + }], + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use common::configuration::{AgentPipeline, Listener}; + + fn create_test_router_service() -> Arc { + Arc::new(RouterService::new( + vec![], // empty providers for testing + "http://localhost:8080".to_string(), + "test-model".to_string(), + "test-provider".to_string(), + )) + } + + fn create_test_agent(name: &str, description: &str, is_default: bool) -> AgentPipeline { + AgentPipeline { + name: name.to_string(), + description: Some(description.to_string()), + default: Some(is_default), + filter_chain: vec![name.to_string()], + } + } + + fn create_test_listener(name: &str, agents: Vec) -> Listener { + Listener { + name: name.to_string(), + agents: Some(agents), + port: 8080, + router: None, + } + } + + fn create_test_agent_struct(name: &str) -> Agent { + Agent { + name: name.to_string(), + kind: "test".to_string(), + endpoint: "http://localhost:8080".to_string(), + } + } + + #[tokio::test] + async fn test_find_listener_success() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let listener1 = create_test_listener("test-listener", vec![]); + let listener2 = create_test_listener("other-listener", vec![]); + let listeners = vec![listener1.clone(), listener2]; + + let result = selector + .find_listener(Some("test-listener"), &listeners) + .await; + + assert!(result.is_ok()); + assert_eq!(result.unwrap().name, "test-listener"); + } + + #[tokio::test] + async fn test_find_listener_not_found() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let listeners = vec![create_test_listener("other-listener", vec![])]; + + let result = selector + .find_listener(Some("nonexistent"), &listeners) + .await; + + assert!(result.is_err()); + matches!( + result.unwrap_err(), + AgentSelectionError::ListenerNotFound(_) + ); + } + + #[test] + fn test_create_agent_map() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let agents = vec![ + create_test_agent_struct("agent1"), + create_test_agent_struct("agent2"), + ]; + + let agent_map = selector.create_agent_map(&agents); + + assert_eq!(agent_map.len(), 2); + assert!(agent_map.contains_key("agent1")); + assert!(agent_map.contains_key("agent2")); + } + + #[test] + fn test_convert_agent_description_to_routing_preferences() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let agents = vec![ + create_test_agent("agent1", "First agent description", true), + create_test_agent("agent2", "Second agent description", false), + ]; + + let preferences = selector.convert_agent_description_to_routing_preferences(&agents); + + assert_eq!(preferences.len(), 2); + assert_eq!(preferences[0].model, "agent1"); + assert_eq!(preferences[0].routing_preferences[0].name, "agent1"); + assert_eq!( + preferences[0].routing_preferences[0].description, + "First agent description" + ); + } + + #[test] + fn test_get_default_agent() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let agents = vec![ + create_test_agent("agent1", "First agent", false), + create_test_agent("agent2", "Default agent", true), + create_test_agent("agent3", "Third agent", false), + ]; + + let result = selector.get_default_agent(&agents, "test-listener"); + + assert!(result.is_ok()); + assert_eq!(result.unwrap().name, "agent2"); + } + + #[test] + fn test_get_default_agent_fallback_to_first() { + let router_service = create_test_router_service(); + let selector = AgentSelector::new(router_service); + + let agents = vec![ + create_test_agent("agent1", "First agent", false), + create_test_agent("agent2", "Second agent", false), + ]; + + let result = selector.get_default_agent(&agents, "test-listener"); + + assert!(result.is_ok()); + assert_eq!(result.unwrap().name, "agent1"); + } +} diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs new file mode 100644 index 00000000..44a4bbf4 --- /dev/null +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -0,0 +1,143 @@ +use std::sync::Arc; + +use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent}; +use hyper::header::HeaderMap; + +use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError}; +use crate::handlers::pipeline_processor::{PipelineProcessor}; +use crate::handlers::response_handler::ResponseHandler; +use crate::router::llm_router::RouterService; + +/// Integration test that demonstrates the modular agent chat flow +/// This test shows how the three main components work together: +/// 1. AgentSelector - selects the appropriate agent based on routing +/// 2. PipelineProcessor - executes the agent pipeline +/// 3. ResponseHandler - handles response streaming +#[cfg(test)] +mod integration_tests { + use super::*; + use common::configuration::{Agent, AgentPipeline, Listener}; + + fn create_test_router_service() -> Arc { + Arc::new(RouterService::new( + vec![], // empty providers for testing + "http://localhost:8080".to_string(), + "test-model".to_string(), + "test-provider".to_string(), + )) + } + + fn create_test_message(role: Role, content: &str) -> Message { + Message { + role, + content: MessageContent::Text(content.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + #[tokio::test] + async fn test_modular_agent_chat_flow() { + // Setup services + let router_service = create_test_router_service(); + let agent_selector = AgentSelector::new(router_service); + let pipeline_processor = PipelineProcessor::default(); + + // Create test data + let agents = vec![ + Agent { + name: "filter-agent".to_string(), + kind: "filter".to_string(), + endpoint: "http://localhost:8081".to_string(), + }, + Agent { + name: "terminal-agent".to_string(), + kind: "terminal".to_string(), + endpoint: "http://localhost:8082".to_string(), + }, + ]; + + let agent_pipeline = AgentPipeline { + name: "test-pipeline".to_string(), + filter_chain: vec!["filter-agent".to_string(), "terminal-agent".to_string()], + description: Some("Test pipeline".to_string()), + default: Some(true), + }; + + let listener = Listener { + name: "test-listener".to_string(), + agents: Some(vec![agent_pipeline.clone()]), + port: 8080, + router: None, + }; + + let listeners = vec![listener]; + let messages = vec![create_test_message(Role::User, "Hello world!")]; + + // Test 1: Agent Selection + let selected_listener = agent_selector + .find_listener(Some("test-listener"), &listeners) + .await; + + assert!(selected_listener.is_ok()); + let listener = selected_listener.unwrap(); + assert_eq!(listener.name, "test-listener"); + + // Test 2: Agent Map Creation + let agent_map = agent_selector.create_agent_map(&agents); + assert_eq!(agent_map.len(), 2); + assert!(agent_map.contains_key("filter-agent")); + assert!(agent_map.contains_key("terminal-agent")); + + // Test 3: Pipeline Processing (empty filter chain for testing) + let request = ChatCompletionsRequest { + messages: messages.clone(), + model: "test-model".to_string(), + ..Default::default() + }; + + // Create a pipeline with only terminal agent to avoid network calls + let test_pipeline = AgentPipeline { + name: "test-pipeline".to_string(), + filter_chain: vec!["terminal-agent".to_string()], + description: None, + default: None, + }; + + let headers = HeaderMap::new(); + let result = pipeline_processor + .process_filter_chain(&request, &test_pipeline, &agent_map, &headers) + .await; + + assert!(result.is_ok()); + let processed_messages = result.unwrap(); + assert_eq!(processed_messages.len(), 1); + + // Test 4: Error Response Creation + let error_response = ResponseHandler::create_bad_request("Test error"); + assert_eq!(error_response.status(), hyper::StatusCode::BAD_REQUEST); + + println!("✅ All modular components working correctly!"); + } + + #[tokio::test] + async fn test_error_handling_flow() { + let router_service = create_test_router_service(); + let agent_selector = AgentSelector::new(router_service); + + // Test listener not found + let result = agent_selector + .find_listener(Some("nonexistent"), &[]) + .await; + + assert!(result.is_err()); + assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_))); + + // Test error response creation + let error_response = ResponseHandler::create_internal_error("Pipeline failed"); + assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR); + + println!("✅ Error handling working correctly!"); + } +} diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 6fe1404d..93f64889 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,3 +1,9 @@ pub mod chat_completions; pub mod models; pub mod agent_chat_completions; +pub mod agent_selector; +pub mod pipeline_processor; +pub mod response_handler; + +#[cfg(test)] +mod integration_tests; diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs new file mode 100644 index 00000000..c04cac96 --- /dev/null +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -0,0 +1,248 @@ +use std::collections::HashMap; + +use common::configuration::{Agent, AgentPipeline}; +use common::consts::ARCH_UPSTREAM_HOST_HEADER; +use hermesllm::apis::openai::{ChatCompletionsRequest, Message}; +use hyper::header::HeaderMap; +use tracing::{debug, warn}; + +/// Errors that can occur during pipeline processing +#[derive(Debug, thiserror::Error)] +pub enum PipelineError { + #[error("HTTP request failed: {0}")] + RequestFailed(#[from] reqwest::Error), + #[error("Failed to parse response: {0}")] + ParseError(#[from] serde_json::Error), + #[error("Agent '{0}' not found in agent map")] + AgentNotFound(String), + #[error("No choices in response from agent '{0}'")] + NoChoicesInResponse(String), + #[error("No content in response from agent '{0}'")] + NoContentInResponse(String), +} + +/// Service for processing agent pipelines +pub struct PipelineProcessor { + client: reqwest::Client, + llm_endpoint: String, +} + +impl Default for PipelineProcessor { + fn default() -> Self { + Self { + client: reqwest::Client::new(), + llm_endpoint: "http://localhost:11000/v1/chat/completions".to_string(), + } + } +} + +impl PipelineProcessor { + pub fn new(llm_endpoint: String) -> Self { + Self { + client: reqwest::Client::new(), + llm_endpoint, + } + } + + /// Process the filter chain of agents (all except the terminal agent) + pub async fn process_filter_chain( + &self, + initial_request: &ChatCompletionsRequest, + agent_pipeline: &AgentPipeline, + agent_map: &HashMap, + request_headers: &HeaderMap, + ) -> Result, PipelineError> { + let mut chat_completions_history = initial_request.messages.clone(); + + let filter_chain_without_terminal = + &agent_pipeline.filter_chain[..agent_pipeline.filter_chain.len().saturating_sub(1)]; + + for agent_name in filter_chain_without_terminal { + debug!("Processing filter agent: {}", agent_name); + + let agent = agent_map + .get(agent_name) + .ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?; + + debug!("Agent details: {:?}", agent); + + let response_content = self + .send_agent_request( + &chat_completions_history, + initial_request, + agent, + request_headers, + ) + .await?; + + debug!("Received response from filter agent {}", agent_name); + + // Parse the response content as new message history + chat_completions_history = serde_json::from_str(&response_content).map_err(|err| { + warn!( + "Failed to parse response from agent {}, response: {}", + agent_name, response_content + ); + err + })?; + } + + Ok(chat_completions_history) + } + + /// Send request to a specific agent and return the response content + async fn send_agent_request( + &self, + messages: &[Message], + original_request: &ChatCompletionsRequest, + agent: &Agent, + request_headers: &HeaderMap, + ) -> Result { + let mut request = original_request.clone(); + request.messages = messages.to_vec(); + + let request_body = serde_json::to_string(&request)?; + debug!("Sending request to agent {}", agent.name); + + let mut agent_headers = request_headers.clone(); + agent_headers.remove(hyper::header::CONTENT_LENGTH); + agent_headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(&agent.name) + .map_err(|_| PipelineError::AgentNotFound(agent.name.clone()))?, + ); + + let response = self + .client + .post(&self.llm_endpoint) + .headers(agent_headers) + .body(request_body) + .send() + .await?; + + let response_bytes = response.bytes().await?; + + // Parse the response as JSON to extract the content + let response_json: serde_json::Value = serde_json::from_slice(&response_bytes)?; + + let content = response_json + .get("choices") + .and_then(|choices| choices.as_array()) + .and_then(|choices| choices.first()) + .and_then(|choice| choice.get("message")) + .and_then(|message| message.get("content")) + .and_then(|content| content.as_str()) + .ok_or_else(|| PipelineError::NoContentInResponse(agent.name.clone()))? + .to_string(); + + Ok(content) + } + + /// Send request to terminal agent and return the raw response for streaming + pub async fn send_terminal_request( + &self, + messages: &[Message], + original_request: &ChatCompletionsRequest, + terminal_agent: &Agent, + request_headers: &HeaderMap, + ) -> Result { + let mut request = original_request.clone(); + request.messages = messages.to_vec(); + + let request_body = serde_json::to_string(&request)?; + debug!("Sending request to terminal agent {}", terminal_agent.name); + + let mut agent_headers = request_headers.clone(); + agent_headers.remove(hyper::header::CONTENT_LENGTH); + agent_headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(&terminal_agent.name) + .map_err(|_| PipelineError::AgentNotFound(terminal_agent.name.clone()))?, + ); + + let response = self + .client + .post(&self.llm_endpoint) + .headers(agent_headers) + .body(request_body) + .send() + .await?; + + Ok(response) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hermesllm::apis::openai::{Message, MessageContent, Role}; + use std::collections::HashMap; + + fn create_test_message(role: Role, content: &str) -> Message { + Message { + role, + content: MessageContent::Text(content.to_string()), + name: None, + tool_calls: None, + tool_call_id: None, + } + } + + fn create_test_pipeline(agents: Vec<&str>) -> AgentPipeline { + AgentPipeline { + name: "test-pipeline".to_string(), + filter_chain: agents.iter().map(|s| s.to_string()).collect(), + description: None, + default: None, + } + } + + #[tokio::test] + async fn test_process_empty_filter_chain() { + let processor = PipelineProcessor::default(); + let agent_map = HashMap::new(); + let request_headers = HeaderMap::new(); + + let initial_request = ChatCompletionsRequest { + messages: vec![create_test_message(Role::User, "Hello")], + model: "test-model".to_string(), + ..Default::default() + }; + + // Pipeline with only terminal agent (no filter chain) + let pipeline = create_test_pipeline(vec!["terminal-agent"]); + + let result = processor + .process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers) + .await; + + assert!(result.is_ok()); + let messages = result.unwrap(); + assert_eq!(messages.len(), 1); + if let MessageContent::Text(text) = &messages[0].content { + assert_eq!(text, "Hello"); + } + } + + #[tokio::test] + async fn test_agent_not_found_error() { + let processor = PipelineProcessor::default(); + let agent_map = HashMap::new(); + let request_headers = HeaderMap::new(); + + let initial_request = ChatCompletionsRequest { + messages: vec![create_test_message(Role::User, "Hello")], + model: "test-model".to_string(), + ..Default::default() + }; + + let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]); + + let result = processor + .process_filter_chain(&initial_request, &pipeline, &agent_map, &request_headers) + .await; + + assert!(result.is_err()); + matches!(result.unwrap_err(), PipelineError::AgentNotFound(_)); + } +} diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response_handler.rs new file mode 100644 index 00000000..3d16a60c --- /dev/null +++ b/crates/brightstaff/src/handlers/response_handler.rs @@ -0,0 +1,160 @@ +use bytes::Bytes; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Full, StreamBody}; +use hyper::body::Frame; +use hyper::{Response, StatusCode}; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; +use tracing::warn; + +/// Errors that can occur during response handling +#[derive(Debug, thiserror::Error)] +pub enum ResponseError { + #[error("Failed to create response: {0}")] + ResponseCreationFailed(#[from] hyper::http::Error), + #[error("Stream error: {0}")] + StreamError(String), +} + +/// Service for handling HTTP responses and streaming +pub struct ResponseHandler; + +impl ResponseHandler { + pub fn new() -> Self { + Self + } + + /// Create a full response body from bytes + pub fn create_full_body>(chunk: T) -> BoxBody { + Full::new(chunk.into()) + .map_err(|never| match never {}) + .boxed() + } + + /// Create an error response with a given status code and message + pub fn create_error_response( + status: StatusCode, + message: &str, + ) -> Response> { + let mut response = Response::new(Self::create_full_body(message.to_string())); + *response.status_mut() = status; + response + } + + /// Create a bad request response + pub fn create_bad_request(message: &str) -> Response> { + Self::create_error_response(StatusCode::BAD_REQUEST, message) + } + + /// Create an internal server error response + pub fn create_internal_error(message: &str) -> Response> { + Self::create_error_response(StatusCode::INTERNAL_SERVER_ERROR, message) + } + + /// Create a streaming response from a reqwest response + pub async fn create_streaming_response( + &self, + llm_response: reqwest::Response, + ) -> Result>, ResponseError> { + // Copy headers from the original response + let response_headers = llm_response.headers(); + let mut response_builder = Response::builder(); + + let headers = response_builder.headers_mut().ok_or_else(|| { + ResponseError::StreamError("Failed to get mutable headers".to_string()) + })?; + + for (header_name, header_value) in response_headers.iter() { + headers.insert(header_name, header_value.clone()); + } + + // Create channel for async streaming + let (tx, rx) = mpsc::channel::(16); + + // Spawn task to stream data + tokio::spawn(async move { + let mut byte_stream = llm_response.bytes_stream(); + + while let Some(item) = byte_stream.next().await { + let chunk = match item { + Ok(chunk) => chunk, + Err(err) => { + warn!("Error receiving chunk: {:?}", err); + break; + } + }; + + if tx.send(chunk).await.is_err() { + warn!("Receiver dropped"); + break; + } + } + }); + + let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); + let stream_body = BoxBody::new(StreamBody::new(stream)); + + response_builder + .body(stream_body) + .map_err(ResponseError::from) + } +} + +impl Default for ResponseHandler { + fn default() -> Self { + Self::new() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use hyper::StatusCode; + + #[test] + fn test_create_bad_request() { + let response = ResponseHandler::create_bad_request("Invalid request"); + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_create_internal_error() { + let response = ResponseHandler::create_internal_error("Server error"); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } + + #[test] + fn test_create_error_response() { + let response = + ResponseHandler::create_error_response(StatusCode::NOT_FOUND, "Resource not found"); + assert_eq!(response.status(), StatusCode::NOT_FOUND); + } + + #[tokio::test] + async fn test_create_streaming_response_with_mock() { + use mockito::Server; + + let mut server = Server::new_async().await; + let mock = server + .mock("GET", "/test") + .with_status(200) + .with_header("content-type", "text/plain") + .with_body("streaming response") + .create_async() + .await; + + let client = reqwest::Client::new(); + let llm_response = client.get(&(server.url() + "/test")).send().await.unwrap(); + + let handler = ResponseHandler::new(); + let result = handler.create_streaming_response(llm_response).await; + + mock.assert_async().await; + assert!(result.is_ok()); + + let response = result.unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert!(response.headers().contains_key("content-type")); + } +}