diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index cb4052d3..8b0d9050 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -7,10 +7,10 @@ use http_body_util::BodyExt; use hyper::{Request, Response}; use tracing::{debug, info, warn}; -use crate::router::llm_router::RouterService; -use super::agent_selector::{AgentSelector, AgentSelectionError}; -use super::pipeline_processor::{PipelineProcessor, PipelineError}; +use super::agent_selector::{AgentSelectionError, AgentSelector}; +use super::pipeline_processor::{PipelineError, PipelineProcessor}; use super::response_handler::ResponseHandler; +use crate::router::llm_router::RouterService; /// Main errors for agent chat completions #[derive(Debug, thiserror::Error)] @@ -38,7 +38,10 @@ pub async fn agent_chat( Ok(response) => Ok(response), Err(err) => { warn!("Agent chat error: {}", err); - Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err))) + Ok(ResponseHandler::create_internal_error(&format!( + "Internal error: {}", + err + ))) } } } @@ -81,7 +84,10 @@ async fn handle_agent_chat( let chat_completions_request: ChatCompletionsRequest = serde_json::from_slice(&chat_request_bytes).map_err(|err| { - warn!("Failed to parse request body as ChatCompletionsRequest: {}", err); + warn!( + "Failed to parse request body as ChatCompletionsRequest: {}", + err + ); AgentChatError::RequestParsing(err) })?; @@ -93,11 +99,7 @@ async fn handle_agent_chat( // Select appropriate agent using arch router llm model let selected_agent = agent_selector - .select_agent( - &chat_completions_request.messages, - &listener, - trace_parent, - ) + .select_agent(&chat_completions_request.messages, &listener, trace_parent) .await?; debug!("Processing agent pipeline: {}", selected_agent.name); diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index b049e4f3..4ccbf48d 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; -use std::collections::HashMap; use bytes::Bytes; use common::configuration::{ModelAlias, ModelUsagePreference}; use common::consts::ARCH_PROVIDER_HINT_HEADER; @@ -11,6 +9,8 @@ use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; +use std::collections::HashMap; +use std::sync::Arc; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; @@ -30,14 +30,19 @@ pub async fn chat( full_qualified_llm_provider_url: String, model_aliases: Arc>>, ) -> Result>, hyper::Error> { - let request_path = request.uri().path().to_string(); let mut request_headers = request.headers().clone(); let chat_request_bytes = request.collect().await?.to_bytes(); - debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes)); + debug!( + "Received request body (raw utf8): {}", + String::from_utf8_lossy(&chat_request_bytes) + ); - let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) { + let mut client_request = match ProviderRequestType::try_from(( + &chat_request_bytes[..], + &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(), + )) { Ok(request) => request, Err(err) => { warn!("Failed to parse request as ProviderRequestType: {}", err); @@ -77,7 +82,10 @@ pub async fn chat( // Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original) let chat_completions_request_for_arch_router: ChatCompletionsRequest = - match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { + match ProviderRequestType::try_from(( + client_request, + &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions), + )) { Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok(ProviderRequestType::MessagesRequest(_)) => { // This should not happen after conversion to OpenAI format @@ -86,9 +94,12 @@ pub async fn chat( let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; return Ok(bad_request); - }, + } Err(err) => { - warn!("Failed to convert request to ChatCompletionsRequest: {}", err); + warn!( + "Failed to convert request to ChatCompletionsRequest: {}", + err + ); let err_msg = format!("Failed to convert request: {}", err); let mut bad_request = Response::new(full(err_msg)); *bad_request.status_mut() = StatusCode::BAD_REQUEST; @@ -106,24 +117,22 @@ pub async fn chat( .find(|(ty, _)| ty.as_str() == "traceparent") .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); - let usage_preferences_str: Option = - routing_metadata.as_ref().and_then(|metadata| { - metadata - .get("archgw_preference_config") - .map(|value| value.to_string()) - }); + let usage_preferences_str: Option = routing_metadata.as_ref().and_then(|metadata| { + metadata + .get("archgw_preference_config") + .map(|value| value.to_string()) + }); let usage_preferences: Option> = usage_preferences_str .as_ref() .and_then(|s| serde_yaml::from_str(s).ok()); - let latest_message_for_log = - chat_completions_request_for_arch_router - .messages - .last() - .map_or("None".to_string(), |msg| { - msg.content.to_string().replace('\n', "\\n") - }); + let latest_message_for_log = chat_completions_request_for_arch_router + .messages + .last() + .map_or("None".to_string(), |msg| { + msg.content.to_string().replace('\n', "\\n") + }); const MAX_MESSAGE_LENGTH: usize = 50; let latest_message_for_log = if latest_message_for_log.len() > MAX_MESSAGE_LENGTH { @@ -152,12 +161,11 @@ pub async fn chat( Ok(route) => match route { Some((_, model_name)) => model_name, None => { - debug!( + debug!( "No route determined, using default model from request: {}", chat_completions_request_for_arch_router.model ); chat_completions_request_for_arch_router.model.clone() - } }, Err(err) => { diff --git a/crates/brightstaff/src/handlers/integration_tests.rs b/crates/brightstaff/src/handlers/integration_tests.rs index 44a4bbf4..01e22629 100644 --- a/crates/brightstaff/src/handlers/integration_tests.rs +++ b/crates/brightstaff/src/handlers/integration_tests.rs @@ -1,10 +1,10 @@ use std::sync::Arc; -use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent}; +use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; use hyper::header::HeaderMap; -use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError}; -use crate::handlers::pipeline_processor::{PipelineProcessor}; +use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector}; +use crate::handlers::pipeline_processor::PipelineProcessor; use crate::handlers::response_handler::ResponseHandler; use crate::router::llm_router::RouterService; @@ -127,16 +127,20 @@ mod integration_tests { let agent_selector = AgentSelector::new(router_service); // Test listener not found - let result = agent_selector - .find_listener(Some("nonexistent"), &[]) - .await; + let result = agent_selector.find_listener(Some("nonexistent"), &[]).await; assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_))); + assert!(matches!( + result.unwrap_err(), + AgentSelectionError::ListenerNotFound(_) + )); // Test error response creation let error_response = ResponseHandler::create_internal_error("Pipeline failed"); - assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + error_response.status(), + hyper::StatusCode::INTERNAL_SERVER_ERROR + ); println!("✅ Error handling working correctly!"); } diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 93f64889..66c5449b 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,7 +1,7 @@ -pub mod chat_completions; -pub mod models; pub mod agent_chat_completions; pub mod agent_selector; +pub mod chat_completions; +pub mod models; pub mod pipeline_processor; pub mod response_handler; diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index c23f675a..adecd5c6 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -71,8 +71,8 @@ async fn main() -> Result<(), Box> { &serde_json::to_string(arch_config.as_ref()).unwrap() ); - let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT") - .unwrap_or_else(|_| "http://localhost:12001".to_string()); + let llm_provider_url = + env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); info!("llm provider url: {}", llm_provider_url); info!("listening on http://{}", bind_address); @@ -99,7 +99,6 @@ async fn main() -> Result<(), Box> { let model_aliases = Arc::new(arch_config.model_aliases.clone()); - loop { let (stream, _) = listener.accept().await?; let peer_addr = stream.peer_addr()?; @@ -113,7 +112,6 @@ async fn main() -> Result<(), Box> { let agents_list = agents_list.clone(); let listeners = listeners.clone(); let service = service_fn(move |req| { - let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); let llm_provider_url = llm_provider_url.clone(); @@ -125,16 +123,24 @@ async fn main() -> Result<(), Box> { async move { match (req.method(), req.uri().path()) { (&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => { - let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); + let fully_qualified_url = + format!("{}{}", llm_provider_url, req.uri().path()); chat(req, router_service, fully_qualified_url, model_aliases) .with_context(parent_cx) .await } (&Method::POST, "/agents/v1/chat/completions") => { - let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); - agent_chat(req, router_service, fully_qualified_url, agents_list, listeners) - .with_context(parent_cx) - .await + let fully_qualified_url = + format!("{}{}", llm_provider_url, req.uri().path()); + agent_chat( + req, + router_service, + fully_qualified_url, + agents_list, + listeners, + ) + .with_context(parent_cx) + .await } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 1c1c14ef..758cf83a 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,9 +1,7 @@ use std::collections::HashMap; -use common::{ - configuration::{ModelUsagePreference, RoutingPreference}, -}; -use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role}; +use common::configuration::{ModelUsagePreference, RoutingPreference}; +use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role}; use serde::{Deserialize, Serialize}; use tracing::{debug, warn};