mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
pass raw bytes through input/output filter chains
This commit is contained in:
parent
80dfb41cad
commit
d26abbfb9c
11 changed files with 610 additions and 448 deletions
|
|
@ -268,12 +268,13 @@ async fn llm_chat_inner(
|
||||||
}
|
}
|
||||||
|
|
||||||
// === Input filters processing for model listener ===
|
// === Input filters processing for model listener ===
|
||||||
|
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
|
||||||
|
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
|
||||||
{
|
{
|
||||||
if let Some(ref fc) = *input_filters {
|
if let Some(ref fc) = *input_filters {
|
||||||
if !fc.is_empty() {
|
if !fc.is_empty() {
|
||||||
debug!(input_filters = ?fc, "processing model listener input filters");
|
debug!(input_filters = ?fc, "processing model listener input filters");
|
||||||
|
|
||||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
|
||||||
let temp_filter_chain = AgentFilterChain {
|
let temp_filter_chain = AgentFilterChain {
|
||||||
id: "model_listener".to_string(),
|
id: "model_listener".to_string(),
|
||||||
default: None,
|
default: None,
|
||||||
|
|
@ -282,23 +283,34 @@ async fn llm_chat_inner(
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut pipeline_processor = PipelineProcessor::default();
|
let mut pipeline_processor = PipelineProcessor::default();
|
||||||
let messages = client_request.get_messages();
|
|
||||||
match pipeline_processor
|
match pipeline_processor
|
||||||
.process_filter_chain(
|
.process_raw_filter_chain(
|
||||||
&messages,
|
&chat_request_bytes,
|
||||||
&temp_filter_chain,
|
&temp_filter_chain,
|
||||||
&input_filter_agents,
|
&input_filter_agents,
|
||||||
&request_headers,
|
&request_headers,
|
||||||
|
&request_path,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(filtered_messages) => {
|
Ok(filtered_bytes) => {
|
||||||
client_request.set_messages(&filtered_messages);
|
match ProviderRequestType::try_from((
|
||||||
info!(
|
&filtered_bytes[..],
|
||||||
original_count = messages.len(),
|
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||||
filtered_count = filtered_messages.len(),
|
)) {
|
||||||
"filter chain processed successfully"
|
Ok(updated_request) => {
|
||||||
);
|
client_request = updated_request;
|
||||||
|
info!("input filter chain processed successfully");
|
||||||
|
}
|
||||||
|
Err(parse_err) => {
|
||||||
|
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||||
|
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||||
|
"Input filter returned invalid request: {}",
|
||||||
|
parse_err
|
||||||
|
))
|
||||||
|
.into_response());
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Err(super::pipeline_processor::PipelineError::ClientError {
|
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||||
agent,
|
agent,
|
||||||
|
|
@ -508,21 +520,25 @@ async fn llm_chat_inner(
|
||||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||||
});
|
});
|
||||||
|
|
||||||
// Output filters are only supported for /v1/chat/completions — the SSE content
|
|
||||||
// extraction logic is specific to that API shape (choices[].delta.content).
|
|
||||||
let output_filters_configured = output_filters
|
let output_filters_configured = output_filters
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.map(|fc| !fc.is_empty())
|
.map(|fc| !fc.is_empty())
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
let has_output_filter = output_filters_configured
|
let has_output_filter = output_filters_configured;
|
||||||
&& request_path == common::consts::CHAT_COMPLETIONS_PATH;
|
|
||||||
if output_filters_configured && !has_output_filter {
|
// Extract the upstream API path (e.g. "/v1/messages" from "https://api.anthropic.com/v1/messages").
|
||||||
warn!(
|
// Output filters are called at <agent.url><upstream_api_path> so they know the exact byte format.
|
||||||
path = %request_path,
|
let upstream_api_path = {
|
||||||
"output filters are configured but only supported for /v1/chat/completions, skipping"
|
let after_scheme = full_qualified_llm_provider_url
|
||||||
);
|
.find("://")
|
||||||
}
|
.map(|i| &full_qualified_llm_provider_url[i + 3..])
|
||||||
|
.unwrap_or(&full_qualified_llm_provider_url);
|
||||||
|
after_scheme
|
||||||
|
.find('/')
|
||||||
|
.map(|i| after_scheme[i..].to_string())
|
||||||
|
.unwrap_or_else(|| "/".to_string())
|
||||||
|
};
|
||||||
|
|
||||||
// Save request headers for output filters (before they're consumed by upstream request)
|
// Save request headers for output filters (before they're consumed by upstream request)
|
||||||
let output_filter_request_headers = if has_output_filter {
|
let output_filter_request_headers = if has_output_filter {
|
||||||
|
|
@ -607,6 +623,7 @@ async fn llm_chat_inner(
|
||||||
ofc,
|
ofc,
|
||||||
ofa,
|
ofa,
|
||||||
output_filter_request_headers.unwrap(),
|
output_filter_request_headers.unwrap(),
|
||||||
|
upstream_api_path.clone(),
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
create_streaming_response(byte_stream, state_processor, 16)
|
create_streaming_response(byte_stream, state_processor, 16)
|
||||||
|
|
@ -621,6 +638,7 @@ async fn llm_chat_inner(
|
||||||
ofc,
|
ofc,
|
||||||
ofa,
|
ofa,
|
||||||
output_filter_request_headers.unwrap(),
|
output_filter_request_headers.unwrap(),
|
||||||
|
upstream_api_path,
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
// Use base processor without state management
|
// Use base processor without state management
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, AgentFilterChain};
|
use common::configuration::{Agent, AgentFilterChain};
|
||||||
use common::consts::{
|
use common::consts::{
|
||||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||||
|
|
@ -605,6 +606,139 @@ impl PipelineProcessor {
|
||||||
Ok(messages)
|
Ok(messages)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||||
|
/// Used for input and output filters where the full raw request/response is passed through.
|
||||||
|
/// No MCP protocol wrapping; agent_type is ignored.
|
||||||
|
#[instrument(
|
||||||
|
skip(self, raw_bytes, agent, request_headers),
|
||||||
|
fields(
|
||||||
|
agent_id = %agent.id,
|
||||||
|
agent_url = %agent.url,
|
||||||
|
filter_name = %agent.id,
|
||||||
|
bytes_len = raw_bytes.len()
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
async fn execute_raw_filter(
|
||||||
|
&mut self,
|
||||||
|
raw_bytes: &[u8],
|
||||||
|
agent: &Agent,
|
||||||
|
request_headers: &HeaderMap,
|
||||||
|
request_path: &str,
|
||||||
|
) -> Result<Bytes, PipelineError> {
|
||||||
|
set_service_name(operation_component::AGENT_FILTER);
|
||||||
|
use opentelemetry::trace::get_active_span;
|
||||||
|
get_active_span(|span| {
|
||||||
|
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut agent_headers = request_headers.clone();
|
||||||
|
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||||
|
|
||||||
|
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||||
|
global::get_text_map_propagator(|propagator| {
|
||||||
|
let cx =
|
||||||
|
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||||
|
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||||
|
});
|
||||||
|
|
||||||
|
agent_headers.insert(
|
||||||
|
ARCH_UPSTREAM_HOST_HEADER,
|
||||||
|
hyper::header::HeaderValue::from_str(&agent.id)
|
||||||
|
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||||
|
);
|
||||||
|
agent_headers.insert(
|
||||||
|
ENVOY_RETRY_HEADER,
|
||||||
|
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||||
|
);
|
||||||
|
agent_headers.insert(
|
||||||
|
"Accept",
|
||||||
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
agent_headers.insert(
|
||||||
|
"Content-Type",
|
||||||
|
hyper::header::HeaderValue::from_static("application/json"),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Append the original request path so the filter endpoint encodes the API format.
|
||||||
|
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||||
|
// -> POST http://host/anonymize/v1/chat/completions
|
||||||
|
let url = format!("{}{}", agent.url, request_path);
|
||||||
|
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(&url)
|
||||||
|
.headers(agent_headers)
|
||||||
|
.body(raw_bytes.to_vec())
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let http_status = response.status();
|
||||||
|
let response_bytes = response.bytes().await?;
|
||||||
|
|
||||||
|
if !http_status.is_success() {
|
||||||
|
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||||
|
return Err(if http_status.is_client_error() {
|
||||||
|
PipelineError::ClientError {
|
||||||
|
agent: agent.id.clone(),
|
||||||
|
status: http_status.as_u16(),
|
||||||
|
body: error_body,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PipelineError::ServerError {
|
||||||
|
agent: agent.id.clone(),
|
||||||
|
status: http_status.as_u16(),
|
||||||
|
body: error_body,
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
|
||||||
|
Ok(response_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Process a chain of raw-bytes filters sequentially.
|
||||||
|
/// Input: raw request or response bytes. Output: filtered bytes.
|
||||||
|
/// Each agent receives the output of the previous one.
|
||||||
|
pub async fn process_raw_filter_chain(
|
||||||
|
&mut self,
|
||||||
|
raw_bytes: &[u8],
|
||||||
|
agent_filter_chain: &AgentFilterChain,
|
||||||
|
agent_map: &HashMap<String, Agent>,
|
||||||
|
request_headers: &HeaderMap,
|
||||||
|
request_path: &str,
|
||||||
|
) -> Result<Bytes, PipelineError> {
|
||||||
|
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||||
|
Some(fc) if !fc.is_empty() => fc,
|
||||||
|
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
|
||||||
|
|
||||||
|
for agent_name in filter_chain {
|
||||||
|
debug!(agent = %agent_name, "processing raw filter agent");
|
||||||
|
|
||||||
|
let agent = agent_map
|
||||||
|
.get(agent_name)
|
||||||
|
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||||
|
|
||||||
|
info!(
|
||||||
|
agent = %agent_name,
|
||||||
|
url = %agent.url,
|
||||||
|
bytes_len = current_bytes.len(),
|
||||||
|
"executing raw filter"
|
||||||
|
);
|
||||||
|
|
||||||
|
current_bytes = self
|
||||||
|
.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(current_bytes)
|
||||||
|
}
|
||||||
|
|
||||||
/// Send request to terminal agent and return the raw response for streaming
|
/// Send request to terminal agent and return the raw response for streaming
|
||||||
/// Note: The caller is responsible for creating the plano(agent) span that wraps
|
/// Note: The caller is responsible for creating the plano(agent) span that wraps
|
||||||
/// both this call and the subsequent response consumption.
|
/// both this call and the subsequent response consumption.
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
|
||||||
use tracing::{debug, info, warn, Instrument};
|
use tracing::{debug, info, warn, Instrument};
|
||||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||||
|
|
||||||
use super::pipeline_processor::PipelineProcessor;
|
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
use hermesllm::apis::openai::Message;
|
||||||
|
|
||||||
/// Trait for processing streaming chunks
|
/// Trait for processing streaming chunks
|
||||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||||
|
|
@ -281,184 +281,9 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
|
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||||
/// Returns the content string and whether it was found.
|
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
|
||||||
fn extract_sse_content(json_str: &str) -> Option<String> {
|
/// On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||||
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
|
||||||
value
|
|
||||||
.get("choices")?
|
|
||||||
.get(0)?
|
|
||||||
.get("delta")?
|
|
||||||
.get("content")?
|
|
||||||
.as_str()
|
|
||||||
.map(|s| s.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Replace content in an SSE JSON chunk with new content.
|
|
||||||
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
|
|
||||||
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
|
||||||
value
|
|
||||||
.get_mut("choices")?
|
|
||||||
.get_mut(0)?
|
|
||||||
.get_mut("delta")?
|
|
||||||
.as_object_mut()?
|
|
||||||
.insert(
|
|
||||||
"content".to_string(),
|
|
||||||
serde_json::Value::String(new_content.to_string()),
|
|
||||||
);
|
|
||||||
serde_json::to_string(&value).ok()
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Process an SSE chunk through the output filter chain.
|
|
||||||
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
|
|
||||||
async fn filter_sse_chunk(
|
|
||||||
chunk_str: &str,
|
|
||||||
pipeline_processor: &mut PipelineProcessor,
|
|
||||||
filter_chain: &AgentFilterChain,
|
|
||||||
filter_agents: &HashMap<String, Agent>,
|
|
||||||
request_headers: &HeaderMap,
|
|
||||||
) -> String {
|
|
||||||
let mut result = String::new();
|
|
||||||
for line in chunk_str.split('\n') {
|
|
||||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
|
||||||
if json_str.trim() == "[DONE]" {
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
if let Some(content) = extract_sse_content(json_str) {
|
|
||||||
if content.is_empty() {
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Send content through output filter chain
|
|
||||||
let messages = vec![Message {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: Some(MessageContent::Text(content)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
}];
|
|
||||||
match pipeline_processor
|
|
||||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(filtered_messages) => {
|
|
||||||
if let Some(msg) = filtered_messages.first() {
|
|
||||||
let filtered_content = match &msg.content {
|
|
||||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
if let Some(filtered_content) = filtered_content {
|
|
||||||
if let Some(new_json) =
|
|
||||||
replace_sse_content(json_str, &filtered_content)
|
|
||||||
{
|
|
||||||
result.push_str("data: ");
|
|
||||||
result.push_str(&new_json);
|
|
||||||
result.push('\n');
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Fallback: pass through original
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(error = %e, "output filter chain error, passing through original chunk");
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// No content in this SSE line, pass through
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
result.push_str(line);
|
|
||||||
result.push('\n');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Remove trailing extra newline if the original didn't end with one
|
|
||||||
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
|
|
||||||
result.pop();
|
|
||||||
}
|
|
||||||
result
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Process a non-streaming JSON response through the output filter chain.
|
|
||||||
/// Extracts assistant message content, filters it, and reconstructs the response.
|
|
||||||
pub async fn filter_non_streaming_response(
|
|
||||||
response_bytes: &[u8],
|
|
||||||
pipeline_processor: &mut PipelineProcessor,
|
|
||||||
filter_chain: &AgentFilterChain,
|
|
||||||
filter_agents: &HashMap<String, Agent>,
|
|
||||||
request_headers: &HeaderMap,
|
|
||||||
) -> Bytes {
|
|
||||||
let response_str = match std::str::from_utf8(response_bytes) {
|
|
||||||
Ok(s) => s,
|
|
||||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
|
|
||||||
Ok(v) => v,
|
|
||||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
|
||||||
};
|
|
||||||
|
|
||||||
// Extract content from choices[0].message.content
|
|
||||||
let content = value
|
|
||||||
.get("choices")
|
|
||||||
.and_then(|c| c.get(0))
|
|
||||||
.and_then(|c| c.get("message"))
|
|
||||||
.and_then(|m| m.get("content"))
|
|
||||||
.and_then(|c| c.as_str())
|
|
||||||
.map(|s| s.to_string());
|
|
||||||
|
|
||||||
if let Some(content) = content {
|
|
||||||
let messages = vec![Message {
|
|
||||||
role: Role::Assistant,
|
|
||||||
content: Some(MessageContent::Text(content)),
|
|
||||||
name: None,
|
|
||||||
tool_calls: None,
|
|
||||||
tool_call_id: None,
|
|
||||||
}];
|
|
||||||
match pipeline_processor
|
|
||||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(filtered_messages) => {
|
|
||||||
if let Some(msg) = filtered_messages.first() {
|
|
||||||
let filtered_content = match &msg.content {
|
|
||||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
if let Some(filtered_content) = filtered_content {
|
|
||||||
if let Some(choices) = value.get_mut("choices") {
|
|
||||||
if let Some(choice) = choices.get_mut(0) {
|
|
||||||
if let Some(message) = choice.get_mut("message") {
|
|
||||||
message.as_object_mut().unwrap().insert(
|
|
||||||
"content".to_string(),
|
|
||||||
serde_json::Value::String(filtered_content),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!(error = %e, "output filter chain error on non-streaming response");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Creates a streaming response that processes each chunk through output filters.
|
|
||||||
/// The output filter is called asynchronously for each SSE chunk's content.
|
|
||||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||||
mut byte_stream: S,
|
mut byte_stream: S,
|
||||||
mut inner_processor: P,
|
mut inner_processor: P,
|
||||||
|
|
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
|
||||||
output_filters: Vec<String>,
|
output_filters: Vec<String>,
|
||||||
output_filter_agents: HashMap<String, Agent>,
|
output_filter_agents: HashMap<String, Agent>,
|
||||||
request_headers: HeaderMap,
|
request_headers: HeaderMap,
|
||||||
|
upstream_path: String,
|
||||||
) -> StreamingResponse
|
) -> StreamingResponse
|
||||||
where
|
where
|
||||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||||
|
|
@ -501,32 +327,35 @@ where
|
||||||
is_first_chunk = false;
|
is_first_chunk = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try to process through output filter chain
|
// Pass raw chunk bytes through the output filter chain
|
||||||
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
|
let processed_chunk = match pipeline_processor
|
||||||
if chunk_str.contains("data: ") {
|
.process_raw_filter_chain(
|
||||||
let filtered = filter_sse_chunk(
|
&chunk,
|
||||||
chunk_str,
|
&temp_filter_chain,
|
||||||
&mut pipeline_processor,
|
&output_filter_agents,
|
||||||
&temp_filter_chain,
|
&request_headers,
|
||||||
&output_filter_agents,
|
&upstream_path,
|
||||||
&request_headers,
|
)
|
||||||
)
|
.await
|
||||||
.await;
|
{
|
||||||
Bytes::from(filtered)
|
Ok(filtered) => filtered,
|
||||||
} else {
|
Err(PipelineError::ClientError {
|
||||||
// Non-SSE chunk (could be non-streaming JSON response)
|
agent,
|
||||||
let filtered = filter_non_streaming_response(
|
status,
|
||||||
&chunk,
|
body,
|
||||||
&mut pipeline_processor,
|
}) => {
|
||||||
&temp_filter_chain,
|
warn!(
|
||||||
&output_filter_agents,
|
agent = %agent,
|
||||||
&request_headers,
|
status = %status,
|
||||||
)
|
body = %body,
|
||||||
.await;
|
"output filter client error, passing through original chunk"
|
||||||
filtered
|
);
|
||||||
|
chunk
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "output filter error, passing through original chunk");
|
||||||
|
chunk
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
chunk
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Pass through inner processor for metrics/observability
|
// Pass through inner processor for metrics/observability
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,11 @@
|
||||||
Run content-safety filters on direct LLM requests — no agent layer required.
|
Run content-safety filters on direct LLM requests — no agent layer required.
|
||||||
|
|
||||||
This demo uses the `input_filters` feature on a **model-type listener** to intercept
|
This demo uses the `input_filters` feature on a **model-type listener** to intercept
|
||||||
`/v1/chat/completions` requests and block unsafe content before they reach the LLM provider.
|
requests and block unsafe content before they reach the LLM provider. Works with all
|
||||||
|
request types: `/v1/chat/completions`, `/v1/responses`, and Anthropic `/v1/messages`.
|
||||||
|
|
||||||
|
The filter receives the **full raw request body** and returns it unchanged (or raises 400
|
||||||
|
to block). No message extraction — the complete JSON payload flows through as-is.
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,13 +3,15 @@ Content guard filter — keyword-based content safety for model listeners.
|
||||||
|
|
||||||
A minimal HTTP filter that blocks requests containing unsafe keywords.
|
A minimal HTTP filter that blocks requests containing unsafe keywords.
|
||||||
No LLM calls required — keeps the demo self-contained and fast.
|
No LLM calls required — keeps the demo self-contained and fast.
|
||||||
|
|
||||||
|
Receives the full raw request body (any API format: /v1/chat/completions,
|
||||||
|
/v1/responses, /v1/messages) and returns it unchanged or raises 400.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import List
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import FastAPI, Request, HTTPException
|
from fastapi import FastAPI, Request, HTTPException
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
|
@ -36,11 +38,6 @@ BLOCKED_KEYWORDS = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
def check_content(text: str) -> str | None:
|
def check_content(text: str) -> str | None:
|
||||||
"""Return the matched keyword if blocked, else None."""
|
"""Return the matched keyword if blocked, else None."""
|
||||||
lower = text.lower()
|
lower = text.lower()
|
||||||
|
|
@ -50,19 +47,58 @@ def check_content(text: str) -> str | None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.post("/")
|
def extract_last_user_text(body: dict[str, Any]) -> str | None:
|
||||||
async def content_guard(
|
"""Extract the most recent user message text from any supported request format."""
|
||||||
messages: List[ChatMessage], request: Request
|
messages = body.get("messages", [])
|
||||||
) -> List[ChatMessage]:
|
# Anthropic /v1/messages and OpenAI /v1/chat/completions both use "messages"
|
||||||
"""Block messages that contain unsafe keywords."""
|
|
||||||
last_user_msg = None
|
|
||||||
for msg in reversed(messages):
|
for msg in reversed(messages):
|
||||||
if msg.role == "user":
|
if msg.get("role") == "user":
|
||||||
last_user_msg = msg.content
|
content = msg.get("content", "")
|
||||||
break
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
# Multimodal content parts
|
||||||
|
return " ".join(
|
||||||
|
part.get("text", "")
|
||||||
|
for part in content
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text"
|
||||||
|
)
|
||||||
|
|
||||||
|
# OpenAI /v1/responses uses "input" instead of "messages"
|
||||||
|
input_val = body.get("input")
|
||||||
|
if isinstance(input_val, str):
|
||||||
|
return input_val
|
||||||
|
if isinstance(input_val, list):
|
||||||
|
for item in reversed(input_val):
|
||||||
|
if isinstance(item, dict) and item.get("role") == "user":
|
||||||
|
content = item.get("content", "")
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/{path:path}")
|
||||||
|
async def content_guard(path: str, request: Request) -> dict[str, Any]:
|
||||||
|
"""Block requests containing unsafe keywords. Returns the full request body unchanged.
|
||||||
|
|
||||||
|
The endpoint path encodes the API format:
|
||||||
|
/v1/chat/completions — check body["messages"]
|
||||||
|
/v1/responses — check body["input"]
|
||||||
|
/v1/messages — check body["messages"] (Anthropic format)
|
||||||
|
"""
|
||||||
|
endpoint = f"/{path}"
|
||||||
|
body = await request.json()
|
||||||
|
|
||||||
|
# /v1/responses uses "input" instead of "messages"
|
||||||
|
if endpoint == "/v1/responses":
|
||||||
|
input_val = body.get("input", "")
|
||||||
|
last_user_msg = input_val if isinstance(input_val, str) else None
|
||||||
|
else:
|
||||||
|
last_user_msg = extract_last_user_text(body)
|
||||||
|
|
||||||
if last_user_msg is None:
|
if last_user_msg is None:
|
||||||
return messages
|
return body
|
||||||
|
|
||||||
matched = check_content(last_user_msg)
|
matched = check_content(last_user_msg)
|
||||||
if matched:
|
if matched:
|
||||||
|
|
@ -76,7 +112,7 @@ async def content_guard(
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Content check passed — forwarding request")
|
logger.info("Content check passed — forwarding request")
|
||||||
return messages
|
return body
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|
|
||||||
|
|
@ -80,14 +80,26 @@ Check the PII filter service logs in the terminal running `start_agents.sh`. You
|
||||||
| Email | standard email format | `user@example.com` | `[EMAIL_0]` |
|
| Email | standard email format | `user@example.com` | `[EMAIL_0]` |
|
||||||
| Phone | US phone formats | `555-123-4567` | `[PHONE_0]` |
|
| Phone | US phone formats | `555-123-4567` | `[PHONE_0]` |
|
||||||
|
|
||||||
|
## Filter Contract
|
||||||
|
|
||||||
|
**Input filter (`/anonymize`)** receives the **full raw request body** and returns the modified body:
|
||||||
|
```json
|
||||||
|
{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Contact john@example.com"}], "stream": true}
|
||||||
|
```
|
||||||
|
→ returns the same structure with PII replaced in the `messages` array.
|
||||||
|
|
||||||
|
**Output filter (`/deanonymize`)** receives the **raw LLM response bytes** and returns modified bytes:
|
||||||
|
- *Streaming*: raw SSE chunk, e.g. `data: {"choices":[{"delta":{"content":"Contact [EMAIL_0]"}}]}`
|
||||||
|
- *Non-streaming*: full JSON response body
|
||||||
|
|
||||||
## How Streaming De-anonymization Works
|
## How Streaming De-anonymization Works
|
||||||
|
|
||||||
For streaming responses, each SSE chunk is sent through the output filters as it arrives from the LLM:
|
For streaming responses, each raw SSE chunk is sent through the output filter as it arrives from the LLM:
|
||||||
|
|
||||||
1. Plano receives a chunk with content like `"The email [EMAIL_0] belongs to..."`
|
1. Plano receives a raw SSE chunk like `data: {"choices":[{"delta":{"content":"The email [EMAIL_0] belongs to..."}}]}`
|
||||||
2. The chunk content is sent to the `/deanonymize` endpoint
|
2. The raw chunk bytes are sent to the `/deanonymize` endpoint
|
||||||
3. The filter looks up the PII mapping (stored during anonymization) and replaces placeholders
|
3. The filter parses the SSE, looks up the PII mapping (stored during anonymization), and replaces placeholders in the delta content
|
||||||
4. The restored chunk `"The email john@example.com belongs to..."` is streamed to the client
|
4. The restored chunk is returned and streamed to the client
|
||||||
|
|
||||||
Partial placeholders split across chunks (e.g., `[EMA` in one chunk, `IL_0]` in the next) are handled via internal buffering in the filter service.
|
Partial placeholders split across chunks (e.g., `[EMA` in one chunk, `IL_0]` in the next) are handled via internal buffering in the filter service.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ model_providers:
|
||||||
- model: openai/gpt-4o-mini
|
- model: openai/gpt-4o-mini
|
||||||
access_key: $OPENAI_API_KEY
|
access_key: $OPENAI_API_KEY
|
||||||
default: true
|
default: true
|
||||||
|
- model: anthropic/claude-sonnet-4-20250514
|
||||||
|
access_key: $ANTHROPIC_API_KEY
|
||||||
|
|
||||||
listeners:
|
listeners:
|
||||||
- type: model
|
- type: model
|
||||||
|
|
|
||||||
88
demos/filter_chains/pii_anonymizer/pii.py
Normal file
88
demos/filter_chains/pii_anonymizer/pii.py
Normal file
|
|
@ -0,0 +1,88 @@
|
||||||
|
"""PII detection and anonymization utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
# Order matters: SSN before phone to avoid overlap
|
||||||
|
PII_PATTERNS = [
|
||||||
|
("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")),
|
||||||
|
("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")),
|
||||||
|
("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")),
|
||||||
|
("PHONE", re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]:
|
||||||
|
"""Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping)."""
|
||||||
|
mapping: Dict[str, str] = {}
|
||||||
|
counters: Dict[str, int] = {}
|
||||||
|
matched_spans: List[Tuple[int, int]] = []
|
||||||
|
|
||||||
|
for pii_type, pattern in PII_PATTERNS:
|
||||||
|
for match in pattern.finditer(text):
|
||||||
|
start, end = match.start(), match.end()
|
||||||
|
if any(s <= start < e or s < end <= e for s, e in matched_spans):
|
||||||
|
continue
|
||||||
|
matched_spans.append((start, end))
|
||||||
|
idx = counters.get(pii_type, 0)
|
||||||
|
counters[pii_type] = idx + 1
|
||||||
|
mapping[f"[{pii_type}_{idx}]"] = match.group()
|
||||||
|
|
||||||
|
# Replace right-to-left to preserve span indices
|
||||||
|
matched_spans.sort(reverse=True)
|
||||||
|
result = text
|
||||||
|
for start, end in matched_spans:
|
||||||
|
placeholder = next(k for k, v in mapping.items() if v == text[start:end])
|
||||||
|
result = result[:start] + placeholder + result[end:]
|
||||||
|
|
||||||
|
return result, mapping
|
||||||
|
|
||||||
|
|
||||||
|
def deanonymize_text(
|
||||||
|
text: str, mapping: Dict[str, str], buffer: str = ""
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
"""Replace placeholders back with original PII values.
|
||||||
|
|
||||||
|
Handles partial placeholders at chunk boundaries via a buffer.
|
||||||
|
Returns (processed_text, remaining_buffer).
|
||||||
|
"""
|
||||||
|
combined = buffer + text
|
||||||
|
|
||||||
|
# Build prefix set for all known placeholders (e.g. "[EMAIL_0" is a prefix of "[EMAIL_0]")
|
||||||
|
prefixes: set[str] = set()
|
||||||
|
for placeholder in mapping:
|
||||||
|
for i in range(1, len(placeholder)):
|
||||||
|
prefixes.add(placeholder[:i])
|
||||||
|
|
||||||
|
# If the tail looks like the start of a placeholder, hold it in the buffer
|
||||||
|
remaining_buffer = ""
|
||||||
|
last_bracket = combined.rfind("[")
|
||||||
|
if last_bracket != -1 and "]" not in combined[last_bracket:]:
|
||||||
|
tail = combined[last_bracket:]
|
||||||
|
if tail in prefixes:
|
||||||
|
remaining_buffer = tail
|
||||||
|
combined = combined[:last_bracket]
|
||||||
|
|
||||||
|
for placeholder, original in mapping.items():
|
||||||
|
combined = combined.replace(placeholder, original)
|
||||||
|
|
||||||
|
return combined, remaining_buffer
|
||||||
|
|
||||||
|
|
||||||
|
def anonymize_message_content(content: Any, all_mappings: Dict[str, str]) -> Any:
|
||||||
|
"""Anonymize string content or list of content parts."""
|
||||||
|
if isinstance(content, str):
|
||||||
|
anonymized, mapping = anonymize_text(content)
|
||||||
|
all_mappings.update(mapping)
|
||||||
|
return anonymized
|
||||||
|
if isinstance(content, list):
|
||||||
|
result = []
|
||||||
|
for part in content:
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text":
|
||||||
|
anonymized, mapping = anonymize_text(part.get("text", ""))
|
||||||
|
all_mappings.update(mapping)
|
||||||
|
result.append({**part, "text": anonymized})
|
||||||
|
else:
|
||||||
|
result.append(part)
|
||||||
|
return result
|
||||||
|
return content
|
||||||
|
|
@ -5,18 +5,26 @@ Inspired by Uber's GenAI Gateway PII Redactor. Two endpoints:
|
||||||
POST /anonymize — replace PII with placeholders (input filter)
|
POST /anonymize — replace PII with placeholders (input filter)
|
||||||
POST /deanonymize — restore original PII from placeholders (output filter)
|
POST /deanonymize — restore original PII from placeholders (output filter)
|
||||||
|
|
||||||
Uses regex-based detection for: email, phone, SSN, credit card.
|
Input filter (/anonymize):
|
||||||
Correlates request/response via x-request-id header.
|
Receives the full raw request body (any API format). Anonymizes user message
|
||||||
|
content and returns the modified body.
|
||||||
|
|
||||||
|
Output filter (/deanonymize):
|
||||||
|
Receives raw LLM response bytes — SSE (streaming) or full JSON (non-streaming).
|
||||||
|
De-anonymizes content and returns modified bytes.
|
||||||
|
|
||||||
|
The path suffix encodes the upstream API format so each endpoint knows how to
|
||||||
|
parse the body (e.g. /anonymize/v1/chat/completions, /deanonymize/v1/messages).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
from typing import Any, Dict
|
||||||
import time
|
|
||||||
import threading
|
|
||||||
from typing import Dict, List, Optional, Tuple
|
|
||||||
|
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from pydantic import BaseModel
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from pii import anonymize_text, anonymize_message_content
|
||||||
|
from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
|
|
@ -26,205 +34,79 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = FastAPI(title="PII Anonymizer", version="1.0.0")
|
app = FastAPI(title="PII Anonymizer", version="1.0.0")
|
||||||
|
|
||||||
# --- PII patterns (order matters: SSN before phone to avoid overlap) ---
|
|
||||||
|
|
||||||
PII_PATTERNS = [
|
@app.post("/anonymize/{path:path}")
|
||||||
("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")),
|
async def anonymize(path: str, request: Request) -> dict[str, Any]:
|
||||||
("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")),
|
"""Anonymize PII in user messages. Receives and returns the full raw request body.
|
||||||
("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")),
|
|
||||||
(
|
|
||||||
"PHONE",
|
|
||||||
re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}"),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- In-memory mapping store (request_id -> mapping + timestamp) ---
|
The endpoint path encodes the API format:
|
||||||
|
/anonymize/v1/chat/completions — anonymize body["messages"]
|
||||||
_store_lock = threading.Lock()
|
/anonymize/v1/responses — anonymize body["input"] (string or items list)
|
||||||
_mapping_store: Dict[str, Tuple[Dict[str, str], float]] = {}
|
/anonymize/v1/messages — anonymize body["messages"] (Anthropic format)
|
||||||
# Buffer for partial placeholder matches during streaming de-anonymization
|
|
||||||
_buffer_store: Dict[str, str] = {}
|
|
||||||
MAPPING_TTL_SECONDS = 300 # 5 minutes
|
|
||||||
|
|
||||||
|
|
||||||
def _cleanup_expired():
|
|
||||||
"""Remove expired mappings."""
|
|
||||||
now = time.time()
|
|
||||||
expired = [
|
|
||||||
k for k, (_, ts) in _mapping_store.items() if now - ts > MAPPING_TTL_SECONDS
|
|
||||||
]
|
|
||||||
for k in expired:
|
|
||||||
del _mapping_store[k]
|
|
||||||
_buffer_store.pop(k, None)
|
|
||||||
|
|
||||||
|
|
||||||
def _store_mapping(request_id: str, mapping: Dict[str, str]):
|
|
||||||
with _store_lock:
|
|
||||||
_cleanup_expired()
|
|
||||||
_mapping_store[request_id] = (mapping, time.time())
|
|
||||||
|
|
||||||
|
|
||||||
def _get_mapping(request_id: str) -> Optional[Dict[str, str]]:
|
|
||||||
with _store_lock:
|
|
||||||
entry = _mapping_store.get(request_id)
|
|
||||||
if entry:
|
|
||||||
return entry[0]
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# --- Core logic ---
|
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]:
|
|
||||||
"""Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping)."""
|
|
||||||
mapping: Dict[str, str] = {}
|
|
||||||
counters: Dict[str, int] = {}
|
|
||||||
# Track spans already matched to avoid overlapping replacements
|
|
||||||
matched_spans: List[Tuple[int, int]] = []
|
|
||||||
|
|
||||||
for pii_type, pattern in PII_PATTERNS:
|
|
||||||
for match in pattern.finditer(text):
|
|
||||||
start, end = match.start(), match.end()
|
|
||||||
# Skip if this span overlaps with an already-matched span
|
|
||||||
if any(s <= start < e or s < end <= e for s, e in matched_spans):
|
|
||||||
continue
|
|
||||||
matched_spans.append((start, end))
|
|
||||||
idx = counters.get(pii_type, 0)
|
|
||||||
counters[pii_type] = idx + 1
|
|
||||||
placeholder = f"[{pii_type}_{idx}]"
|
|
||||||
mapping[placeholder] = match.group()
|
|
||||||
|
|
||||||
# Replace from right to left to preserve indices
|
|
||||||
matched_spans.sort(reverse=True)
|
|
||||||
result = text
|
|
||||||
for start, end in matched_spans:
|
|
||||||
original = text[start:end]
|
|
||||||
# Find the placeholder for this original value
|
|
||||||
placeholder = next(k for k, v in mapping.items() if v == original)
|
|
||||||
result = result[:start] + placeholder + result[end:]
|
|
||||||
|
|
||||||
return result, mapping
|
|
||||||
|
|
||||||
|
|
||||||
def deanonymize_text(
|
|
||||||
text: str, mapping: Dict[str, str], buffer: str = ""
|
|
||||||
) -> Tuple[str, str]:
|
|
||||||
"""Replace placeholders back with original PII values.
|
|
||||||
|
|
||||||
Handles partial placeholders across streaming chunks via a buffer.
|
|
||||||
Only buffers text that could be the prefix of an actual placeholder
|
|
||||||
from this request's mapping, not arbitrary ``[`` from normal text.
|
|
||||||
Returns (processed_text, remaining_buffer).
|
|
||||||
"""
|
"""
|
||||||
combined = buffer + text
|
|
||||||
|
|
||||||
# Build the set of all prefixes for placeholders in this request's mapping.
|
|
||||||
# e.g. for "[EMAIL_0]" -> {"[", "[E", "[EM", "[EMA", "[EMAI", "[EMAIL", "[EMAIL_", "[EMAIL_0"}
|
|
||||||
prefixes: set[str] = set()
|
|
||||||
for placeholder in mapping:
|
|
||||||
# Exclude the full placeholder (with closing ']') — that's a complete match, not partial
|
|
||||||
for i in range(1, len(placeholder)):
|
|
||||||
prefixes.add(placeholder[:i])
|
|
||||||
|
|
||||||
# Check if the end of the text could be a partial placeholder.
|
|
||||||
remaining_buffer = ""
|
|
||||||
last_bracket = combined.rfind("[")
|
|
||||||
if last_bracket != -1 and "]" not in combined[last_bracket:]:
|
|
||||||
tail = combined[last_bracket:]
|
|
||||||
if tail in prefixes:
|
|
||||||
remaining_buffer = tail
|
|
||||||
combined = combined[:last_bracket]
|
|
||||||
|
|
||||||
# Replace all complete placeholders
|
|
||||||
for placeholder, original in mapping.items():
|
|
||||||
combined = combined.replace(placeholder, original)
|
|
||||||
|
|
||||||
return combined, remaining_buffer
|
|
||||||
|
|
||||||
|
|
||||||
# --- Endpoints ---
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/anonymize")
|
|
||||||
async def anonymize(messages: List[ChatMessage], request: Request) -> List[ChatMessage]:
|
|
||||||
"""Anonymize PII in user messages. Stores mapping for later de-anonymization."""
|
|
||||||
request_id = request.headers.get("x-request-id", "unknown")
|
request_id = request.headers.get("x-request-id", "unknown")
|
||||||
|
endpoint = f"/{path}"
|
||||||
|
body = await request.json()
|
||||||
all_mappings: Dict[str, str] = {}
|
all_mappings: Dict[str, str] = {}
|
||||||
result_messages = []
|
|
||||||
|
|
||||||
for msg in messages:
|
if endpoint == "/v1/responses":
|
||||||
if msg.role == "user":
|
input_val = body.get("input", "")
|
||||||
anonymized, mapping = anonymize_text(msg.content)
|
if isinstance(input_val, str):
|
||||||
|
anonymized, mapping = anonymize_text(input_val)
|
||||||
all_mappings.update(mapping)
|
all_mappings.update(mapping)
|
||||||
result_messages.append(ChatMessage(role=msg.role, content=anonymized))
|
body = {**body, "input": anonymized}
|
||||||
else:
|
elif isinstance(input_val, list):
|
||||||
result_messages.append(msg)
|
items = [
|
||||||
|
{**item, "content": anonymize_message_content(item.get("content", ""), all_mappings)}
|
||||||
|
if isinstance(item, dict) and item.get("role") == "user"
|
||||||
|
else item
|
||||||
|
for item in input_val
|
||||||
|
]
|
||||||
|
body = {**body, "input": items}
|
||||||
|
else:
|
||||||
|
# /v1/chat/completions and /v1/messages both use "messages"
|
||||||
|
messages = [
|
||||||
|
{**msg, "content": anonymize_message_content(msg.get("content", ""), all_mappings)}
|
||||||
|
if msg.get("role") == "user"
|
||||||
|
else msg
|
||||||
|
for msg in body.get("messages", [])
|
||||||
|
]
|
||||||
|
if messages:
|
||||||
|
body = {**body, "messages": messages}
|
||||||
|
|
||||||
if all_mappings:
|
if all_mappings:
|
||||||
_store_mapping(request_id, all_mappings)
|
store_mapping(request_id, all_mappings)
|
||||||
logger.info(
|
logger.info("request_id=%s /anonymize mapping: %s", request_id, all_mappings)
|
||||||
"request_id=%s /anonymize mapping: %s",
|
|
||||||
request_id,
|
|
||||||
all_mappings,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info("request_id=%s no PII detected", request_id)
|
logger.info("request_id=%s no PII detected", request_id)
|
||||||
|
|
||||||
logger.info(
|
return body
|
||||||
"request_id=%s /anonymize input: %s -> output: %s",
|
|
||||||
request_id,
|
|
||||||
[m.content for m in messages],
|
|
||||||
[m.content for m in result_messages],
|
|
||||||
)
|
|
||||||
|
|
||||||
return result_messages
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/deanonymize")
|
@app.post("/deanonymize/{path:path}")
|
||||||
async def deanonymize(
|
async def deanonymize(path: str, request: Request) -> Response:
|
||||||
messages: List[ChatMessage], request: Request
|
"""De-anonymize PII placeholders in LLM response. Handles SSE (streaming) and JSON.
|
||||||
) -> List[ChatMessage]:
|
|
||||||
"""De-anonymize PII placeholders in response messages using stored mapping."""
|
The path encodes the upstream API format:
|
||||||
|
/deanonymize/v1/chat/completions — OpenAI chat completions
|
||||||
|
/deanonymize/v1/messages — Anthropic messages
|
||||||
|
/deanonymize/v1/responses — OpenAI responses API
|
||||||
|
"""
|
||||||
|
endpoint = f"/{path}"
|
||||||
|
is_anthropic = endpoint == "/v1/messages"
|
||||||
request_id = request.headers.get("x-request-id", "unknown")
|
request_id = request.headers.get("x-request-id", "unknown")
|
||||||
mapping = _get_mapping(request_id)
|
mapping = get_mapping(request_id)
|
||||||
|
raw_body = await request.body()
|
||||||
|
|
||||||
if not mapping:
|
if not mapping:
|
||||||
logger.info("request_id=%s no mapping found, passing through", request_id)
|
logger.info("request_id=%s no mapping found, passing through", request_id)
|
||||||
return messages
|
return Response(content=raw_body, media_type="application/json")
|
||||||
|
|
||||||
result_messages = []
|
body_str = raw_body.decode("utf-8", errors="replace")
|
||||||
for msg in messages:
|
|
||||||
if msg.role == "assistant" and msg.content:
|
|
||||||
with _store_lock:
|
|
||||||
buffer = _buffer_store.get(request_id, "")
|
|
||||||
|
|
||||||
restored, remaining = deanonymize_text(msg.content, mapping, buffer)
|
if "data: " in body_str:
|
||||||
|
return deanonymize_sse(request_id, body_str, mapping, is_anthropic)
|
||||||
with _store_lock:
|
return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic)
|
||||||
if remaining:
|
|
||||||
_buffer_store[request_id] = remaining
|
|
||||||
else:
|
|
||||||
_buffer_store.pop(request_id, None)
|
|
||||||
|
|
||||||
# Only log when a replacement actually happened
|
|
||||||
if restored != msg.content:
|
|
||||||
logger.info(
|
|
||||||
"request_id=%s /deanonymize '%s' -> '%s'",
|
|
||||||
request_id,
|
|
||||||
msg.content,
|
|
||||||
restored,
|
|
||||||
)
|
|
||||||
|
|
||||||
result_messages.append(ChatMessage(role=msg.role, content=restored))
|
|
||||||
else:
|
|
||||||
result_messages.append(msg)
|
|
||||||
|
|
||||||
return result_messages
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|
|
||||||
115
demos/filter_chains/pii_anonymizer/store.py
Normal file
115
demos/filter_chains/pii_anonymizer/store.py
Normal file
|
|
@ -0,0 +1,115 @@
|
||||||
|
"""In-memory mapping store and LLM response processors for PII de-anonymization."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
from fastapi.responses import Response
|
||||||
|
|
||||||
|
from pii import deanonymize_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MAPPING_TTL_SECONDS = 300 # 5 minutes
|
||||||
|
|
||||||
|
_lock = threading.Lock()
|
||||||
|
_mappings: Dict[str, Tuple[Dict[str, str], float]] = {}
|
||||||
|
_buffers: Dict[str, str] = {} # partial placeholder buffers for streaming
|
||||||
|
|
||||||
|
|
||||||
|
def _cleanup_expired():
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, (_, ts) in _mappings.items() if now - ts > MAPPING_TTL_SECONDS]
|
||||||
|
for k in expired:
|
||||||
|
del _mappings[k]
|
||||||
|
_buffers.pop(k, None)
|
||||||
|
|
||||||
|
|
||||||
|
def store_mapping(request_id: str, mapping: Dict[str, str]):
|
||||||
|
with _lock:
|
||||||
|
_cleanup_expired()
|
||||||
|
_mappings[request_id] = (mapping, time.time())
|
||||||
|
|
||||||
|
|
||||||
|
def get_mapping(request_id: str) -> Optional[Dict[str, str]]:
|
||||||
|
with _lock:
|
||||||
|
entry = _mappings.get(request_id)
|
||||||
|
return entry[0] if entry else None
|
||||||
|
|
||||||
|
|
||||||
|
def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) -> str:
|
||||||
|
"""Restore PII in one streaming chunk, maintaining the per-request partial buffer."""
|
||||||
|
with _lock:
|
||||||
|
buffer = _buffers.get(request_id, "")
|
||||||
|
restored, remaining = deanonymize_text(content, mapping, buffer)
|
||||||
|
with _lock:
|
||||||
|
if remaining:
|
||||||
|
_buffers[request_id] = remaining
|
||||||
|
else:
|
||||||
|
_buffers.pop(request_id, None)
|
||||||
|
if restored != content:
|
||||||
|
logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored)
|
||||||
|
return restored
|
||||||
|
|
||||||
|
|
||||||
|
def deanonymize_sse(
|
||||||
|
request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool
|
||||||
|
) -> Response:
|
||||||
|
result_lines = []
|
||||||
|
for line in body_str.split("\n"):
|
||||||
|
stripped = line.strip()
|
||||||
|
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
|
||||||
|
result_lines.append(line)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk = json.loads(stripped[6:])
|
||||||
|
if is_anthropic:
|
||||||
|
# {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}}
|
||||||
|
if chunk.get("type") == "content_block_delta":
|
||||||
|
delta = chunk.get("delta", {})
|
||||||
|
if delta.get("type") == "text_delta" and delta.get("text"):
|
||||||
|
delta["text"] = restore_streaming(request_id, delta["text"], mapping)
|
||||||
|
else:
|
||||||
|
# {"choices": [{"delta": {"content": "..."}}]}
|
||||||
|
for choice in chunk.get("choices", []):
|
||||||
|
delta = choice.get("delta", {})
|
||||||
|
if delta.get("content"):
|
||||||
|
delta["content"] = restore_streaming(request_id, delta["content"], mapping)
|
||||||
|
result_lines.append("data: " + json.dumps(chunk))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
result_lines.append(line)
|
||||||
|
return Response(content="\n".join(result_lines), media_type="text/plain")
|
||||||
|
|
||||||
|
|
||||||
|
def deanonymize_json(
|
||||||
|
request_id: str,
|
||||||
|
raw_body: bytes,
|
||||||
|
body_str: str,
|
||||||
|
mapping: Dict[str, str],
|
||||||
|
is_anthropic: bool,
|
||||||
|
) -> Response:
|
||||||
|
try:
|
||||||
|
body = json.loads(body_str)
|
||||||
|
if is_anthropic:
|
||||||
|
# {"content": [{"type": "text", "text": "..."}]}
|
||||||
|
for part in body.get("content", []):
|
||||||
|
if isinstance(part, dict) and part.get("type") == "text" and part.get("text"):
|
||||||
|
restored, _ = deanonymize_text(part["text"], mapping)
|
||||||
|
if restored != part["text"]:
|
||||||
|
logger.info("request_id=%s restored '%s' -> '%s'", request_id, part["text"], restored)
|
||||||
|
part["text"] = restored
|
||||||
|
else:
|
||||||
|
# {"choices": [{"message": {"content": "..."}}]}
|
||||||
|
for choice in body.get("choices", []):
|
||||||
|
message = choice.get("message", {})
|
||||||
|
content = message.get("content")
|
||||||
|
if content and isinstance(content, str):
|
||||||
|
restored, _ = deanonymize_text(content, mapping)
|
||||||
|
if restored != content:
|
||||||
|
logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored)
|
||||||
|
message["content"] = restored
|
||||||
|
return Response(content=json.dumps(body), media_type="application/json")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return Response(content=raw_body, media_type="application/json")
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
#!/usr/bin/env bash
|
#!/usr/bin/env bash
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
BASE_URL="http://localhost:12000/v1"
|
BASE_URL="http://localhost:12000"
|
||||||
PASS=0
|
PASS=0
|
||||||
FAIL=0
|
FAIL=0
|
||||||
|
|
||||||
# ── Wait for Plano to be ready ──────────────────────────────────────────────
|
# ── Wait for Plano to be ready ──────────────────────────────────────────────
|
||||||
echo "Waiting for Plano to be ready..."
|
echo "Waiting for Plano to be ready..."
|
||||||
for i in $(seq 1 30); do
|
for i in $(seq 1 30); do
|
||||||
if curl -sf "$BASE_URL/models" > /dev/null 2>&1; then
|
if curl -sf "$BASE_URL/v1/models" > /dev/null 2>&1; then
|
||||||
echo "Plano is ready."
|
echo "Plano is ready."
|
||||||
break
|
break
|
||||||
fi
|
fi
|
||||||
|
|
@ -22,11 +22,12 @@ done
|
||||||
# ── Helper ───────────────────────────────────────────────────────────────────
|
# ── Helper ───────────────────────────────────────────────────────────────────
|
||||||
run_test() {
|
run_test() {
|
||||||
local name="$1"
|
local name="$1"
|
||||||
local expected_code="$2"
|
local path="$2"
|
||||||
local body="$3"
|
local expected_code="$3"
|
||||||
|
local body="$4"
|
||||||
|
|
||||||
http_code=$(curl -s -o /tmp/plano_test_body -w "%{http_code}" \
|
http_code=$(curl -s -o /tmp/plano_test_body -w "%{http_code}" \
|
||||||
-X POST "$BASE_URL/chat/completions" \
|
-X POST "$BASE_URL$path" \
|
||||||
-H "Content-Type: application/json" \
|
-H "Content-Type: application/json" \
|
||||||
-d "$body")
|
-d "$body")
|
||||||
|
|
||||||
|
|
@ -40,34 +41,75 @@ run_test() {
|
||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Tests ────────────────────────────────────────────────────────────────────
|
# ── /v1/chat/completions ─────────────────────────────────────────────────────
|
||||||
echo ""
|
echo ""
|
||||||
echo "Running tests..."
|
echo "=== /v1/chat/completions ==="
|
||||||
|
|
||||||
run_test "Non-streaming with PII (email + phone)" 200 '{
|
run_test "Non-streaming with PII (email + phone)" /v1/chat/completions 200 '{
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [{"role": "user", "content": "Contact me at john@example.com or call 555-123-4567"}],
|
"messages": [{"role": "user", "content": "Contact me at john@example.com or call 555-123-4567"}],
|
||||||
"stream": false
|
"stream": false
|
||||||
}'
|
}'
|
||||||
|
|
||||||
run_test "Streaming with PII (SSN)" 200 '{
|
run_test "Streaming with PII (SSN)" /v1/chat/completions 200 '{
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [{"role": "user", "content": "My SSN is 123-45-6789, please help me file taxes"}],
|
"messages": [{"role": "user", "content": "My SSN is 123-45-6789, please help me file taxes"}],
|
||||||
"stream": true
|
"stream": true
|
||||||
}'
|
}'
|
||||||
|
|
||||||
run_test "No PII (clean message)" 200 '{
|
run_test "No PII (clean message)" /v1/chat/completions 200 '{
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [{"role": "user", "content": "What is 2+2?"}],
|
"messages": [{"role": "user", "content": "What is 2+2?"}],
|
||||||
"stream": false
|
"stream": false
|
||||||
}'
|
}'
|
||||||
|
|
||||||
run_test "Multiple PII types" 200 '{
|
run_test "Multiple PII types" /v1/chat/completions 200 '{
|
||||||
"model": "gpt-4o-mini",
|
"model": "gpt-4o-mini",
|
||||||
"messages": [{"role": "user", "content": "Email: test@test.com, Phone: 555-867-5309, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}],
|
"messages": [{"role": "user", "content": "Email: test@test.com, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}],
|
||||||
"stream": false
|
"stream": false
|
||||||
}'
|
}'
|
||||||
|
|
||||||
|
# ── /v1/responses ────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=== /v1/responses ==="
|
||||||
|
|
||||||
|
run_test "Non-streaming with PII (email)" /v1/responses 200 '{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"input": "My email is jane@example.com — can you summarize it?"
|
||||||
|
}'
|
||||||
|
|
||||||
|
run_test "Non-streaming with PII (credit card)" /v1/responses 200 '{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"input": "I need help disputing a charge on card 4111 1111 1111 1111"
|
||||||
|
}'
|
||||||
|
|
||||||
|
run_test "No PII" /v1/responses 200 '{
|
||||||
|
"model": "gpt-4o-mini",
|
||||||
|
"input": "What is the capital of France?"
|
||||||
|
}'
|
||||||
|
|
||||||
|
# ── /v1/messages (Anthropic) ─────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=== /v1/messages ==="
|
||||||
|
|
||||||
|
run_test "Non-streaming with PII (phone)" /v1/messages 200 '{
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role": "user", "content": "Call me at 555-867-5309 to discuss my account"}]
|
||||||
|
}'
|
||||||
|
|
||||||
|
run_test "Non-streaming with PII (SSN)" /v1/messages 200 '{
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role": "user", "content": "My SSN is 123-45-6789"}]
|
||||||
|
}'
|
||||||
|
|
||||||
|
run_test "No PII" /v1/messages 200 '{
|
||||||
|
"model": "claude-sonnet-4-20250514",
|
||||||
|
"max_tokens": 256,
|
||||||
|
"messages": [{"role": "user", "content": "Hello, how are you?"}]
|
||||||
|
}'
|
||||||
|
|
||||||
# ── Summary ──────────────────────────────────────────────────────────────────
|
# ── Summary ──────────────────────────────────────────────────────────────────
|
||||||
echo ""
|
echo ""
|
||||||
echo "Results: $PASS passed, $FAIL failed"
|
echo "Results: $PASS passed, $FAIL failed"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue