mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
consolidate model filter chain into single struct, move validation to config generator
This commit is contained in:
parent
e41e9e1cf4
commit
edf782f07a
4 changed files with 54 additions and 54 deletions
|
|
@ -426,6 +426,13 @@ def validate_and_render_schema():
|
||||||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate at most one model listener
|
||||||
|
model_listeners = [l for l in listeners if l.get("type") == "model"]
|
||||||
|
if len(model_listeners) > 1:
|
||||||
|
raise Exception(
|
||||||
|
f"Only one model listener is allowed, found {len(model_listeners)}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate filter_chain IDs on listeners reference valid agent/filter IDs
|
# Validate filter_chain IDs on listeners reference valid agent/filter IDs
|
||||||
for listener in listeners:
|
for listener in listeners:
|
||||||
listener_filter_chain = listener.get("filter_chain", [])
|
listener_filter_chain = listener.get("filter_chain", [])
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
|
use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||||
};
|
};
|
||||||
|
|
@ -45,8 +45,7 @@ pub async fn llm_chat(
|
||||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||||
span_attributes: Arc<Option<SpanAttributes>>,
|
span_attributes: Arc<Option<SpanAttributes>>,
|
||||||
state_storage: Option<Arc<dyn StateStorage>>,
|
state_storage: Option<Arc<dyn StateStorage>>,
|
||||||
filter_chain: Arc<Option<Vec<String>>>,
|
model_filter_chain: Arc<Option<ModelFilterChain>>,
|
||||||
filter_agents: Arc<HashMap<String, Agent>>,
|
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
|
@ -86,8 +85,7 @@ pub async fn llm_chat(
|
||||||
request_id,
|
request_id,
|
||||||
request_path,
|
request_path,
|
||||||
request_headers,
|
request_headers,
|
||||||
filter_chain,
|
model_filter_chain,
|
||||||
filter_agents,
|
|
||||||
)
|
)
|
||||||
.instrument(request_span)
|
.instrument(request_span)
|
||||||
.await
|
.await
|
||||||
|
|
@ -105,8 +103,7 @@ async fn llm_chat_inner(
|
||||||
request_id: String,
|
request_id: String,
|
||||||
request_path: String,
|
request_path: String,
|
||||||
mut request_headers: hyper::HeaderMap,
|
mut request_headers: hyper::HeaderMap,
|
||||||
filter_chain: Arc<Option<Vec<String>>>,
|
model_filter_chain: Arc<Option<ModelFilterChain>>,
|
||||||
filter_agents: Arc<HashMap<String, Agent>>,
|
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
// Set service name for LLM operations
|
// Set service name for LLM operations
|
||||||
set_service_name(operation_component::LLM);
|
set_service_name(operation_component::LLM);
|
||||||
|
|
@ -261,29 +258,27 @@ async fn llm_chat_inner(
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Filter chain processing for model listener ===
|
// === Filter chain processing for model listener ===
|
||||||
{
|
if let Some(ref mfc) = *model_filter_chain {
|
||||||
if let Some(ref fc) = *filter_chain {
|
{
|
||||||
if !fc.is_empty() {
|
debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain");
|
||||||
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
|
||||||
|
|
||||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
let temp_filter_chain = AgentFilterChain {
|
||||||
let temp_filter_chain = AgentFilterChain {
|
id: "model_listener".to_string(),
|
||||||
id: "model_listener".to_string(),
|
default: None,
|
||||||
default: None,
|
description: None,
|
||||||
description: None,
|
filter_chain: Some(mfc.filter_ids.clone()),
|
||||||
filter_chain: Some(fc.clone()),
|
};
|
||||||
};
|
|
||||||
|
|
||||||
let mut pipeline_processor = PipelineProcessor::default();
|
let mut pipeline_processor = PipelineProcessor::default();
|
||||||
let messages = client_request.get_messages();
|
let messages = client_request.get_messages();
|
||||||
match pipeline_processor
|
match pipeline_processor
|
||||||
.process_filter_chain(
|
.process_filter_chain(
|
||||||
&messages,
|
&messages,
|
||||||
&temp_filter_chain,
|
&temp_filter_chain,
|
||||||
&filter_agents,
|
&mfc.agents,
|
||||||
&request_headers,
|
&request_headers,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(filtered_messages) => {
|
Ok(filtered_messages) => {
|
||||||
client_request.set_messages(&filtered_messages);
|
client_request.set_messages(&filtered_messages);
|
||||||
|
|
@ -326,7 +321,6 @@ async fn llm_chat_inner(
|
||||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||||
return Ok(internal_error);
|
return Ok(internal_error);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||||
use brightstaff::state::StateStorage;
|
use brightstaff::state::StateStorage;
|
||||||
use brightstaff::utils::tracing::init_tracer;
|
use brightstaff::utils::tracing::init_tracer;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, Configuration, ListenerType};
|
use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||||
};
|
};
|
||||||
|
|
@ -94,32 +94,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||||
|
|
||||||
// Resolve model listener filter chain and agents at startup
|
// Resolve model listener filter chain and agents at startup
|
||||||
let model_listener_count = plano_config
|
|
||||||
.listeners
|
|
||||||
.iter()
|
|
||||||
.filter(|l| l.listener_type == ListenerType::Model)
|
|
||||||
.count();
|
|
||||||
assert!(
|
|
||||||
model_listener_count <= 1,
|
|
||||||
"only one model listener is allowed, found {}",
|
|
||||||
model_listener_count
|
|
||||||
);
|
|
||||||
let model_listener = plano_config
|
let model_listener = plano_config
|
||||||
.listeners
|
.listeners
|
||||||
.iter()
|
.iter()
|
||||||
.find(|l| l.listener_type == ListenerType::Model);
|
.find(|l| l.listener_type == ListenerType::Model);
|
||||||
let model_filter_chain: Arc<Option<Vec<String>>> =
|
let model_filter_chain: Arc<Option<ModelFilterChain>> = Arc::new(
|
||||||
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
|
model_listener
|
||||||
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
|
.and_then(|l| l.filter_chain.clone())
|
||||||
model_filter_chain
|
.filter(|fc| !fc.is_empty())
|
||||||
.as_ref()
|
|
||||||
.as_ref()
|
|
||||||
.map(|fc| {
|
.map(|fc| {
|
||||||
fc.iter()
|
let agents = fc
|
||||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
.iter()
|
||||||
.collect()
|
.filter_map(|id| {
|
||||||
})
|
global_agent_map.get(id).map(|a| (id.clone(), a.clone()))
|
||||||
.unwrap_or_default(),
|
})
|
||||||
|
.collect();
|
||||||
|
ModelFilterChain {
|
||||||
|
filter_ids: fc,
|
||||||
|
agents,
|
||||||
|
}
|
||||||
|
}),
|
||||||
);
|
);
|
||||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||||
let llm_provider_url =
|
let llm_provider_url =
|
||||||
|
|
@ -216,7 +210,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let llm_providers = llm_providers.clone();
|
let llm_providers = llm_providers.clone();
|
||||||
let agents_list = combined_agents_filters_list.clone();
|
let agents_list = combined_agents_filters_list.clone();
|
||||||
let model_filter_chain = model_filter_chain.clone();
|
let model_filter_chain = model_filter_chain.clone();
|
||||||
let model_filter_agents = model_filter_agents.clone();
|
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let span_attributes = span_attributes.clone();
|
let span_attributes = span_attributes.clone();
|
||||||
let state_storage = state_storage.clone();
|
let state_storage = state_storage.clone();
|
||||||
|
|
@ -229,7 +222,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let model_aliases = Arc::clone(&model_aliases);
|
let model_aliases = Arc::clone(&model_aliases);
|
||||||
let agents_list = agents_list.clone();
|
let agents_list = agents_list.clone();
|
||||||
let model_filter_chain = model_filter_chain.clone();
|
let model_filter_chain = model_filter_chain.clone();
|
||||||
let model_filter_agents = model_filter_agents.clone();
|
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let span_attributes = span_attributes.clone();
|
let span_attributes = span_attributes.clone();
|
||||||
let state_storage = state_storage.clone();
|
let state_storage = state_storage.clone();
|
||||||
|
|
@ -289,7 +281,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
span_attributes,
|
span_attributes,
|
||||||
state_storage,
|
state_storage,
|
||||||
model_filter_chain,
|
model_filter_chain,
|
||||||
model_filter_agents,
|
|
||||||
)
|
)
|
||||||
.with_context(parent_cx)
|
.with_context(parent_cx)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,14 @@ pub struct AgentFilterChain {
|
||||||
pub filter_chain: Option<Vec<String>>,
|
pub filter_chain: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolved filter chain for a model listener: the ordered filter IDs
|
||||||
|
/// together with the agent definitions they reference.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelFilterChain {
|
||||||
|
pub filter_ids: Vec<String>,
|
||||||
|
pub agents: HashMap<String, Agent>,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum ListenerType {
|
pub enum ListenerType {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue