add output filter chain (#822)

This commit is contained in:
Adil Hafeez 2026-03-18 17:58:20 -07:00 committed by GitHub
parent de2d8847f3
commit 1f23c573bf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 2961 additions and 2621 deletions

View file

@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
use brightstaff::state::StateStorage;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::{Agent, Configuration};
use common::configuration::{
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
};
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
use common::llm_providers::LlmProviders;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
@ -22,6 +24,7 @@ use hyper_util::rt::TokioIo;
use opentelemetry::trace::FutureExt;
use opentelemetry::{global, Context};
use opentelemetry_http::HeaderExtractor;
use std::collections::HashMap;
use std::sync::Arc;
use std::{env, fs};
use tokio::net::TcpListener;
@ -80,11 +83,49 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.cloned()
.collect();
// Build global agent map for resolving filter chain references
let global_agent_map: HashMap<String, Agent> = all_agents
.iter()
.map(|a| (a.id.clone(), a.clone()))
.collect();
// Create expanded provider list for /v1/models endpoint
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
.expect("Failed to create LlmProviders");
let llm_providers = Arc::new(RwLock::new(llm_providers));
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
// Resolve model listener filter chain and agents at startup
let model_listener_count = plano_config
.listeners
.iter()
.filter(|l| l.listener_type == ListenerType::Model)
.count();
assert!(
model_listener_count <= 1,
"only one model listener is allowed, found {}",
model_listener_count
);
let model_listener = plano_config
.listeners
.iter()
.find(|l| l.listener_type == ListenerType::Model);
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
filter_ids.map(|ids| {
let agents = ids
.iter()
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
.collect();
ResolvedFilterChain {
filter_ids: ids,
agents,
}
})
};
let filter_pipeline = Arc::new(FilterPipeline {
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
});
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
let llm_provider_url =
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
@ -200,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let agents_list = combined_agents_filters_list.clone();
let filter_pipeline = filter_pipeline.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -211,6 +253,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let llm_providers = llm_providers.clone();
let model_aliases = Arc::clone(&model_aliases);
let agents_list = agents_list.clone();
let filter_pipeline = filter_pipeline.clone();
let listeners = listeners.clone();
let span_attributes = span_attributes.clone();
let state_storage = state_storage.clone();
@ -269,6 +312,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
llm_providers,
span_attributes,
state_storage,
filter_pipeline,
)
.with_context(parent_cx)
.await