diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index 4d591c9c..1b110863 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use bytes::Bytes; use common::api::open_ai::{ChatCompletionsResponse, Choice}; -use common::configuration::ModelUsagePreference; +use common::configuration::{ModelUsagePreference, RoutingPreference}; use common::consts::{ARCH_PROVIDER_HINT_HEADER, ARCH_UPSTREAM_HOST_HEADER}; use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::apis::{Role, Usage}; @@ -38,12 +38,15 @@ pub async fn agent_chat( let listener_name = request.headers().get("x-arch-agent-listener-name"); let listener = { let listeners = listeners.read().await; - listeners.iter().find(|l| { - listener_name - .and_then(|name| name.to_str().ok()) - .map(|name| l.name == name) - .unwrap_or(false) - }).cloned() + listeners + .iter() + .find(|l| { + listener_name + .and_then(|name| name.to_str().ok()) + .map(|name| l.name == name) + .unwrap_or(false) + }) + .cloned() } .unwrap(); @@ -83,8 +86,83 @@ pub async fn agent_chat( map }; + let trace_parent = request_headers + .iter() + .find(|(ty, _)| ty.as_str() == "traceparent") + .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); + + let usage_preferences: Vec = listener + .agents + .as_ref() + .unwrap() + .iter() + .map(|agent| ModelUsagePreference { + model: agent.name.clone(), + routing_preferences: vec![RoutingPreference { + name: agent.name.clone(), + description: agent + .description + .as_ref() + .unwrap_or(&"".to_string()) + .clone(), + }], + }) + .collect(); + + debug!( + "Usage preferences for agent routing: {:?}", + usage_preferences + ); + + let selected_agent = match router_service + .determine_route( + &chat_completions_request.messages, + trace_parent.clone(), + Some(usage_preferences), + ) + .await + { + Ok(route) => match route { + Some((_, model_name)) => Some(model_name), + None => { + debug!("No route determined"); + None + } + }, + Err(err) => { + let err_msg = format!("Failed to determine route: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; + // find agent to answer the request - let agent_pipeline = listener.agents.as_ref().unwrap()[0].clone(); // for now, just take the first agent pipeline + let agent_pipeline = match selected_agent { + Some(agent_name) => listener + .agents + .as_ref() + .unwrap() + .iter() + .find(|a| a.name == agent_name) + .cloned() + // selected agent must exist in the agent map + .unwrap(), + None => listener + .agents + .as_ref() + .unwrap() + .iter() + .find(|a| a.default.unwrap_or(false)) + .cloned() + .unwrap_or_else(|| { + warn!( + "No default agent found, routing request to first agent: {}", + listener.agents.as_ref().unwrap()[0].name + ); + listener.agents.as_ref().unwrap()[0].clone() + }), + }; // process agent pipeline @@ -92,11 +170,6 @@ pub async fn agent_chat( let mut chat_completions_history = chat_completions_request.messages.clone(); - // let trace_parent = request_headers - // .iter() - // .find(|(ty, _)| ty.as_str() == "traceparent") - // .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); - // if let Some(trace_parent) = trace_parent { // request_headers.insert( // header::HeaderName::from_static("traceparent"), diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 3b09c115..243ebca7 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -79,7 +79,10 @@ impl RouterService { trace_parent: Option, usage_preferences: Option>, ) -> Result> { - if !self.llm_usage_defined { + if usage_preferences.is_none() + || usage_preferences.as_ref().unwrap().len() < 2 + || messages.is_empty() + { return Ok(None); } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 37d3d002..4f8cb998 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -23,6 +23,7 @@ pub struct Agent { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AgentPipeline { pub name: String, + pub default: Option, pub description: Option, pub filter_chain: Vec, } diff --git a/demos/use_cases/rag_agent/arch_config.yaml b/demos/use_cases/rag_agent/arch_config.yaml index f12577b5..7c95da61 100644 --- a/demos/use_cases/rag_agent/arch_config.yaml +++ b/demos/use_cases/rag_agent/arch_config.yaml @@ -19,13 +19,14 @@ listeners: router: arch_agent_v2 agents: - name: simple_rag_agent - description: virtual assistant for device contracts. + default: true + description: virtual assistant for device contracts for simple queries filter_chain: - query_rewriter - context_builder - response_generator - name: research_agent - description: deep research agent that can perform searches and gather information. + description: deep research agent that can perform searches and gather information filter_chain: - research_agent - response_generator