mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
consolidate model filter chain into single struct, move validation to config generator
This commit is contained in:
parent
e41e9e1cf4
commit
edf782f07a
4 changed files with 54 additions and 54 deletions
|
|
@ -426,6 +426,13 @@ def validate_and_render_schema():
|
|||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||
)
|
||||
|
||||
# Validate at most one model listener
|
||||
model_listeners = [l for l in listeners if l.get("type") == "model"]
|
||||
if len(model_listeners) > 1:
|
||||
raise Exception(
|
||||
f"Only one model listener is allowed, found {len(model_listeners)}"
|
||||
)
|
||||
|
||||
# Validate filter_chain IDs on listeners reference valid agent/filter IDs
|
||||
for listener in listeners:
|
||||
listener_filter_chain = listener.get("filter_chain", [])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain, ModelAlias, SpanAttributes};
|
||||
use common::configuration::{AgentFilterChain, ModelAlias, ModelFilterChain, SpanAttributes};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
|
|
@ -45,8 +45,7 @@ pub async fn llm_chat(
|
|||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
model_filter_chain: Arc<Option<ModelFilterChain>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
|
@ -86,8 +85,7 @@ pub async fn llm_chat(
|
|||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
filter_chain,
|
||||
filter_agents,
|
||||
model_filter_chain,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
|
|
@ -105,8 +103,7 @@ async fn llm_chat_inner(
|
|||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
filter_chain: Arc<Option<Vec<String>>>,
|
||||
filter_agents: Arc<HashMap<String, Agent>>,
|
||||
model_filter_chain: Arc<Option<ModelFilterChain>>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -261,29 +258,27 @@ async fn llm_chat_inner(
|
|||
}
|
||||
|
||||
// === Filter chain processing for model listener ===
|
||||
{
|
||||
if let Some(ref fc) = *filter_chain {
|
||||
if !fc.is_empty() {
|
||||
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
||||
if let Some(ref mfc) = *model_filter_chain {
|
||||
{
|
||||
debug!(filter_ids = ?mfc.filter_ids, "processing model listener filter chain");
|
||||
|
||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
||||
let temp_filter_chain = AgentFilterChain {
|
||||
id: "model_listener".to_string(),
|
||||
default: None,
|
||||
description: None,
|
||||
filter_chain: Some(fc.clone()),
|
||||
};
|
||||
let temp_filter_chain = AgentFilterChain {
|
||||
id: "model_listener".to_string(),
|
||||
default: None,
|
||||
description: None,
|
||||
filter_chain: Some(mfc.filter_ids.clone()),
|
||||
};
|
||||
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let messages = client_request.get_messages();
|
||||
match pipeline_processor
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
&temp_filter_chain,
|
||||
&filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let messages = client_request.get_messages();
|
||||
match pipeline_processor
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
&temp_filter_chain,
|
||||
&mfc.agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
client_request.set_messages(&filtered_messages);
|
||||
|
|
@ -326,7 +321,6 @@ async fn llm_chat_inner(
|
|||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
|||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration, ListenerType};
|
||||
use common::configuration::{Agent, Configuration, ListenerType, ModelFilterChain};
|
||||
use common::consts::{
|
||||
CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH, PLANO_ORCHESTRATOR_MODEL_NAME,
|
||||
};
|
||||
|
|
@ -94,32 +94,26 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
|
||||
// Resolve model listener filter chain and agents at startup
|
||||
let model_listener_count = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
assert!(
|
||||
model_listener_count <= 1,
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
);
|
||||
let model_listener = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
let model_filter_chain: Arc<Option<Vec<String>>> =
|
||||
Arc::new(model_listener.and_then(|l| l.filter_chain.clone()));
|
||||
let model_filter_agents: Arc<HashMap<String, Agent>> = Arc::new(
|
||||
model_filter_chain
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
let model_filter_chain: Arc<Option<ModelFilterChain>> = Arc::new(
|
||||
model_listener
|
||||
.and_then(|l| l.filter_chain.clone())
|
||||
.filter(|fc| !fc.is_empty())
|
||||
.map(|fc| {
|
||||
fc.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default(),
|
||||
let agents = fc
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
global_agent_map.get(id).map(|a| (id.clone(), a.clone()))
|
||||
})
|
||||
.collect();
|
||||
ModelFilterChain {
|
||||
filter_ids: fc,
|
||||
agents,
|
||||
}
|
||||
}),
|
||||
);
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
|
|
@ -216,7 +210,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -229,7 +222,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let model_filter_chain = model_filter_chain.clone();
|
||||
let model_filter_agents = model_filter_agents.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -289,7 +281,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
span_attributes,
|
||||
state_storage,
|
||||
model_filter_chain,
|
||||
model_filter_agents,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -36,6 +36,14 @@ pub struct AgentFilterChain {
|
|||
pub filter_chain: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// Resolved filter chain for a model listener: the ordered filter IDs
|
||||
/// together with the agent definitions they reference.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelFilterChain {
|
||||
pub filter_ids: Vec<String>,
|
||||
pub agents: HashMap<String, Agent>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ListenerType {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue