diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index adfdce02..22722895 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -2,15 +2,17 @@ use std::sync::Arc; use std::time::Instant; use bytes::Bytes; +use common::llm_providers::LlmProviders; use hermesllm::apis::OpenAIMessage; use hermesllm::clients::SupportedAPIsFromClient; use hermesllm::providers::request::ProviderRequest; use hermesllm::ProviderRequestType; use http_body_util::combinators::BoxBody; use http_body_util::BodyExt; -use hyper::{Request, Response}; +use hyper::{Request, Response, StatusCode}; use opentelemetry::trace::get_active_span; use serde::ser::Error as SerError; +use tokio::sync::RwLock; use tracing::{debug, info, info_span, warn, Instrument}; use super::agent_selector::{AgentSelectionError, AgentSelector}; @@ -40,6 +42,7 @@ pub async fn agent_chat( _: String, agents_list: Arc>>>, listeners: Arc>>, + llm_providers: Arc>, ) -> Result>, hyper::Error> { // Extract request_id from headers or generate a new one let request_id: String = match request @@ -71,6 +74,7 @@ pub async fn agent_chat( orchestrator_service, agents_list, listeners, + llm_providers, request_id, ) .await @@ -155,6 +159,7 @@ async fn handle_agent_chat_inner( orchestrator_service: Arc, agents_list: Arc>>>, listeners: Arc>>, + llm_providers: Arc>, request_id: String, ) -> Result>, AgentFilterChainError> { // Initialize services @@ -221,16 +226,36 @@ async fn handle_agent_chat_inner( AgentFilterChainError::RequestParsing(serde_json::Error::custom(err_msg)) })?; - let client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { - Ok(request) => request, - Err(err) => { - warn!("failed to parse request as ProviderRequestType: {}", err); - let err_msg = format!("Failed to parse request: {}", err); - return Err(AgentFilterChainError::RequestParsing( - serde_json::Error::custom(err_msg), - )); + let mut client_request = + match ProviderRequestType::try_from((&chat_request_bytes[..], &api_type)) { + Ok(request) => request, + Err(err) => { + warn!("failed to parse request as ProviderRequestType: {}", err); + let err_msg = format!("Failed to parse request: {}", err); + return Err(AgentFilterChainError::RequestParsing( + serde_json::Error::custom(err_msg), + )); + } + }; + + // If model is not specified in the request, resolve from default provider + if client_request.model().is_empty() { + match llm_providers.read().await.default() { + Some(default_provider) => { + let default_model = default_provider.name.clone(); + info!(default_model = %default_model, "no model specified in request, using default provider"); + client_request.set_model(default_model); + } + None => { + let err_msg = "No model specified in request and no default provider configured"; + warn!("{}", err_msg); + let mut bad_request = + Response::new(ResponseHandler::create_full_body(err_msg.to_string())); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } } - }; + } let message: Vec = client_request.get_messages(); diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index 10a68c1a..435fb6f5 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -150,9 +150,30 @@ async fn llm_chat_inner( Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)) ); + // If model is not specified in the request, resolve from default provider + let model_from_request = client_request.model().to_string(); + let model_from_request = if model_from_request.is_empty() { + match llm_providers.read().await.default() { + Some(default_provider) => { + let default_model = default_provider.name.clone(); + info!(default_model = %default_model, "no model specified in request, using default provider"); + client_request.set_model(default_model.clone()); + default_model + } + None => { + let err_msg = "No model specified in request and no default provider configured"; + warn!("{}", err_msg); + let mut bad_request = Response::new(full(err_msg.to_string())); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + } + } else { + model_from_request + }; + // Model alias resolution: update model field in client_request immediately // This ensures all downstream objects use the resolved model - let model_from_request = client_request.model().to_string(); let temperature = client_request.get_temperature(); let is_streaming_request = client_request.is_streaming(); let alias_resolved_model = resolve_model_alias(&model_from_request, &model_aliases); diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index fff69b00..87deda6a 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -202,6 +202,7 @@ async fn main() -> Result<(), Box> { fully_qualified_url, agents_list, listeners, + llm_providers, ) .with_context(parent_cx) .await; diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 6e53e6db..3cb06828 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -102,6 +102,7 @@ pub struct McpServer { #[skip_serializing_none] #[derive(Serialize, Deserialize, Debug, Clone)] pub struct MessagesRequest { + #[serde(default)] pub model: String, pub messages: Vec, pub max_tokens: u32, diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index cd4e7d0b..53eee442 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -74,6 +74,7 @@ impl ApiDefinition for OpenAIApi { #[derive(Serialize, Deserialize, Debug, Clone, Default)] pub struct ChatCompletionsRequest { pub messages: Vec, + #[serde(default)] pub model: String, // pub audio: Option