mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
add output filter chain (#822)
This commit is contained in:
parent
de2d8847f3
commit
1f23c573bf
59 changed files with 2961 additions and 2621 deletions
|
|
@ -332,15 +332,38 @@ async fn handle_agent_chat_inner(
|
|||
"processing agent"
|
||||
);
|
||||
|
||||
// Process the filter chain
|
||||
let chat_history = pipeline_processor
|
||||
.process_filter_chain(
|
||||
¤t_messages,
|
||||
selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
)
|
||||
.await?;
|
||||
// Process input filters — serialize current request as OpenAI chat completions body,
|
||||
// pass raw bytes through each filter, then extract updated messages from the result.
|
||||
let chat_history = if selected_agent
|
||||
.input_filters
|
||||
.as_ref()
|
||||
.map(|f| !f.is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
let filter_body = serde_json::json!({
|
||||
"model": client_request.model(),
|
||||
"messages": current_messages,
|
||||
});
|
||||
let filter_bytes =
|
||||
serde_json::to_vec(&filter_body).map_err(PipelineError::ParseError)?;
|
||||
|
||||
let filtered_bytes = pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&filter_bytes,
|
||||
selected_agent,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await?;
|
||||
|
||||
let filtered_body: serde_json::Value =
|
||||
serde_json::from_slice(&filtered_bytes).map_err(PipelineError::ParseError)?;
|
||||
serde_json::from_value(filtered_body["messages"].clone())
|
||||
.map_err(PipelineError::ParseError)?
|
||||
} else {
|
||||
current_messages.clone()
|
||||
};
|
||||
|
||||
// Get agent details and invoke
|
||||
let agent = agent_map.get(&agent_name).unwrap();
|
||||
|
|
|
|||
|
|
@ -172,7 +172,7 @@ impl AgentSelector {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{AgentFilterChain, Listener};
|
||||
use common::configuration::{AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
|
|
@ -187,14 +187,17 @@ mod tests {
|
|||
id: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
default: Some(is_default),
|
||||
filter_chain: Some(vec![name.to_string()]),
|
||||
input_filters: Some(vec![name.to_string()]),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_listener(name: &str, agents: Vec<AgentFilterChain>) -> Listener {
|
||||
Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: name.to_string(),
|
||||
agents: Some(agents),
|
||||
input_filters: None,
|
||||
output_filters: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use hyper::StatusCode;
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener};
|
||||
use common::configuration::{Agent, AgentFilterChain, Listener, ListenerType};
|
||||
|
||||
fn create_test_orchestrator_service() -> Arc<OrchestratorService> {
|
||||
Arc::new(OrchestratorService::new(
|
||||
|
|
@ -64,7 +64,7 @@ mod tests {
|
|||
|
||||
let agent_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec![
|
||||
input_filters: Some(vec![
|
||||
"filter-agent".to_string(),
|
||||
"terminal-agent".to_string(),
|
||||
]),
|
||||
|
|
@ -73,8 +73,11 @@ mod tests {
|
|||
};
|
||||
|
||||
let listener = Listener {
|
||||
listener_type: ListenerType::Agent,
|
||||
name: "test-listener".to_string(),
|
||||
agents: Some(vec![agent_pipeline.clone()]),
|
||||
input_filters: None,
|
||||
output_filters: None,
|
||||
port: 8080,
|
||||
router: None,
|
||||
};
|
||||
|
|
@ -107,23 +110,32 @@ mod tests {
|
|||
// Create a pipeline with empty filter chain to avoid network calls
|
||||
let test_pipeline = AgentFilterChain {
|
||||
id: "terminal-agent".to_string(),
|
||||
filter_chain: Some(vec![]), // Empty filter chain - no network calls needed
|
||||
input_filters: Some(vec![]), // Empty filter chain - no network calls needed
|
||||
description: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
let headers = HeaderMap::new();
|
||||
let request_bytes = serde_json::to_vec(&request).expect("failed to serialize request");
|
||||
let result = pipeline_processor
|
||||
.process_filter_chain(&request.messages, &test_pipeline, &agent_map, &headers)
|
||||
.process_raw_filter_chain(
|
||||
&request_bytes,
|
||||
&test_pipeline,
|
||||
&agent_map,
|
||||
&headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
println!("Pipeline processing result: {:?}", result);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let processed_messages = result.unwrap();
|
||||
// With empty filter chain, should return the original messages unchanged
|
||||
assert_eq!(processed_messages.len(), 1);
|
||||
if let Some(MessageContent::Text(content)) = &processed_messages[0].content {
|
||||
let processed_bytes = result.unwrap();
|
||||
// With empty filter chain, should return the original bytes unchanged
|
||||
let processed_request: ChatCompletionsRequest =
|
||||
serde_json::from_slice(&processed_bytes).expect("failed to deserialize response");
|
||||
assert_eq!(processed_request.messages.len(), 1);
|
||||
if let Some(MessageContent::Text(content)) = &processed_request.messages[0].content {
|
||||
assert_eq!(content, "Hello world!");
|
||||
} else {
|
||||
panic!("Expected text content");
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::{ModelAlias, SpanAttributes};
|
||||
use common::configuration::{FilterPipeline, ModelAlias, SpanAttributes};
|
||||
use common::consts::{
|
||||
ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER,
|
||||
};
|
||||
|
|
@ -8,9 +8,9 @@ use hermesllm::apis::openai_responses::InputParam;
|
|||
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::BodyExt;
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::header::{self};
|
||||
use hyper::{Request, Response};
|
||||
use hyper::{Request, Response, StatusCode};
|
||||
use opentelemetry::global;
|
||||
use opentelemetry::trace::get_active_span;
|
||||
use opentelemetry_http::HeaderInjector;
|
||||
|
|
@ -19,9 +19,12 @@ use std::sync::Arc;
|
|||
use tokio::sync::RwLock;
|
||||
use tracing::{debug, info, info_span, warn, Instrument};
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
|
||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||
use crate::handlers::utils::{
|
||||
create_streaming_response, truncate_message, ObservableStreamProcessor,
|
||||
use crate::handlers::streaming::{
|
||||
create_streaming_response, create_streaming_response_with_output_filter, truncate_message,
|
||||
ObservableStreamProcessor, StreamProcessor,
|
||||
};
|
||||
use crate::router::llm_router::RouterService;
|
||||
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||
|
|
@ -34,6 +37,7 @@ use crate::tracing::{
|
|||
|
||||
use common::errors::BrightStaffError;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn llm_chat(
|
||||
request: Request<hyper::body::Incoming>,
|
||||
router_service: Arc<RouterService>,
|
||||
|
|
@ -42,6 +46,7 @@ pub async fn llm_chat(
|
|||
llm_providers: Arc<RwLock<LlmProviders>>,
|
||||
span_attributes: Arc<Option<SpanAttributes>>,
|
||||
state_storage: Option<Arc<dyn StateStorage>>,
|
||||
filter_pipeline: Arc<FilterPipeline>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
let request_path = request.uri().path().to_string();
|
||||
let request_headers = request.headers().clone();
|
||||
|
|
@ -81,6 +86,7 @@ pub async fn llm_chat(
|
|||
request_id,
|
||||
request_path,
|
||||
request_headers,
|
||||
filter_pipeline,
|
||||
)
|
||||
.instrument(request_span)
|
||||
.await
|
||||
|
|
@ -98,6 +104,7 @@ async fn llm_chat_inner(
|
|||
request_id: String,
|
||||
request_path: String,
|
||||
mut request_headers: hyper::HeaderMap,
|
||||
filter_pipeline: Arc<FilterPipeline>,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
// Set service name for LLM operations
|
||||
set_service_name(operation_component::LLM);
|
||||
|
|
@ -250,6 +257,85 @@ async fn llm_chat_inner(
|
|||
if client_request.remove_metadata_key("plano_preference_config") {
|
||||
debug!("removed plano_preference_config from metadata");
|
||||
}
|
||||
|
||||
// === Input filters processing for model listener ===
|
||||
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
|
||||
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
|
||||
{
|
||||
if let Some(ref input_chain) = filter_pipeline.input {
|
||||
if !input_chain.is_empty() {
|
||||
debug!(input_filters = ?input_chain.filter_ids, "processing model listener input filters");
|
||||
|
||||
let chain = input_chain.to_agent_filter_chain("model_listener");
|
||||
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chat_request_bytes,
|
||||
&chain,
|
||||
&input_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_bytes) => {
|
||||
match ProviderRequestType::try_from((
|
||||
&filtered_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(updated_request) => {
|
||||
client_request = updated_request;
|
||||
info!("input filter chain processed successfully");
|
||||
}
|
||||
Err(parse_err) => {
|
||||
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Input filter returned invalid request: {}",
|
||||
parse_err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"client error from filter chain"
|
||||
);
|
||||
let error_json = serde_json::json!({
|
||||
"error": "FilterChainError",
|
||||
"agent": agent,
|
||||
"status": status,
|
||||
"agent_response": body
|
||||
});
|
||||
let mut error_response = Response::new(full(error_json.to_string()));
|
||||
*error_response.status_mut() =
|
||||
StatusCode::from_u16(status).unwrap_or(StatusCode::BAD_REQUEST);
|
||||
error_response.headers_mut().insert(
|
||||
hyper::header::CONTENT_TYPE,
|
||||
"application/json".parse().unwrap(),
|
||||
);
|
||||
return Ok(error_response);
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(error = %err, "filter chain processing failed");
|
||||
let err_msg = format!("Filter chain processing failed: {}", err);
|
||||
let mut internal_error = Response::new(full(err_msg));
|
||||
*internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
|
||||
return Ok(internal_error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref client_api_kind) = client_api {
|
||||
let upstream_api =
|
||||
provider_id.compatible_api_for_client(client_api_kind, is_streaming_request);
|
||||
|
|
@ -420,6 +506,18 @@ async fn llm_chat_inner(
|
|||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Output filters run for any API shape that reaches this handler (e.g. /v1/chat/completions,
|
||||
// /v1/messages, /v1/responses). Brightstaff does inbound translation and llm_gateway does
|
||||
// outbound translation; filters receive raw response bytes and request path.
|
||||
let has_output_filter = filter_pipeline.has_output_filters();
|
||||
|
||||
// Save request headers for output filters (before they're consumed by upstream request)
|
||||
let output_filter_request_headers = if has_output_filter {
|
||||
Some(request_headers.clone())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Capture start time right before sending request to upstream
|
||||
let request_start_time = std::time::Instant::now();
|
||||
let _request_start_system_time = std::time::SystemTime::now();
|
||||
|
|
@ -462,20 +560,18 @@ async fn llm_chat_inner(
|
|||
);
|
||||
|
||||
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||
let streaming_response = if let (true, false, Some(state_store)) = (
|
||||
// Pick the right processor: state-aware if needed, otherwise base metrics-only.
|
||||
let processor: Box<dyn StreamProcessor> = if let (true, false, Some(state_store)) = (
|
||||
should_manage_state,
|
||||
original_input_items.is_empty(),
|
||||
state_storage,
|
||||
) {
|
||||
// Extract Content-Encoding header to handle decompression for state parsing
|
||||
let content_encoding = response_headers
|
||||
.get("content-encoding")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
// Wrap with state management processor to store state after response completes
|
||||
let state_processor = ResponsesStateProcessor::new(
|
||||
Box::new(ResponsesStateProcessor::new(
|
||||
base_processor,
|
||||
state_store,
|
||||
original_input_items,
|
||||
|
|
@ -485,11 +581,23 @@ async fn llm_chat_inner(
|
|||
false, // Not OpenAI upstream since should_manage_state is true
|
||||
content_encoding,
|
||||
request_id,
|
||||
);
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
))
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
create_streaming_response(byte_stream, base_processor, 16)
|
||||
Box::new(base_processor)
|
||||
};
|
||||
|
||||
// Apply output filters if configured, then build the streaming response.
|
||||
let streaming_response = if has_output_filter {
|
||||
let output_chain = filter_pipeline.output.as_ref().unwrap().clone();
|
||||
create_streaming_response_with_output_filter(
|
||||
byte_stream,
|
||||
processor,
|
||||
output_chain,
|
||||
output_filter_request_headers.unwrap(),
|
||||
request_path.clone(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, processor)
|
||||
};
|
||||
|
||||
match response.body(streaming_response.body) {
|
||||
|
|
@ -570,3 +678,9 @@ async fn get_provider_info(
|
|||
(hermesllm::ProviderId::OpenAI, None)
|
||||
}
|
||||
}
|
||||
|
||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||
Full::new(chunk.into())
|
||||
.map_err(|never| match never {})
|
||||
.boxed()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ pub mod pipeline_processor;
|
|||
pub mod response_handler;
|
||||
pub mod router_chat;
|
||||
pub mod routing_service;
|
||||
pub mod utils;
|
||||
pub mod streaming;
|
||||
|
||||
#[cfg(test)]
|
||||
mod integration_tests;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -35,8 +36,6 @@ pub enum PipelineError {
|
|||
NoResultInResponse(String),
|
||||
#[error("No structured content in response from agent '{0}'")]
|
||||
NoStructuredContentInResponse(String),
|
||||
#[error("No messages in response from agent '{0}'")]
|
||||
NoMessagesInResponse(String),
|
||||
#[error("Client error from agent '{agent}' (HTTP {status}): {body}")]
|
||||
ClientError {
|
||||
agent: String,
|
||||
|
|
@ -79,68 +78,6 @@ impl PipelineProcessor {
|
|||
}
|
||||
}
|
||||
|
||||
// /// Process the filter chain of agents (all except the terminal agent)
|
||||
// #[instrument(
|
||||
// skip(self, chat_history, agent_filter_chain, agent_map, request_headers),
|
||||
// fields(
|
||||
// filter_count = agent_filter_chain.filter_chain.as_ref().map(|fc| fc.len()).unwrap_or(0),
|
||||
// message_count = chat_history.len()
|
||||
// )
|
||||
// )]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn process_filter_chain(
|
||||
&mut self,
|
||||
chat_history: &[Message],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
let mut chat_history_updated = chat_history.to_vec();
|
||||
|
||||
// If filter_chain is None or empty, proceed without filtering
|
||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(chat_history_updated),
|
||||
};
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
tool = %tool_name,
|
||||
url = %agent.url,
|
||||
agent_type = %agent.agent_type.as_deref().unwrap_or("mcp"),
|
||||
conversation_len = chat_history.len(),
|
||||
"executing filter"
|
||||
);
|
||||
|
||||
if agent.agent_type.as_deref().unwrap_or("mcp") == "mcp" {
|
||||
chat_history_updated = self
|
||||
.execute_mcp_filter(&chat_history_updated, agent, request_headers)
|
||||
.await?;
|
||||
} else {
|
||||
chat_history_updated = self
|
||||
.execute_http_filter(&chat_history_updated, agent, request_headers)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
updated_len = chat_history_updated.len(),
|
||||
"filter completed"
|
||||
);
|
||||
}
|
||||
|
||||
Ok(chat_history_updated)
|
||||
}
|
||||
|
||||
/// Build common MCP headers for requests
|
||||
fn build_mcp_headers(
|
||||
&self,
|
||||
|
|
@ -261,14 +198,17 @@ impl PipelineProcessor {
|
|||
Ok(response)
|
||||
}
|
||||
|
||||
/// Build a tools/call JSON-RPC request
|
||||
fn build_tool_call_request(
|
||||
/// Build a tools/call JSON-RPC request with a full body dict and path hint.
|
||||
/// Used by execute_mcp_filter_raw so MCP tools receive the same contract as HTTP filters.
|
||||
fn build_tool_call_request_with_body(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
messages: &[Message],
|
||||
body: &serde_json::Value,
|
||||
path: &str,
|
||||
) -> Result<JsonRpcRequest, PipelineError> {
|
||||
let mut arguments = HashMap::new();
|
||||
arguments.insert("messages".to_string(), serde_json::to_value(messages)?);
|
||||
arguments.insert("body".to_string(), serde_json::to_value(body)?);
|
||||
arguments.insert("path".to_string(), serde_json::to_value(path)?);
|
||||
|
||||
let mut params = HashMap::new();
|
||||
params.insert("name".to_string(), serde_json::to_value(tool_name)?);
|
||||
|
|
@ -282,31 +222,24 @@ impl PipelineProcessor {
|
|||
})
|
||||
}
|
||||
|
||||
/// Send request to a specific agent and return the response content
|
||||
#[instrument(
|
||||
skip(self, messages, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
filter_name = %agent.id,
|
||||
message_count = messages.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_mcp_filter(
|
||||
/// Like execute_mcp_filter_raw but passes the full raw body dict + path hint as MCP tool arguments.
|
||||
/// The MCP tool receives (body: dict, path: str) and returns the modified body dict.
|
||||
async fn execute_mcp_filter_raw(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Set service name for this filter span
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
|
||||
// Update current span name to include filter name
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_mcp_filter ({})", agent.id));
|
||||
span.update_name(format!("execute_mcp_filter_raw ({})", agent.id));
|
||||
});
|
||||
|
||||
// Get or create MCP session
|
||||
let body: serde_json::Value =
|
||||
serde_json::from_slice(raw_bytes).map_err(PipelineError::ParseError)?;
|
||||
|
||||
let mcp_session_id = if let Some(session_id) = self.agent_id_session_map.get(&agent.id) {
|
||||
session_id.clone()
|
||||
} else {
|
||||
|
|
@ -321,11 +254,10 @@ impl PipelineProcessor {
|
|||
mcp_session_id, agent.id
|
||||
);
|
||||
|
||||
// Build JSON-RPC request
|
||||
let tool_name = agent.tool.as_deref().unwrap_or(&agent.id);
|
||||
let json_rpc_request = self.build_tool_call_request(tool_name, messages)?;
|
||||
let json_rpc_request =
|
||||
self.build_tool_call_request_with_body(tool_name, &body, request_path)?;
|
||||
|
||||
// Build headers
|
||||
let agent_headers =
|
||||
self.build_mcp_headers(request_headers, &agent.id, Some(&mcp_session_id))?;
|
||||
|
||||
|
|
@ -335,7 +267,6 @@ impl PipelineProcessor {
|
|||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
|
|
@ -353,20 +284,12 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
info!(
|
||||
"Response from agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
|
||||
// Parse SSE response
|
||||
let data_chunk = self.parse_sse_response(&response_bytes, &agent.id)?;
|
||||
let response: JsonRpcResponse = serde_json::from_str(&data_chunk)?;
|
||||
let response_result = response
|
||||
.result
|
||||
.ok_or_else(|| PipelineError::NoResultInResponse(agent.id.clone()))?;
|
||||
|
||||
// Check if error field is set in response result
|
||||
if response_result
|
||||
.get("isError")
|
||||
.and_then(|v| v.as_bool())
|
||||
|
|
@ -388,21 +311,28 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
// Extract structured content and parse messages
|
||||
let response_json = response_result
|
||||
// FastMCP puts structured Pydantic return values in structuredContent.result,
|
||||
// but plain dicts land in content[0].text as a JSON string. Try both.
|
||||
let result = if let Some(structured) = response_result
|
||||
.get("structuredContent")
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
.and_then(|v| v.get("result"))
|
||||
.cloned()
|
||||
{
|
||||
structured
|
||||
} else {
|
||||
let text = response_result
|
||||
.get("content")
|
||||
.and_then(|v| v.as_array())
|
||||
.and_then(|arr| arr.first())
|
||||
.and_then(|v| v.get("text"))
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| PipelineError::NoStructuredContentInResponse(agent.id.clone()))?;
|
||||
serde_json::from_str(text).map_err(PipelineError::ParseError)?
|
||||
};
|
||||
|
||||
let messages: Vec<Message> = response_json
|
||||
.get("result")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| PipelineError::NoMessagesInResponse(agent.id.clone()))?
|
||||
.iter()
|
||||
.map(|msg_value| serde_json::from_value(msg_value.clone()))
|
||||
.collect::<Result<Vec<Message>, _>>()
|
||||
.map_err(PipelineError::ParseError)?;
|
||||
|
||||
Ok(messages)
|
||||
Ok(Bytes::from(
|
||||
serde_json::to_vec(&result).map_err(PipelineError::ParseError)?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Build an initialize JSON-RPC request
|
||||
|
|
@ -499,36 +429,34 @@ impl PipelineProcessor {
|
|||
session_id
|
||||
}
|
||||
|
||||
/// Execute a HTTP-based filter agent
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
#[instrument(
|
||||
skip(self, messages, agent, request_headers),
|
||||
skip(self, raw_bytes, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
agent_url = %agent.url,
|
||||
filter_name = %agent.id,
|
||||
message_count = messages.len()
|
||||
bytes_len = raw_bytes.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_http_filter(
|
||||
async fn execute_raw_filter(
|
||||
&mut self,
|
||||
messages: &[Message],
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Result<Vec<Message>, PipelineError> {
|
||||
// Set service name for this filter span
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
|
||||
// Update current span name to include filter name
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_http_filter ({})", agent.id));
|
||||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
// Build headers
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
// Inject OpenTelemetry trace context automatically
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
|
|
@ -541,40 +469,36 @@ impl PipelineProcessor {
|
|||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"Sending HTTP request to agent {} at URL: {}",
|
||||
agent.id, agent.url
|
||||
);
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
// Send messages array directly as request body
|
||||
let response = self
|
||||
.client
|
||||
.post(&agent.url)
|
||||
.post(&url)
|
||||
.headers(agent_headers)
|
||||
.json(&messages)
|
||||
.body(raw_bytes.to_vec())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
// Handle HTTP errors
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
|
|
@ -592,17 +516,56 @@ impl PipelineProcessor {
|
|||
});
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Response from HTTP agent {}: {}",
|
||||
agent.id,
|
||||
String::from_utf8_lossy(&response_bytes)
|
||||
);
|
||||
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
|
||||
Ok(response_bytes)
|
||||
}
|
||||
|
||||
// Parse response - expecting array of messages directly
|
||||
let messages: Vec<Message> =
|
||||
serde_json::from_slice(&response_bytes).map_err(PipelineError::ParseError)?;
|
||||
/// Process a chain of raw-bytes filters sequentially.
|
||||
/// Input: raw request or response bytes. Output: filtered bytes.
|
||||
/// Each agent receives the output of the previous one.
|
||||
pub async fn process_raw_filter_chain(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
let filter_chain = match agent_filter_chain.input_filters.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||
};
|
||||
|
||||
Ok(messages)
|
||||
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing raw filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
let agent_type = agent.agent_type.as_deref().unwrap_or("mcp");
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
url = %agent.url,
|
||||
agent_type = %agent_type,
|
||||
bytes_len = current_bytes.len(),
|
||||
"executing raw filter"
|
||||
);
|
||||
|
||||
current_bytes = if agent_type == "mcp" {
|
||||
self.execute_mcp_filter_raw(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?
|
||||
} else {
|
||||
self.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?
|
||||
};
|
||||
|
||||
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||
}
|
||||
|
||||
Ok(current_bytes)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
|
|
@ -661,24 +624,13 @@ impl PipelineProcessor {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use mockito::Server;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_message(role: Role, content: &str) -> Message {
|
||||
Message {
|
||||
role,
|
||||
content: Some(MessageContent::Text(content.to_string())),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_test_pipeline(agents: Vec<&str>) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: "test-agent".to_string(),
|
||||
filter_chain: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||
input_filters: Some(agents.iter().map(|s| s.to_string()).collect()),
|
||||
description: None,
|
||||
default: None,
|
||||
}
|
||||
|
|
@ -690,12 +642,19 @@ mod tests {
|
|||
let agent_map = HashMap::new();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
|
||||
let pipeline = create_test_pipeline(vec!["nonexistent-agent", "terminal-agent"]);
|
||||
|
||||
let result = processor
|
||||
.process_filter_chain(&messages, &pipeline, &agent_map, &request_headers)
|
||||
.process_raw_filter_chain(
|
||||
&raw_bytes,
|
||||
&pipeline,
|
||||
&agent_map,
|
||||
&request_headers,
|
||||
"/v1/chat/completions",
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
|
|
@ -725,11 +684,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hello")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hello"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -764,11 +724,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Ping")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Ping"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
@ -816,11 +777,12 @@ mod tests {
|
|||
agent_type: None,
|
||||
};
|
||||
|
||||
let messages = vec![create_test_message(Role::User, "Hi")];
|
||||
let body = serde_json::json!({"messages": [{"role": "user", "content": "Hi"}]});
|
||||
let raw_bytes = serde_json::to_vec(&body).unwrap();
|
||||
let request_headers = HeaderMap::new();
|
||||
|
||||
let result = processor
|
||||
.execute_mcp_filter(&messages, &agent, &request_headers)
|
||||
.execute_mcp_filter_raw(&raw_bytes, &agent, &request_headers, "/v1/chat/completions")
|
||||
.await;
|
||||
|
||||
match result {
|
||||
|
|
|
|||
|
|
@ -1,16 +1,21 @@
|
|||
use bytes::Bytes;
|
||||
use common::configuration::ResolvedFilterChain;
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::StreamBody;
|
||||
use hyper::body::Frame;
|
||||
use hyper::header::HeaderMap;
|
||||
use opentelemetry::trace::TraceContextExt;
|
||||
use opentelemetry::KeyValue;
|
||||
use std::time::Instant;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tokio_stream::StreamExt;
|
||||
use tracing::{info, warn, Instrument};
|
||||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
|
||||
const STREAM_BUFFER_SIZE: usize = 16;
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
|
@ -31,6 +36,21 @@ pub trait StreamProcessor: Send + 'static {
|
|||
fn on_error(&mut self, _error: &str) {}
|
||||
}
|
||||
|
||||
impl StreamProcessor for Box<dyn StreamProcessor> {
|
||||
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||
(**self).process_chunk(chunk)
|
||||
}
|
||||
fn on_first_bytes(&mut self) {
|
||||
(**self).on_first_bytes()
|
||||
}
|
||||
fn on_complete(&mut self) {
|
||||
(**self).on_complete()
|
||||
}
|
||||
fn on_error(&mut self, error: &str) {
|
||||
(**self).on_error(error)
|
||||
}
|
||||
}
|
||||
|
||||
/// A processor that tracks streaming metrics
|
||||
pub struct ObservableStreamProcessor {
|
||||
service_name: String,
|
||||
|
|
@ -206,16 +226,12 @@ pub struct StreamingResponse {
|
|||
pub processor_handle: tokio::task::JoinHandle<()>,
|
||||
}
|
||||
|
||||
pub fn create_streaming_response<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut processor: P,
|
||||
buffer_size: usize,
|
||||
) -> StreamingResponse
|
||||
pub fn create_streaming_response<S, P>(mut byte_stream: S, mut processor: P) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(buffer_size);
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
|
||||
|
||||
// Capture the current span so the spawned task inherits the request context
|
||||
let current_span = tracing::Span::current();
|
||||
|
|
@ -277,6 +293,108 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and request path (any API shape; not limited to
|
||||
/// chat completions). On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
output_chain: ResolvedFilterChain,
|
||||
request_headers: HeaderMap,
|
||||
request_path: String,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
P: StreamProcessor,
|
||||
{
|
||||
let (tx, rx) = mpsc::channel::<Bytes>(STREAM_BUFFER_SIZE);
|
||||
let current_span = tracing::Span::current();
|
||||
|
||||
let processor_handle = tokio::spawn(
|
||||
async move {
|
||||
let mut is_first_chunk = true;
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let chain = output_chain.to_agent_filter_chain("output_filter");
|
||||
|
||||
while let Some(item) = byte_stream.next().await {
|
||||
let chunk = match item {
|
||||
Ok(chunk) => chunk,
|
||||
Err(err) => {
|
||||
let err_msg = format!("Error receiving chunk: {:?}", err);
|
||||
warn!(error = %err_msg, "stream error");
|
||||
inner_processor.on_error(&err_msg);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if is_first_chunk {
|
||||
inner_processor.on_first_bytes();
|
||||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&chain,
|
||||
&output_chain.agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
match inner_processor.process_chunk(processed_chunk) {
|
||||
Ok(Some(final_chunk)) => {
|
||||
if tx.send(final_chunk).await.is_err() {
|
||||
warn!("receiver dropped");
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(None) => continue,
|
||||
Err(err) => {
|
||||
warn!("processor error: {}", err);
|
||||
inner_processor.on_error(&err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inner_processor.on_complete();
|
||||
debug!("output filter streaming completed");
|
||||
}
|
||||
.instrument(current_span),
|
||||
);
|
||||
|
||||
let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk)));
|
||||
let stream_body = BoxBody::new(StreamBody::new(stream));
|
||||
|
||||
StreamingResponse {
|
||||
body: stream_body,
|
||||
processor_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncates a message to the specified maximum length, adding "..." if truncated.
|
||||
pub fn truncate_message(message: &str, max_length: usize) -> String {
|
||||
if message.chars().count() > max_length {
|
||||
|
|
@ -10,7 +10,9 @@ use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
|||
use brightstaff::state::StateStorage;
|
||||
use brightstaff::utils::tracing::init_tracer;
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, Configuration};
|
||||
use common::configuration::{
|
||||
Agent, Configuration, FilterPipeline, ListenerType, ResolvedFilterChain,
|
||||
};
|
||||
use common::consts::{CHAT_COMPLETIONS_PATH, MESSAGES_PATH, OPENAI_RESPONSES_API_PATH};
|
||||
use common::llm_providers::LlmProviders;
|
||||
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
|
||||
|
|
@ -22,6 +24,7 @@ use hyper_util::rt::TokioIo;
|
|||
use opentelemetry::trace::FutureExt;
|
||||
use opentelemetry::{global, Context};
|
||||
use opentelemetry_http::HeaderExtractor;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::{env, fs};
|
||||
use tokio::net::TcpListener;
|
||||
|
|
@ -80,11 +83,49 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
.cloned()
|
||||
.collect();
|
||||
|
||||
// Build global agent map for resolving filter chain references
|
||||
let global_agent_map: HashMap<String, Agent> = all_agents
|
||||
.iter()
|
||||
.map(|a| (a.id.clone(), a.clone()))
|
||||
.collect();
|
||||
|
||||
// Create expanded provider list for /v1/models endpoint
|
||||
let llm_providers = LlmProviders::try_from(plano_config.model_providers.clone())
|
||||
.expect("Failed to create LlmProviders");
|
||||
let llm_providers = Arc::new(RwLock::new(llm_providers));
|
||||
let combined_agents_filters_list = Arc::new(RwLock::new(Some(all_agents)));
|
||||
|
||||
// Resolve model listener filter chain and agents at startup
|
||||
let model_listener_count = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.filter(|l| l.listener_type == ListenerType::Model)
|
||||
.count();
|
||||
assert!(
|
||||
model_listener_count <= 1,
|
||||
"only one model listener is allowed, found {}",
|
||||
model_listener_count
|
||||
);
|
||||
let model_listener = plano_config
|
||||
.listeners
|
||||
.iter()
|
||||
.find(|l| l.listener_type == ListenerType::Model);
|
||||
let resolve_chain = |filter_ids: Option<Vec<String>>| -> Option<ResolvedFilterChain> {
|
||||
filter_ids.map(|ids| {
|
||||
let agents = ids
|
||||
.iter()
|
||||
.filter_map(|id| global_agent_map.get(id).map(|a| (id.clone(), a.clone())))
|
||||
.collect();
|
||||
ResolvedFilterChain {
|
||||
filter_ids: ids,
|
||||
agents,
|
||||
}
|
||||
})
|
||||
};
|
||||
let filter_pipeline = Arc::new(FilterPipeline {
|
||||
input: resolve_chain(model_listener.and_then(|l| l.input_filters.clone())),
|
||||
output: resolve_chain(model_listener.and_then(|l| l.output_filters.clone())),
|
||||
});
|
||||
let listeners = Arc::new(RwLock::new(plano_config.listeners.clone()));
|
||||
let llm_provider_url =
|
||||
env::var("LLM_PROVIDER_ENDPOINT").unwrap_or_else(|_| "http://localhost:12001".to_string());
|
||||
|
|
@ -200,6 +241,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
|
||||
let llm_providers = llm_providers.clone();
|
||||
let agents_list = combined_agents_filters_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -211,6 +253,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
let llm_providers = llm_providers.clone();
|
||||
let model_aliases = Arc::clone(&model_aliases);
|
||||
let agents_list = agents_list.clone();
|
||||
let filter_pipeline = filter_pipeline.clone();
|
||||
let listeners = listeners.clone();
|
||||
let span_attributes = span_attributes.clone();
|
||||
let state_storage = state_storage.clone();
|
||||
|
|
@ -269,6 +312,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
|||
llm_providers,
|
||||
span_attributes,
|
||||
state_storage,
|
||||
filter_pipeline,
|
||||
)
|
||||
.with_context(parent_cx)
|
||||
.await
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use std::io::Read;
|
|||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::handlers::utils::StreamProcessor;
|
||||
use crate::handlers::streaming::StreamProcessor;
|
||||
use crate::state::{OpenAIConversationState, StateStorage};
|
||||
|
||||
/// Processor that wraps another processor and handles v1/responses state management
|
||||
|
|
|
|||
|
|
@ -27,14 +27,66 @@ pub struct AgentFilterChain {
|
|||
pub id: String,
|
||||
pub default: Option<bool>,
|
||||
pub description: Option<String>,
|
||||
pub filter_chain: Option<Vec<String>>,
|
||||
pub input_filters: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
/// A filter chain with its agent references resolved to concrete Agent objects.
|
||||
/// Bundles the ordered filter IDs with the agent lookup map so they stay in sync.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ResolvedFilterChain {
|
||||
pub filter_ids: Vec<String>,
|
||||
pub agents: HashMap<String, Agent>,
|
||||
}
|
||||
|
||||
impl ResolvedFilterChain {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.filter_ids.is_empty()
|
||||
}
|
||||
|
||||
pub fn to_agent_filter_chain(&self, id: &str) -> AgentFilterChain {
|
||||
AgentFilterChain {
|
||||
id: id.to_string(),
|
||||
default: None,
|
||||
description: None,
|
||||
input_filters: Some(self.filter_ids.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Holds resolved input and output filter chains for a model listener.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct FilterPipeline {
|
||||
pub input: Option<ResolvedFilterChain>,
|
||||
pub output: Option<ResolvedFilterChain>,
|
||||
}
|
||||
|
||||
impl FilterPipeline {
|
||||
pub fn has_input_filters(&self) -> bool {
|
||||
self.input.as_ref().is_some_and(|c| !c.is_empty())
|
||||
}
|
||||
|
||||
pub fn has_output_filters(&self) -> bool {
|
||||
self.output.as_ref().is_some_and(|c| !c.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ListenerType {
|
||||
Model,
|
||||
Agent,
|
||||
Prompt,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Listener {
|
||||
#[serde(rename = "type")]
|
||||
pub listener_type: ListenerType,
|
||||
pub name: String,
|
||||
pub router: Option<String>,
|
||||
pub agents: Option<Vec<AgentFilterChain>>,
|
||||
pub input_filters: Option<Vec<String>>,
|
||||
pub output_filters: Option<Vec<String>>,
|
||||
pub port: u16,
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue