refactor more

This commit is contained in:
Adil Hafeez 2026-03-12 13:54:23 -07:00
parent 8c71acfc76
commit efa677683a
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 61 additions and 41 deletions

View file

@ -172,7 +172,7 @@ impl AgentSelector {
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{AgentFilterChain, Listener};
use common::configuration::{AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -192,12 +192,12 @@ mod tests {
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
Listener {
listener_type: ListenerType::Agent,
name: name.to_string(),
agents: Some(agents),
filter_chain: None,
port: 8080,
router: None,
filter_agents: None,
}
}

View file

@ -17,7 +17,7 @@ use hyper::StatusCode;
#[cfg(test)]
mod tests {
use super::*;
use common::configuration::{Agent, AgentFilterChain, Listener};
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
Arc::new(OrchestratorService::new(
@ -72,12 +72,12 @@ mod tests {
};
let listener = Listener {
listener_type: ListenerType::Agent,
name: "test-listener".to_string(),
agents: Some(vec![agent_pipeline.clone()]),
filter_chain: None,
port: 8080,
router: None,
filter_agents: None,
};
let listeners = vec![listener];

View file

@ -1,5 +1,5 @@
use bytes::Bytes;
use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes};
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
};
@ -45,7 +45,8 @@ pub async fn llm_chat(
llm_providers: Arc<RwLock<LlmProviders>>,
span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>,
listeners: Arc<RwLock<Vec<Listener>>>,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone();
@ -85,7 +86,8 @@ pub async fn llm_chat(
request_id,
request_path,
request_headers,
listeners,
filter_chain,
filter_agents,
)
.instrument(request_span)
.await
@ -103,7 +105,8 @@ async fn llm_chat_inner(
request_id: String,
request_path: String,
mut request_headers: hyper::HeaderMap,
listeners: Arc<RwLock<Vec<Listener>>>,
filter_chain: Arc<Option<Vec<String>>>,
filter_agents: Arc<HashMap<String, Agent>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations
set_service_name(operation_component::LLM);
@ -257,20 +260,9 @@ async fn llm_chat_inner(
debug!("removed plano_preference_config from metadata");
}
// === Filter chain processing for model listeners ===
// Check if any model listener (no agents) has a filter_chain configured
// === Filter chain processing for model listener ===
{
let listeners_guard = listeners.read().await;
let model_listener = listeners_guard
.iter()
.find(|l| l.agents.is_none() && l.filter_chain.is_some());
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
let agent_map = model_listener
.and_then(|l| l.filter_agents.clone())
.unwrap_or_default();
if let Some(ref fc) = filter_chain {
if let Some(ref fc) = *filter_chain {
if !fc.is_empty() {
debug!(filter_chain = ?fc, "processing model listener filter chain");
@ -288,7 +280,7 @@ async fn llm_chat_inner(
.process_filter_chain(
&messages,
&temp_filter_chain,
&agent_map,
&filter_agents,
&request_headers,
)
.await

View file

@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::configuration::{Agent, Configuration, ListenerType};
use common::consts::{
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
};
@ -87,26 +87,41 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.map(|a| (a.id.clone(), a.clone()))
.collect();
// Resolve filter_agents on each listener at startup
let mut listeners_resolved = plano_config.listeners.clone();
for listener in &mut listeners_resolved {
if let Some(ref fc) = listener.filter_chain {
let filter_agents: HashMap<String, Agent> = fc
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
if !filter_agents.is_empty() {
listener.filter_agents = Some(filter_agents);
}
}
}
// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(listeners_resolved));
// Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let model_filter_chain: Arc<Option<Vec<String>>> =
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
model_filter_chain
.as_ref()
.as_ref()
.map(|fc| {
fc.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect()
})
.unwrap_or_default(),
);
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -200,6 +215,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -211,6 +228,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let model_filter_chain = model_filter_chain.clone();
let model_filter_agents = model_filter_agents.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -269,7 +288,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers,
span_attributes,
state_storage,
listeners,
model_filter_chain,
model_filter_agents,
)
.with_context(parent_cx)
.await

View file

@ -36,15 +36,23 @@ pub struct AgentFilterChain {
pub filter_chain: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum ListenerType {
Model,
Agent,
Prompt,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Listener {
#[serde(rename = "type")]
pub listener_type: ListenerType,
pub name: String,
pub router: Option<String>,
pub agents: Option<Vec<AgentFilterChain>>,
pub filter_chain: Option<Vec<String>>,
pub port: u16,
#[serde(skip)]
pub filter_agents: Option<HashMap<String, Agent>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]