mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
refactor a bit to send filters instead of agents
This commit is contained in:
parent
692499d910
commit
8c71acfc76
5 changed files with 34 additions and 16 deletions
|
|
@ -197,6 +197,7 @@ mod tests {
|
||||||
filter_chain: None,
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
|
filter_agents: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ mod tests {
|
||||||
filter_chain: None,
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
|
filter_agents: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let listeners = vec![listener];
|
let listeners = vec![listener];
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias, SpanAttributes};
|
use common::configuration::{AgentFilterChain, Listener, ModelAlias, SpanAttributes};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||||
};
|
};
|
||||||
|
|
@ -46,7 +46,6 @@ pub async fn llm_chat(
|
||||||
span_attributes: Arc<Option<SpanAttributes>>,
|
span_attributes: Arc<Option<SpanAttributes>>,
|
||||||
state_storage: Option<Arc<dyn StateStorage>>,
|
state_storage: Option<Arc<dyn StateStorage>>,
|
||||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||||
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
|
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
|
@ -87,7 +86,6 @@ pub async fn llm_chat(
|
||||||
request_path,
|
request_path,
|
||||||
request_headers,
|
request_headers,
|
||||||
listeners,
|
listeners,
|
||||||
agents_list,
|
|
||||||
)
|
)
|
||||||
.instrument(request_span)
|
.instrument(request_span)
|
||||||
.await
|
.await
|
||||||
|
|
@ -106,7 +104,6 @@ async fn llm_chat_inner(
|
||||||
request_path: String,
|
request_path: String,
|
||||||
mut request_headers: hyper::HeaderMap,
|
mut request_headers: hyper::HeaderMap,
|
||||||
listeners: Arc<RwLock<Vec<Listener>>>,
|
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||||
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
|
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
// Set service name for LLM operations
|
// Set service name for LLM operations
|
||||||
set_service_name(operation_component::LLM);
|
set_service_name(operation_component::LLM);
|
||||||
|
|
@ -264,22 +261,19 @@ async fn llm_chat_inner(
|
||||||
// Check if any model listener (no agents) has a filter_chain configured
|
// Check if any model listener (no agents) has a filter_chain configured
|
||||||
{
|
{
|
||||||
let listeners_guard = listeners.read().await;
|
let listeners_guard = listeners.read().await;
|
||||||
let filter_chain: Option<Vec<String>> = listeners_guard
|
let model_listener = listeners_guard
|
||||||
.iter()
|
.iter()
|
||||||
.find(|l| l.agents.is_none() && l.filter_chain.is_some())
|
.find(|l| l.agents.is_none() && l.filter_chain.is_some());
|
||||||
.and_then(|l| l.filter_chain.clone());
|
|
||||||
|
let filter_chain = model_listener.and_then(|l| l.filter_chain.clone());
|
||||||
|
let agent_map = model_listener
|
||||||
|
.and_then(|l| l.filter_agents.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
if let Some(ref fc) = filter_chain {
|
if let Some(ref fc) = filter_chain {
|
||||||
if !fc.is_empty() {
|
if !fc.is_empty() {
|
||||||
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
||||||
|
|
||||||
// Build agent map from agents_list
|
|
||||||
let agents_guard = agents_list.read().await;
|
|
||||||
let agent_map: HashMap<String, Agent> = agents_guard
|
|
||||||
.as_ref()
|
|
||||||
.map(|agents| agents.iter().map(|a| (a.id.clone(), a.clone())).collect())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
||||||
let temp_filter_chain = AgentFilterChain {
|
let temp_filter_chain = AgentFilterChain {
|
||||||
id: "model_listener".to_string(),
|
id: "model_listener".to_string(),
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ use hyper_util::rt::TokioIo;
|
||||||
use opentelemetry::trace::FutureExt;
|
use opentelemetry::trace::FutureExt;
|
||||||
use opentelemetry::{global, Context};
|
use opentelemetry::{global, Context};
|
||||||
use opentelemetry_http::HeaderExtractor;
|
use opentelemetry_http::HeaderExtractor;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{env, fs};
|
use std::{env, fs};
|
||||||
use tokio::net::TcpListener;
|
use tokio::net::TcpListener;
|
||||||
|
|
@ -80,12 +81,32 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
// Build global agent map for resolving filter chain references
|
||||||
|
let global_agent_map: HashMap<String, Agent> = all_agents
|
||||||
|
.iter()
|
||||||
|
.map(|a| (a.id.clone(), a.clone()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Resolve filter_agents on each listener at startup
|
||||||
|
let mut listeners_resolved = plano_config.listeners.clone();
|
||||||
|
for listener in &mut listeners_resolved {
|
||||||
|
if let Some(ref fc) = listener.filter_chain {
|
||||||
|
let filter_agents: HashMap<String, Agent> = fc
|
||||||
|
.iter()
|
||||||
|
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||||
|
.collect();
|
||||||
|
if !filter_agents.is_empty() {
|
||||||
|
listener.filter_agents = Some(filter_agents);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Create expanded provider list for /v1/models endpoint
|
// Create expanded provider list for /v1/models endpoint
|
||||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||||
.expect("Failed to create LlmProviders");
|
.expect("Failed to create LlmProviders");
|
||||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
let listeners = Arc::new(RwLock::new(listeners_resolved));
|
||||||
let llm_provider_url =
|
let llm_provider_url =
|
||||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||||
|
|
||||||
|
|
@ -249,7 +270,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
span_attributes,
|
span_attributes,
|
||||||
state_storage,
|
state_storage,
|
||||||
listeners,
|
listeners,
|
||||||
agents_list,
|
|
||||||
)
|
)
|
||||||
.with_context(parent_cx)
|
.with_context(parent_cx)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,8 @@ pub struct Listener {
|
||||||
pub agents: Option<Vec<AgentFilterChain>>,
|
pub agents: Option<Vec<AgentFilterChain>>,
|
||||||
pub filter_chain: Option<Vec<String>>,
|
pub filter_chain: Option<Vec<String>>,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
|
#[serde(skip)]
|
||||||
|
pub filter_agents: Option<HashMap<String, Agent>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue