cargo fmt brightstaff

This commit is contained in:
Adil Hafeez 2025-09-17 11:08:08 -07:00
parent 2229f0d4d4
commit 0c6600ac47
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
6 changed files with 74 additions and 56 deletions

View file

@ -7,10 +7,10 @@ use http_body_util::BodyExt;
use hyper::{Request, Response};
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
use super::agent_selector::{AgentSelector, AgentSelectionError};
use super::pipeline_processor::{PipelineProcessor, PipelineError};
use super::agent_selector::{AgentSelectionError, AgentSelector};
use super::pipeline_processor::{PipelineError, PipelineProcessor};
use super::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
/// Main errors for agent chat completions
#[derive(Debug, thiserror::Error)]
@ -38,7 +38,10 @@ pub async fn agent_chat(
Ok(response) => Ok(response),
Err(err) => {
warn!("Agent chat error: {}", err);
Ok(ResponseHandler::create_internal_error(&format!("Internal error: {}", err)))
Ok(ResponseHandler::create_internal_error(&format!(
"Internal error: {}",
err
)))
}
}
}
@ -81,7 +84,10 @@ async fn handle_agent_chat(
let chat_completions_request: ChatCompletionsRequest =
serde_json::from_slice(&chat_request_bytes).map_err(|err| {
warn!("Failed to parse request body as ChatCompletionsRequest: {}", err);
warn!(
"Failed to parse request body as ChatCompletionsRequest: {}",
err
);
AgentChatError::RequestParsing(err)
})?;
@ -93,11 +99,7 @@ async fn handle_agent_chat(
// Select appropriate agent using arch router llm model
let selected_agent = agent_selector
.select_agent(
&chat_completions_request.messages,
&listener,
trace_parent,
)
.select_agent(&chat_completions_request.messages, &listener, trace_parent)
.await?;
debug!("Processing agent pipeline: {}", selected_agent.name);

View file

@ -1,5 +1,3 @@
use std::sync::Arc;
use std::collections::HashMap;
use bytes::Bytes;
use common::configuration::{ModelAlias, ModelUsagePreference};
use common::consts::ARCH_PROVIDER_HINT_HEADER;
@ -11,6 +9,8 @@ use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
use hyper::header::{self};
use hyper::{Request, Response, StatusCode};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
@ -30,14 +30,19 @@ pub async fn chat(
full_qualified_llm_provider_url: String,
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
debug!(
"Received request body (raw utf8): {}",
String::from_utf8_lossy(&chat_request_bytes)
);
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
let mut client_request = match ProviderRequestType::try_from((
&chat_request_bytes[..],
&SupportedAPIs::from_endpoint(request_path.as_str()).unwrap(),
)) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
@ -77,7 +82,10 @@ pub async fn chat(
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
match ProviderRequestType::try_from((
client_request,
&SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions),
)) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(ProviderRequestType::MessagesRequest(_)) => {
// This should not happen after conversion to OpenAI format
@ -86,9 +94,12 @@ pub async fn chat(
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
return Ok(bad_request);
},
}
Err(err) => {
warn!("Failed to convert request to ChatCompletionsRequest: {}", err);
warn!(
"Failed to convert request to ChatCompletionsRequest: {}",
err
);
let err_msg = format!("Failed to convert request: {}", err);
let mut bad_request = Response::new(full(err_msg));
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
@ -106,24 +117,22 @@ pub async fn chat(
.find(|(ty, _)| ty.as_str() == "traceparent")
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences_str: Option<String> = routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
});
let usage_preferences: Option<Vec<ModelUsagePreference>> = usage_preferences_str
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
let latest_message_for_log =
chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
let latest_message_for_log = chat_completions_request_for_arch_router
.messages
.last()
.map_or("None".to_string(), |msg| {
msg.content.to_string().replace('\n', "\\n")
});
const MAX_MESSAGE_LENGTH: usize = 50;
let latest_message_for_log = if latest_message_for_log.len() > MAX_MESSAGE_LENGTH {
@ -152,12 +161,11 @@ pub async fn chat(
Ok(route) => match route {
Some((_, model_name)) => model_name,
None => {
debug!(
debug!(
"No route determined, using default model from request: {}",
chat_completions_request_for_arch_router.model
);
chat_completions_request_for_arch_router.model.clone()
}
},
Err(err) => {

View file

@ -1,10 +1,10 @@
use std::sync::Arc;
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, Role, MessageContent};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use hyper::header::HeaderMap;
use crate::handlers::agent_selector::{AgentSelector, AgentSelectionError};
use crate::handlers::pipeline_processor::{PipelineProcessor};
use crate::handlers::agent_selector::{AgentSelectionError, AgentSelector};
use crate::handlers::pipeline_processor::PipelineProcessor;
use crate::handlers::response_handler::ResponseHandler;
use crate::router::llm_router::RouterService;
@ -127,16 +127,20 @@ mod integration_tests {
let agent_selector = AgentSelector::new(router_service);
// Test listener not found
let result = agent_selector
.find_listener(Some("nonexistent"), &[])
.await;
let result = agent_selector.find_listener(Some("nonexistent"), &[]).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), AgentSelectionError::ListenerNotFound(_)));
assert!(matches!(
result.unwrap_err(),
AgentSelectionError::ListenerNotFound(_)
));
// Test error response creation
let error_response = ResponseHandler::create_internal_error("Pipeline failed");
assert_eq!(error_response.status(), hyper::StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(
error_response.status(),
hyper::StatusCode::INTERNAL_SERVER_ERROR
);
println!("✅ Error handling working correctly!");
}

View file

@ -1,7 +1,7 @@
pub mod chat_completions;
pub mod models;
pub mod agent_chat_completions;
pub mod agent_selector;
pub mod chat_completions;
pub mod models;
pub mod pipeline_processor;
pub mod response_handler;

View file

@ -71,8 +71,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001".to_string());
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address);
@ -99,7 +99,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let model_aliases = Arc::new(arch_config.model_aliases.clone());
loop {
let (stream, _) = listener.accept().await?;
let peer_addr = stream.peer_addr()?;
@ -113,7 +112,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let agents_list = agents_list.clone();
let listeners = listeners.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_url = llm_provider_url.clone();
@ -125,16 +123,24 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH) => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url, model_aliases)
.with_context(parent_cx)
.await
}
(&Method::POST, "/agents/v1/chat/completions") => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(req, router_service, fully_qualified_url, agents_list, listeners)
.with_context(parent_cx)
.await
let fully_qualified_url =
format!("{}{}", llm_provider_url, req.uri().path());
agent_chat(
req,
router_service,
fully_qualified_url,
agents_list,
listeners,
)
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {

View file

@ -1,9 +1,7 @@
use std::collections::HashMap;
use common::{
configuration::{ModelUsagePreference, RoutingPreference},
};
use hermesllm::apis::openai::{ChatCompletionsRequest, MessageContent, Message, Role};
use common::configuration::{ModelUsagePreference, RoutingPreference};
use hermesllm::apis::openai::{ChatCompletionsRequest, Message, MessageContent, Role};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};