refactor more

This commit is contained in:
Adil Hafeez 2026-03-12 13:54:23 -07:00
parent 8c71acfc76
commit efa677683a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 61 additions and 41 deletions

View file

@ -172,7 +172,7 @@ impl AgentSelector {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use common::configuration::{AgentFilterChain, Listener}; use common::configuration::{AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> { fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new( Arc::new(OrchestratorService::new(
@ -192,12 +192,12 @@ mod tests {
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener { fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener { Listener {
listener_type: ListenerType::Agent,
name: name.to_string(), name: name.to_string(),
agents: Some(agents), agents: Some(agents),
filter_chain: None, filter_chain: None,
port: 8080, port: 8080,
router: None, router: None,
filter_agents: None,
} }
} }

View file

@ -17,7 +17,7 @@ use hyper::StatusCode;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener}; use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> { fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new( Arc::new(OrchestratorService::new(
@ -72,12 +72,12 @@ mod tests {
}; };
let listener = Listener { let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(), name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]), agents: Some(vec![agent_pipeline.clone()]),
filter_chain: None, filter_chain: None,
port: 8080, port: 8080,
router: None, router: None,
filter_agents: None,
}; };
let listeners = vec![listener]; let listeners = vec![listener];

View file

@ -1,5 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes}; use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
use common::consts::{ use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
}; };
@ -45,7 +45,8 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>, llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>, span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>, state_storage: Option<Arc<dyn StateStorage>>,
listeners: Arc<RwLock<Vec<Listener>>>, filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone(); let request_headers = request.headers().clone();
@ -85,7 +86,8 @@ pub async fn llm_chat(
request_id, request_id,
request_path, request_path,
request_headers, request_headers,
listeners, filter_chain,
filter_agents,
) )
.instrument(request_span) .instrument(request_span)
.await .await
@ -103,7 +105,8 @@ async fn llm_chat_inner(
request_id: String, request_id: String,
request_path: String, request_path: String,
mut request_headers: hyper::HeaderMap, mut request_headers: hyper::HeaderMap,
listeners: Arc<RwLock<Vec<Listener>>>, filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations // Set service name for LLM operations
set_service_name(operation_component::LLM); set_service_name(operation_component::LLM);
@ -257,20 +260,9 @@ async fn llm_chat_inner(
debug!("removed plano_preference_config from metadata"); debug!("removed plano_preference_config from metadata");
} }
// === Filter chain processing for model listeners === // === Filter chain processing for model listener ===
// Check if any model listener (no agents) has a filter_chain configured
{ {
let listeners_guard = listeners.read().await; if let Some(ref fc) = *filter_chain {
let model_listener = listeners_guard
.iter()
.find(|l| l.agents.is_none() && l.filter_chain.is_some());
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
let agent_map = model_listener
.and_then(|l| l.filter_agents.clone())
.unwrap_or_default();
if let Some(ref fc) = filter_chain {
if !fc.is_empty() { if !fc.is_empty() {
debug!(filter_chain = ?fc, "processing model listener filter chain"); debug!(filter_chain = ?fc, "processing model listener filter chain");
@ -288,7 +280,7 @@ async fn llm_chat_inner(
.process_filter_chain( .process_filter_chain(
&messages, &messages,
&temp_filter_chain, &temp_filter_chain,
&agent_map, &filter_agents,
&request_headers, &request_headers,
) )
.await .await

View file

@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage; use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer; use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, Configuration}; use common::configuration::{Agent, Configuration, ListenerType};
use common::consts::{ use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME, CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
}; };
@ -87,26 +87,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.map(|a| (a.id.clone(), a.clone())) .map(|a| (a.id.clone(), a.clone()))
.collect(); .collect();
// Resolve filter_agents on each listener at startup
let mut listeners_resolved = plano_config.listeners.clone();
for listener in &mut listeners_resolved {
if let Some(ref fc) = listener.filter_chain {
let filter_agents: HashMap<String, Agent> = fc
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
if !filter_agents.is_empty() {
listener.filter_agents = Some(filter_agents);
}
}
}
// Create expanded provider list for /v1/models endpoint // Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone()) let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders"); .expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers)); let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(listeners_resolved));
// Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let model_filter_chain: Arc<Option<Vec<String>>> =
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
model_filter_chain
.as_ref()
.as_ref()
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect()
})
.unwrap_or_default(),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url = let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -200,6 +215,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone(); let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -211,6 +228,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone(); let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases); let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone(); let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone(); let listeners = listeners.clone();
let span_attributes = span_attributes.clone(); let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone(); let state_storage = state_storage.clone();
@ -269,7 +288,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers, llm_providers,
span_attributes, span_attributes,
state_storage, state_storage,
listeners, model_filter_chain,
model_filter_agents,
) )
.with_context(parent_cx) .with_context(parent_cx)
.await .await

View file

@ -36,15 +36,23 @@ pub struct AgentFilterChain {
pub filter_chain: Option<Vec<String>>, pub filter_chain: Option<Vec<String>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ListenerType {
Model,
Agent,
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener { pub struct Listener {
#[serde(rename = "type")]
pub listener_type: ListenerType,
pub name: String, pub name: String,
pub router: Option<String>, pub router: Option<String>,
pub agents: Option<Vec<AgentFilterChain>>, pub agents: Option<Vec<AgentFilterChain>>,
pub filter_chain: Option<Vec<String>>, pub filter_chain: Option<Vec<String>>,
pub port: u16, pub port: u16,
#[serde(skip)]
pub filter_agents: Option<HashMap<String, Agent>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]