mirror of
https://github.com/katanemo/plano.git
synced 2026-05-15 11:02:39 +02:00
add output filter chain (#822)
This commit is contained in:
parent
de2d8847f3
commit
1f23c573bf
59 changed files with 2961 additions and 2621 deletions
|
|
@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
|||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::configuration::{
|
||||
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
|
||||
};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
|
|
@ -22,6 +24,7 @@ use hyper_util::rt::TokioIo;
|
|||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{global, Context};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
|
|
@ -80,11 +83,49 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Build global agent map for resolving filter chain references
|
||||
let global_agent_map: HashMap<String, Agent> = all_agents
|
||||
.iter()
|
||||
.map(|a| (a.id.clone(), a.clone()))
|
||||
.collect();
|
||||
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
|
||||
// Resolve model listener filter chain and agents at startup
|
||||
let model_listener_count = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
assert!(
|
||||
model_listener_count <= 1,
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
);
|
||||
let model_listener = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
|
||||
filter_ids.map(|ids| {
|
||||
let agents = ids
|
||||
.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect();
|
||||
ResolvedFilterChain {
|
||||
filter_ids: ids,
|
||||
agents,
|
||||
}
|
||||
})
|
||||
};
|
||||
let filter_pipeline = Arc::new(FilterPipeline {
|
||||
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
||||
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
||||
});
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
|
@ -200,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -211,6 +253,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -269,6 +312,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
llm_providers,
|
||||
span_attributes,
|
||||
state_storage,
|
||||
filter_pipeline,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue