refactor a bit to send filters instead of agents

This commit is contained in:
Adil Hafeez 2026-03-12 13:26:19 -07:00
parent 692499d910
commit 8c71acfc76
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 34 additions and 16 deletions

View file

@ -197,6 +197,7 @@ mod tests {
filter_chain: None, filter_chain: None,
port: 8080, port: 8080,
router: None, router: None,
filter_agents: None,
} }
} }

View file

@ -77,6 +77,7 @@ mod tests {
filter_chain: None, filter_chain: None,
port: 8080, port: 8080,
router: None, router: None,
filter_agents: None,
}; };
let listeners = vec![listener]; let listeners = vec![listener];

View file

@ -1,5 +1,5 @@
use bytes::Bytes; use bytes::Bytes;
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias, SpanAttributes}; use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes};
use common::consts::{ use common::consts::{
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER, ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
}; };
@ -46,7 +46,6 @@ pub async fn llm_chat(
span_attributes: Arc<Option<SpanAttributes>>, span_attributes: Arc<Option<SpanAttributes>>,
state_storage: Option<Arc<dyn StateStorage>>, state_storage: Option<Arc<dyn StateStorage>>,
listeners: Arc<RwLock<Vec<Listener>>>, listeners: Arc<RwLock<Vec<Listener>>>,
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string(); let request_path = request.uri().path().to_string();
let request_headers = request.headers().clone(); let request_headers = request.headers().clone();
@ -87,7 +86,6 @@ pub async fn llm_chat(
request_path, request_path,
request_headers, request_headers,
listeners, listeners,
agents_list,
) )
.instrument(request_span) .instrument(request_span)
.await .await
@ -106,7 +104,6 @@ async fn llm_chat_inner(
request_path: String, request_path: String,
mut request_headers: hyper::HeaderMap, mut request_headers: hyper::HeaderMap,
listeners: Arc<RwLock<Vec<Listener>>>, listeners: Arc<RwLock<Vec<Listener>>>,
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> { ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
// Set service name for LLM operations // Set service name for LLM operations
set_service_name(operation_component::LLM); set_service_name(operation_component::LLM);
@ -264,22 +261,19 @@ async fn llm_chat_inner(
// Check if any model listener (no agents) has a filter_chain configured // Check if any model listener (no agents) has a filter_chain configured
{ {
let listeners_guard = listeners.read().await; let listeners_guard = listeners.read().await;
let filter_chain: Option<Vec<String>> = listeners_guard let model_listener = listeners_guard
.iter() .iter()
.find(|l| l.agents.is_none() && l.filter_chain.is_some()) .find(|l| l.agents.is_none() && l.filter_chain.is_some());
.and_then(|l| l.filter_chain.clone());
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
let agent_map = model_listener
.and_then(|l| l.filter_agents.clone())
.unwrap_or_default();
if let Some(ref fc) = filter_chain { if let Some(ref fc) = filter_chain {
if !fc.is_empty() { if !fc.is_empty() {
debug!(filter_chain = ?fc, "processing model listener filter chain"); debug!(filter_chain = ?fc, "processing model listener filter chain");
// Build agent map from agents_list
let agents_guard = agents_list.read().await;
let agent_map: HashMap<String, Agent> = agents_guard
.as_ref()
.map(|agents| agents.iter().map(|a| (a.id.clone(), a.clone())).collect())
.unwrap_or_default();
// Create a temporary AgentFilterChain to reuse PipelineProcessor // Create a temporary AgentFilterChain to reuse PipelineProcessor
let temp_filter_chain = AgentFilterChain { let temp_filter_chain = AgentFilterChain {
id: "model_listener".to_string(), id: "model_listener".to_string(),

View file

@ -24,6 +24,7 @@ use hyper_util::rt::TokioIo;
use opentelemetry::trace::FutureExt; use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context}; use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor; use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::{env, fs}; use std::{env, fs};
use tokio::net::TcpListener; use tokio::net::TcpListener;
@ -80,12 +81,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.cloned() .cloned()
.collect(); .collect();
// Build global agent map for resolving filter chain references
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();
// Resolve filter_agents on each listener at startup
let mut listeners_resolved = plano_config.listeners.clone();
for listener in &mut listeners_resolved {
if let Some(ref fc) = listener.filter_chain {
let filter_agents: HashMap<String, Agent> = fc
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
if !filter_agents.is_empty() {
listener.filter_agents = Some(filter_agents);
}
}
}
// Create expanded provider list for /v1/models endpoint // Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone()) let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders"); .expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers)); let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents))); let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone())); let listeners = Arc::new(RwLock::new(listeners_resolved));
let llm_provider_url = let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string()); env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -249,7 +270,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
span_attributes, span_attributes,
state_storage, state_storage,
listeners, listeners,
agents_list,
) )
.with_context(parent_cx) .with_context(parent_cx)
.await .await

View file

@ -43,6 +43,8 @@ pub struct Listener {
pub agents: Option<Vec<AgentFilterChain>>, pub agents: Option<Vec<AgentFilterChain>>,
pub filter_chain: Option<Vec<String>>, pub filter_chain: Option<Vec<String>>,
pub port: u16, pub port: u16,
#[serde(skip)]
pub filter_agents: Option<HashMap<String, Agent>>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]