mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
support filter chains on model listeners
This commit is contained in:
parent
21b043e6d0
commit
8136d7d6ab
8 changed files with 121 additions and 1 deletions
|
|
@ -426,6 +426,16 @@ def validate_and_render_schema():
|
||||||
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
"Please provide model_providers either under listeners or at root level, not both. Currently we don't support multiple listeners with model_providers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate filter_chain IDs on listeners reference valid agent/filter IDs
|
||||||
|
for listener in listeners:
|
||||||
|
listener_filter_chain = listener.get("filter_chain", [])
|
||||||
|
for fc_id in listener_filter_chain:
|
||||||
|
if fc_id not in agent_id_keys:
|
||||||
|
raise Exception(
|
||||||
|
f"Listener '{listener.get('name', 'unknown')}' references filter_chain id '{fc_id}' "
|
||||||
|
f"which is not defined in agents or filters. Available ids: {', '.join(sorted(agent_id_keys))}"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate model aliases if present
|
# Validate model aliases if present
|
||||||
if "model_aliases" in config_yaml:
|
if "model_aliases" in config_yaml:
|
||||||
model_aliases = config_yaml["model_aliases"]
|
model_aliases = config_yaml["model_aliases"]
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,10 @@ properties:
|
||||||
required:
|
required:
|
||||||
- id
|
- id
|
||||||
- description
|
- description
|
||||||
|
filter_chain:
|
||||||
|
type: array
|
||||||
|
items:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- type
|
- type
|
||||||
|
|
|
||||||
|
|
@ -194,6 +194,7 @@ mod tests {
|
||||||
Listener {
|
Listener {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
agents: Some(agents),
|
agents: Some(agents),
|
||||||
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -73,6 +73,7 @@ mod tests {
|
||||||
let listener = Listener {
|
let listener = Listener {
|
||||||
name: "test-listener".to_string(),
|
name: "test-listener".to_string(),
|
||||||
agents: Some(vec![agent_pipeline.clone()]),
|
agents: Some(vec![agent_pipeline.clone()]),
|
||||||
|
filter_chain: None,
|
||||||
port: 8080,
|
port: 8080,
|
||||||
router: None,
|
router: None,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::ModelAlias;
|
use common::configuration::{Agent, AgentFilterChain, Listener, ModelAlias};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||||
};
|
};
|
||||||
|
|
@ -19,6 +19,8 @@ use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::{debug, info, info_span, warn, Instrument};
|
use tracing::{debug, info, info_span, warn, Instrument};
|
||||||
|
|
||||||
|
use super::pipeline_processor::PipelineProcessor;
|
||||||
|
|
||||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||||
use crate::handlers::utils::{
|
use crate::handlers::utils::{
|
||||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||||
|
|
@ -36,6 +38,7 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||||
.boxed()
|
.boxed()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub async fn llm_chat(
|
pub async fn llm_chat(
|
||||||
request: Request<hyper::body::Incoming>,
|
request: Request<hyper::body::Incoming>,
|
||||||
router_service: Arc<RouterService>,
|
router_service: Arc<RouterService>,
|
||||||
|
|
@ -43,6 +46,8 @@ pub async fn llm_chat(
|
||||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||||
llm_providers: Arc<RwLock<LlmProviders>>,
|
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||||
state_storage: Option<Arc<dyn StateStorage>>,
|
state_storage: Option<Arc<dyn StateStorage>>,
|
||||||
|
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||||
|
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
|
@ -79,6 +84,8 @@ pub async fn llm_chat(
|
||||||
request_id,
|
request_id,
|
||||||
request_path,
|
request_path,
|
||||||
request_headers,
|
request_headers,
|
||||||
|
listeners,
|
||||||
|
agents_list,
|
||||||
)
|
)
|
||||||
.instrument(request_span)
|
.instrument(request_span)
|
||||||
.await
|
.await
|
||||||
|
|
@ -95,6 +102,8 @@ async fn llm_chat_inner(
|
||||||
request_id: String,
|
request_id: String,
|
||||||
request_path: String,
|
request_path: String,
|
||||||
mut request_headers: hyper::HeaderMap,
|
mut request_headers: hyper::HeaderMap,
|
||||||
|
listeners: Arc<RwLock<Vec<Listener>>>,
|
||||||
|
agents_list: Arc<RwLock<Option<Vec<Agent>>>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
// Set service name for LLM operations
|
// Set service name for LLM operations
|
||||||
set_service_name(operation_component::LLM);
|
set_service_name(operation_component::LLM);
|
||||||
|
|
@ -235,6 +244,96 @@ async fn llm_chat_inner(
|
||||||
debug!("removed plano_preference_config from metadata");
|
debug!("removed plano_preference_config from metadata");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === Filter chain processing for model listeners ===
|
||||||
|
// Check if any model listener (no agents) has a filter_chain configured
|
||||||
|
{
|
||||||
|
let listeners_guard = listeners.read().await;
|
||||||
|
let filter_chain: Option<Vec<String>> = listeners_guard
|
||||||
|
.iter()
|
||||||
|
.find(|l| l.agents.is_none() && l.filter_chain.is_some())
|
||||||
|
.and_then(|l| l.filter_chain.clone());
|
||||||
|
|
||||||
|
if let Some(ref fc) = filter_chain {
|
||||||
|
if !fc.is_empty() {
|
||||||
|
debug!(filter_chain = ?fc, "processing model listener filter chain");
|
||||||
|
|
||||||
|
// Build agent map from agents_list
|
||||||
|
let agents_guard = agents_list.read().await;
|
||||||
|
let agent_map: HashMap<String, Agent> = agents_guard
|
||||||
|
.as_ref()
|
||||||
|
.map(|agents| {
|
||||||
|
agents
|
||||||
|
.iter()
|
||||||
|
.map(|a| (a.id.clone(), a.clone()))
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
||||||
|
let temp_filter_chain = AgentFilterChain {
|
||||||
|
id: "model_listener".to_string(),
|
||||||
|
default: None,
|
||||||
|
description: None,
|
||||||
|
filter_chain: Some(fc.clone()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut pipeline_processor = PipelineProcessor::default();
|
||||||
|
let messages = client_request.get_messages();
|
||||||
|
match pipeline_processor
|
||||||
|
.process_filter_chain(
|
||||||
|
&messages,
|
||||||
|
&temp_filter_chain,
|
||||||
|
&agent_map,
|
||||||
|
&request_headers,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(filtered_messages) => {
|
||||||
|
client_request.set_messages(&filtered_messages);
|
||||||
|
info!(
|
||||||
|
original_count = messages.len(),
|
||||||
|
filtered_count = filtered_messages.len(),
|
||||||
|
"filter chain processed successfully"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||||
|
agent,
|
||||||
|
status,
|
||||||
|
body,
|
||||||
|
}) => {
|
||||||
|
warn!(
|
||||||
|
agent = %agent,
|
||||||
|
status = %status,
|
||||||
|
body = %body,
|
||||||
|
"client error from filter chain"
|
||||||
|
);
|
||||||
|
let error_json = serde_json::json!({
|
||||||
|
"error": "FilterChainError",
|
||||||
|
"agent": agent,
|
||||||
|
"status": status,
|
||||||
|
"agent_response": body
|
||||||
|
});
|
||||||
|
let mut error_response = Response::new(full(error_json.to_string()));
|
||||||
|
*error_response.status_mut() =
|
||||||
|
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
|
||||||
|
error_response.headers_mut().insert(
|
||||||
|
hyper::header::CONTENT_TYPE,
|
||||||
|
"application/json".parse().unwrap(),
|
||||||
|
);
|
||||||
|
return Ok(error_response);
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
warn!(error = %err, "filter chain processing failed");
|
||||||
|
let err_msg = format!("Filter chain processing failed: {}", err);
|
||||||
|
let mut internal_error = Response::new(full(err_msg));
|
||||||
|
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||||
|
return Ok(internal_error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||||
// Do this BEFORE routing since routing consumes the request
|
// Do this BEFORE routing since routing consumes the request
|
||||||
// Only process state if state_storage is configured
|
// Only process state if state_storage is configured
|
||||||
|
|
|
||||||
|
|
@ -221,6 +221,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
model_aliases,
|
model_aliases,
|
||||||
llm_providers,
|
llm_providers,
|
||||||
state_storage,
|
state_storage,
|
||||||
|
listeners,
|
||||||
|
agents_list,
|
||||||
)
|
)
|
||||||
.with_context(parent_cx)
|
.with_context(parent_cx)
|
||||||
.await
|
.await
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ pub struct Listener {
|
||||||
pub name: String,
|
pub name: String,
|
||||||
pub router: Option<String>,
|
pub router: Option<String>,
|
||||||
pub agents: Option<Vec<AgentFilterChain>>,
|
pub agents: Option<Vec<AgentFilterChain>>,
|
||||||
|
pub filter_chain: Option<Vec<String>>,
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,8 @@ listeners:
|
||||||
port: 443
|
port: 443
|
||||||
protocol: https
|
protocol: https
|
||||||
provider_interface: openai
|
provider_interface: openai
|
||||||
|
filter_chain:
|
||||||
|
- input_guards
|
||||||
name: model_1
|
name: model_1
|
||||||
port: 12000
|
port: 12000
|
||||||
type: model
|
type: model
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue