mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
pass raw bytes through input/output filter chains
This commit is contained in:
parent
80dfb41cad
commit
d26abbfb9c
11 changed files with 610 additions and 448 deletions
|
|
@ -268,12 +268,13 @@ async fn llm_chat_inner(
|
|||
}
|
||||
|
||||
// === Input filters processing for model listener ===
|
||||
// Filters receive the raw request bytes and return (possibly modified) raw bytes.
|
||||
// The returned bytes are re-parsed into a ProviderRequestType to continue the request.
|
||||
{
|
||||
if let Some(ref fc) = *input_filters {
|
||||
if !fc.is_empty() {
|
||||
debug!(input_filters = ?fc, "processing model listener input filters");
|
||||
|
||||
// Create a temporary AgentFilterChain to reuse PipelineProcessor
|
||||
let temp_filter_chain = AgentFilterChain {
|
||||
id: "model_listener".to_string(),
|
||||
default: None,
|
||||
|
|
@ -282,23 +283,34 @@ async fn llm_chat_inner(
|
|||
};
|
||||
|
||||
let mut pipeline_processor = PipelineProcessor::default();
|
||||
let messages = client_request.get_messages();
|
||||
match pipeline_processor
|
||||
.process_filter_chain(
|
||||
&messages,
|
||||
.process_raw_filter_chain(
|
||||
&chat_request_bytes,
|
||||
&temp_filter_chain,
|
||||
&input_filter_agents,
|
||||
&request_headers,
|
||||
&request_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
client_request.set_messages(&filtered_messages);
|
||||
info!(
|
||||
original_count = messages.len(),
|
||||
filtered_count = filtered_messages.len(),
|
||||
"filter chain processed successfully"
|
||||
);
|
||||
Ok(filtered_bytes) => {
|
||||
match ProviderRequestType::try_from((
|
||||
&filtered_bytes[..],
|
||||
&SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(),
|
||||
)) {
|
||||
Ok(updated_request) => {
|
||||
client_request = updated_request;
|
||||
info!("input filter chain processed successfully");
|
||||
}
|
||||
Err(parse_err) => {
|
||||
warn!(error = %parse_err, "input filter returned invalid request JSON");
|
||||
return Ok(BrightStaffError::InvalidRequest(format!(
|
||||
"Input filter returned invalid request: {}",
|
||||
parse_err
|
||||
))
|
||||
.into_response());
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(super::pipeline_processor::PipelineError::ClientError {
|
||||
agent,
|
||||
|
|
@ -508,21 +520,25 @@ async fn llm_chat_inner(
|
|||
propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers));
|
||||
});
|
||||
|
||||
// Output filters are only supported for /v1/chat/completions — the SSE content
|
||||
// extraction logic is specific to that API shape (choices[].delta.content).
|
||||
let output_filters_configured = output_filters
|
||||
.as_ref()
|
||||
.as_ref()
|
||||
.map(|fc| !fc.is_empty())
|
||||
.unwrap_or(false);
|
||||
let has_output_filter = output_filters_configured
|
||||
&& request_path == common::consts::CHAT_COMPLETIONS_PATH;
|
||||
if output_filters_configured && !has_output_filter {
|
||||
warn!(
|
||||
path = %request_path,
|
||||
"output filters are configured but only supported for /v1/chat/completions, skipping"
|
||||
);
|
||||
}
|
||||
let has_output_filter = output_filters_configured;
|
||||
|
||||
// Extract the upstream API path (e.g. "/v1/messages" from "https://api.anthropic.com/v1/messages").
|
||||
// Output filters are called at <agent.url><upstream_api_path> so they know the exact byte format.
|
||||
let upstream_api_path = {
|
||||
let after_scheme = full_qualified_llm_provider_url
|
||||
.find("://")
|
||||
.map(|i| &full_qualified_llm_provider_url[i + 3..])
|
||||
.unwrap_or(&full_qualified_llm_provider_url);
|
||||
after_scheme
|
||||
.find('/')
|
||||
.map(|i| after_scheme[i..].to_string())
|
||||
.unwrap_or_else(|| "/".to_string())
|
||||
};
|
||||
|
||||
// Save request headers for output filters (before they're consumed by upstream request)
|
||||
let output_filter_request_headers = if has_output_filter {
|
||||
|
|
@ -607,6 +623,7 @@ async fn llm_chat_inner(
|
|||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
upstream_api_path.clone(),
|
||||
)
|
||||
} else {
|
||||
create_streaming_response(byte_stream, state_processor, 16)
|
||||
|
|
@ -621,6 +638,7 @@ async fn llm_chat_inner(
|
|||
ofc,
|
||||
ofa,
|
||||
output_filter_request_headers.unwrap(),
|
||||
upstream_api_path,
|
||||
)
|
||||
} else {
|
||||
// Use base processor without state management
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use common::configuration::{Agent, AgentFilterChain};
|
||||
use common::consts::{
|
||||
ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER,
|
||||
|
|
@ -605,6 +606,139 @@ impl PipelineProcessor {
|
|||
Ok(messages)
|
||||
}
|
||||
|
||||
/// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back.
|
||||
/// Used for input and output filters where the full raw request/response is passed through.
|
||||
/// No MCP protocol wrapping; agent_type is ignored.
|
||||
#[instrument(
|
||||
skip(self, raw_bytes, agent, request_headers),
|
||||
fields(
|
||||
agent_id = %agent.id,
|
||||
agent_url = %agent.url,
|
||||
filter_name = %agent.id,
|
||||
bytes_len = raw_bytes.len()
|
||||
)
|
||||
)]
|
||||
async fn execute_raw_filter(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent: &Agent,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
set_service_name(operation_component::AGENT_FILTER);
|
||||
use opentelemetry::trace::get_active_span;
|
||||
get_active_span(|span| {
|
||||
span.update_name(format!("execute_raw_filter ({})", agent.id));
|
||||
});
|
||||
|
||||
let mut agent_headers = request_headers.clone();
|
||||
agent_headers.remove(hyper::header::CONTENT_LENGTH);
|
||||
|
||||
agent_headers.remove(TRACE_PARENT_HEADER);
|
||||
global::get_text_map_propagator(|propagator| {
|
||||
let cx =
|
||||
tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current());
|
||||
propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers));
|
||||
});
|
||||
|
||||
agent_headers.insert(
|
||||
ARCH_UPSTREAM_HOST_HEADER,
|
||||
hyper::header::HeaderValue::from_str(&agent.id)
|
||||
.map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?,
|
||||
);
|
||||
agent_headers.insert(
|
||||
ENVOY_RETRY_HEADER,
|
||||
hyper::header::HeaderValue::from_str("3").unwrap(),
|
||||
);
|
||||
agent_headers.insert(
|
||||
"Accept",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
agent_headers.insert(
|
||||
"Content-Type",
|
||||
hyper::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
// Append the original request path so the filter endpoint encodes the API format.
|
||||
// e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions"
|
||||
// -> POST http://host/anonymize/v1/chat/completions
|
||||
let url = format!("{}{}", agent.url, request_path);
|
||||
debug!(agent = %agent.id, url = %url, "sending raw filter request");
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.headers(agent_headers)
|
||||
.body(raw_bytes.to_vec())
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let http_status = response.status();
|
||||
let response_bytes = response.bytes().await?;
|
||||
|
||||
if !http_status.is_success() {
|
||||
let error_body = String::from_utf8_lossy(&response_bytes).to_string();
|
||||
return Err(if http_status.is_client_error() {
|
||||
PipelineError::ClientError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
} else {
|
||||
PipelineError::ServerError {
|
||||
agent: agent.id.clone(),
|
||||
status: http_status.as_u16(),
|
||||
body: error_body,
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received");
|
||||
Ok(response_bytes)
|
||||
}
|
||||
|
||||
/// Process a chain of raw-bytes filters sequentially.
|
||||
/// Input: raw request or response bytes. Output: filtered bytes.
|
||||
/// Each agent receives the output of the previous one.
|
||||
pub async fn process_raw_filter_chain(
|
||||
&mut self,
|
||||
raw_bytes: &[u8],
|
||||
agent_filter_chain: &AgentFilterChain,
|
||||
agent_map: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
request_path: &str,
|
||||
) -> Result<Bytes, PipelineError> {
|
||||
let filter_chain = match agent_filter_chain.filter_chain.as_ref() {
|
||||
Some(fc) if !fc.is_empty() => fc,
|
||||
_ => return Ok(Bytes::copy_from_slice(raw_bytes)),
|
||||
};
|
||||
|
||||
let mut current_bytes = Bytes::copy_from_slice(raw_bytes);
|
||||
|
||||
for agent_name in filter_chain {
|
||||
debug!(agent = %agent_name, "processing raw filter agent");
|
||||
|
||||
let agent = agent_map
|
||||
.get(agent_name)
|
||||
.ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?;
|
||||
|
||||
info!(
|
||||
agent = %agent_name,
|
||||
url = %agent.url,
|
||||
bytes_len = current_bytes.len(),
|
||||
"executing raw filter"
|
||||
);
|
||||
|
||||
current_bytes = self
|
||||
.execute_raw_filter(¤t_bytes, agent, request_headers, request_path)
|
||||
.await?;
|
||||
|
||||
info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed");
|
||||
}
|
||||
|
||||
Ok(current_bytes)
|
||||
}
|
||||
|
||||
/// Send request to terminal agent and return the raw response for streaming
|
||||
/// Note: The caller is responsible for creating the plano(agent) span that wraps
|
||||
/// both this call and the subsequent response consumption.
|
||||
|
|
|
|||
|
|
@ -14,10 +14,10 @@ use tokio_stream::StreamExt;
|
|||
use tracing::{debug, info, warn, Instrument};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
use super::pipeline_processor::PipelineProcessor;
|
||||
use super::pipeline_processor::{PipelineError, PipelineProcessor};
|
||||
use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER};
|
||||
use crate::tracing::{llm, set_service_name, signals as signal_constants};
|
||||
use hermesllm::apis::openai::{Message, MessageContent, Role};
|
||||
use hermesllm::apis::openai::Message;
|
||||
|
||||
/// Trait for processing streaming chunks
|
||||
/// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging)
|
||||
|
|
@ -281,184 +281,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract content text from an SSE chunk line (the JSON part after "data: ").
|
||||
/// Returns the content string and whether it was found.
|
||||
fn extract_sse_content(json_str: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get("choices")?
|
||||
.get(0)?
|
||||
.get("delta")?
|
||||
.get("content")?
|
||||
.as_str()
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
/// Replace content in an SSE JSON chunk with new content.
|
||||
fn replace_sse_content(json_str: &str, new_content: &str) -> Option<String> {
|
||||
let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?;
|
||||
value
|
||||
.get_mut("choices")?
|
||||
.get_mut(0)?
|
||||
.get_mut("delta")?
|
||||
.as_object_mut()?
|
||||
.insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(new_content.to_string()),
|
||||
);
|
||||
serde_json::to_string(&value).ok()
|
||||
}
|
||||
|
||||
/// Process an SSE chunk through the output filter chain.
|
||||
/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs.
|
||||
async fn filter_sse_chunk(
|
||||
chunk_str: &str,
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> String {
|
||||
let mut result = String::new();
|
||||
for line in chunk_str.split('\n') {
|
||||
if let Some(json_str) = line.strip_prefix("data: ") {
|
||||
if json_str.trim() == "[DONE]" {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
if let Some(content) = extract_sse_content(json_str) {
|
||||
if content.is_empty() {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
// Send content through output filter chain
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(new_json) =
|
||||
replace_sse_content(json_str, &filtered_content)
|
||||
{
|
||||
result.push_str("data: ");
|
||||
result.push_str(&new_json);
|
||||
result.push('\n');
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: pass through original
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error, passing through original chunk");
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// No content in this SSE line, pass through
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
} else {
|
||||
result.push_str(line);
|
||||
result.push('\n');
|
||||
}
|
||||
}
|
||||
// Remove trailing extra newline if the original didn't end with one
|
||||
if !chunk_str.ends_with('\n') && result.ends_with('\n') {
|
||||
result.pop();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Process a non-streaming JSON response through the output filter chain.
|
||||
/// Extracts assistant message content, filters it, and reconstructs the response.
|
||||
pub async fn filter_non_streaming_response(
|
||||
response_bytes: &[u8],
|
||||
pipeline_processor: &mut PipelineProcessor,
|
||||
filter_chain: &AgentFilterChain,
|
||||
filter_agents: &HashMap<String, Agent>,
|
||||
request_headers: &HeaderMap,
|
||||
) -> Bytes {
|
||||
let response_str = match std::str::from_utf8(response_bytes) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
let mut value: serde_json::Value = match serde_json::from_str(response_str) {
|
||||
Ok(v) => v,
|
||||
Err(_) => return Bytes::from(response_bytes.to_vec()),
|
||||
};
|
||||
|
||||
// Extract content from choices[0].message.content
|
||||
let content = value
|
||||
.get("choices")
|
||||
.and_then(|c| c.get(0))
|
||||
.and_then(|c| c.get("message"))
|
||||
.and_then(|m| m.get("content"))
|
||||
.and_then(|c| c.as_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
if let Some(content) = content {
|
||||
let messages = vec![Message {
|
||||
role: Role::Assistant,
|
||||
content: Some(MessageContent::Text(content)),
|
||||
name: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}];
|
||||
match pipeline_processor
|
||||
.process_filter_chain(&messages, filter_chain, filter_agents, request_headers)
|
||||
.await
|
||||
{
|
||||
Ok(filtered_messages) => {
|
||||
if let Some(msg) = filtered_messages.first() {
|
||||
let filtered_content = match &msg.content {
|
||||
Some(MessageContent::Text(t)) => Some(t.clone()),
|
||||
_ => None,
|
||||
};
|
||||
if let Some(filtered_content) = filtered_content {
|
||||
if let Some(choices) = value.get_mut("choices") {
|
||||
if let Some(choice) = choices.get_mut(0) {
|
||||
if let Some(message) = choice.get_mut("message") {
|
||||
message.as_object_mut().unwrap().insert(
|
||||
"content".to_string(),
|
||||
serde_json::Value::String(filtered_content),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter chain error on non-streaming response");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string()))
|
||||
}
|
||||
|
||||
/// Creates a streaming response that processes each chunk through output filters.
|
||||
/// The output filter is called asynchronously for each SSE chunk's content.
|
||||
/// Creates a streaming response that processes each raw chunk through output filters.
|
||||
/// Filters receive the raw LLM response bytes and return (possibly modified) bytes.
|
||||
/// On filter error mid-stream the original chunk is passed through (headers already sent).
|
||||
pub fn create_streaming_response_with_output_filter<S, P>(
|
||||
mut byte_stream: S,
|
||||
mut inner_processor: P,
|
||||
|
|
@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter<S, P>(
|
|||
output_filters: Vec<String>,
|
||||
output_filter_agents: HashMap<String, Agent>,
|
||||
request_headers: HeaderMap,
|
||||
upstream_path: String,
|
||||
) -> StreamingResponse
|
||||
where
|
||||
S: StreamExt<Item = Result<Bytes, reqwest::Error>> + Send + Unpin + 'static,
|
||||
|
|
@ -501,32 +327,35 @@ where
|
|||
is_first_chunk = false;
|
||||
}
|
||||
|
||||
// Try to process through output filter chain
|
||||
let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) {
|
||||
if chunk_str.contains("data: ") {
|
||||
let filtered = filter_sse_chunk(
|
||||
chunk_str,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
Bytes::from(filtered)
|
||||
} else {
|
||||
// Non-SSE chunk (could be non-streaming JSON response)
|
||||
let filtered = filter_non_streaming_response(
|
||||
&chunk,
|
||||
&mut pipeline_processor,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
)
|
||||
.await;
|
||||
filtered
|
||||
// Pass raw chunk bytes through the output filter chain
|
||||
let processed_chunk = match pipeline_processor
|
||||
.process_raw_filter_chain(
|
||||
&chunk,
|
||||
&temp_filter_chain,
|
||||
&output_filter_agents,
|
||||
&request_headers,
|
||||
&upstream_path,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(filtered) => filtered,
|
||||
Err(PipelineError::ClientError {
|
||||
agent,
|
||||
status,
|
||||
body,
|
||||
}) => {
|
||||
warn!(
|
||||
agent = %agent,
|
||||
status = %status,
|
||||
body = %body,
|
||||
"output filter client error, passing through original chunk"
|
||||
);
|
||||
chunk
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "output filter error, passing through original chunk");
|
||||
chunk
|
||||
}
|
||||
} else {
|
||||
chunk
|
||||
};
|
||||
|
||||
// Pass through inner processor for metrics/observability
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@
|
|||
Run content-safety filters on direct LLM requests — no agent layer required.
|
||||
|
||||
This demo uses the `input_filters` feature on a **model-type listener** to intercept
|
||||
`/v1/chat/completions` requests and block unsafe content before they reach the LLM provider.
|
||||
requests and block unsafe content before they reach the LLM provider. Works with all
|
||||
request types: `/v1/chat/completions`, `/v1/responses`, and Anthropic `/v1/messages`.
|
||||
|
||||
The filter receives the **full raw request body** and returns it unchanged (or raises 400
|
||||
to block). No message extraction — the complete JSON payload flows through as-is.
|
||||
|
||||
## Architecture
|
||||
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@ Content guard filter — keyword-based content safety for model listeners.
|
|||
|
||||
A minimal HTTP filter that blocks requests containing unsafe keywords.
|
||||
No LLM calls required — keeps the demo self-contained and fast.
|
||||
|
||||
Receives the full raw request body (any API format: /v1/chat/completions,
|
||||
/v1/responses, /v1/messages) and returns it unchanged or raises 400.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import List
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -36,11 +38,6 @@ BLOCKED_KEYWORDS = [
|
|||
]
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
def check_content(text: str) -> str | None:
|
||||
"""Return the matched keyword if blocked, else None."""
|
||||
lower = text.lower()
|
||||
|
|
@ -50,19 +47,58 @@ def check_content(text: str) -> str | None:
|
|||
return None
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def content_guard(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""Block messages that contain unsafe keywords."""
|
||||
last_user_msg = None
|
||||
def extract_last_user_text(body: dict[str, Any]) -> str | None:
|
||||
"""Extract the most recent user message text from any supported request format."""
|
||||
messages = body.get("messages", [])
|
||||
# Anthropic /v1/messages and OpenAI /v1/chat/completions both use "messages"
|
||||
for msg in reversed(messages):
|
||||
if msg.role == "user":
|
||||
last_user_msg = msg.content
|
||||
break
|
||||
if msg.get("role") == "user":
|
||||
content = msg.get("content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
# Multimodal content parts
|
||||
return " ".join(
|
||||
part.get("text", "")
|
||||
for part in content
|
||||
if isinstance(part, dict) and part.get("type") == "text"
|
||||
)
|
||||
|
||||
# OpenAI /v1/responses uses "input" instead of "messages"
|
||||
input_val = body.get("input")
|
||||
if isinstance(input_val, str):
|
||||
return input_val
|
||||
if isinstance(input_val, list):
|
||||
for item in reversed(input_val):
|
||||
if isinstance(item, dict) and item.get("role") == "user":
|
||||
content = item.get("content", "")
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@app.post("/{path:path}")
|
||||
async def content_guard(path: str, request: Request) -> dict[str, Any]:
|
||||
"""Block requests containing unsafe keywords. Returns the full request body unchanged.
|
||||
|
||||
The endpoint path encodes the API format:
|
||||
/v1/chat/completions — check body["messages"]
|
||||
/v1/responses — check body["input"]
|
||||
/v1/messages — check body["messages"] (Anthropic format)
|
||||
"""
|
||||
endpoint = f"/{path}"
|
||||
body = await request.json()
|
||||
|
||||
# /v1/responses uses "input" instead of "messages"
|
||||
if endpoint == "/v1/responses":
|
||||
input_val = body.get("input", "")
|
||||
last_user_msg = input_val if isinstance(input_val, str) else None
|
||||
else:
|
||||
last_user_msg = extract_last_user_text(body)
|
||||
|
||||
if last_user_msg is None:
|
||||
return messages
|
||||
return body
|
||||
|
||||
matched = check_content(last_user_msg)
|
||||
if matched:
|
||||
|
|
@ -76,7 +112,7 @@ async def content_guard(
|
|||
)
|
||||
|
||||
logger.info("Content check passed — forwarding request")
|
||||
return messages
|
||||
return body
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
|
|
|||
|
|
@ -80,14 +80,26 @@ Check the PII filter service logs in the terminal running `start_agents.sh`. You
|
|||
| Email | standard email format | `user@example.com` | `[EMAIL_0]` |
|
||||
| Phone | US phone formats | `555-123-4567` | `[PHONE_0]` |
|
||||
|
||||
## Filter Contract
|
||||
|
||||
**Input filter (`/anonymize`)** receives the **full raw request body** and returns the modified body:
|
||||
```json
|
||||
{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Contact john@example.com"}], "stream": true}
|
||||
```
|
||||
→ returns the same structure with PII replaced in the `messages` array.
|
||||
|
||||
**Output filter (`/deanonymize`)** receives the **raw LLM response bytes** and returns modified bytes:
|
||||
- *Streaming*: raw SSE chunk, e.g. `data: {"choices":[{"delta":{"content":"Contact [EMAIL_0]"}}]}`
|
||||
- *Non-streaming*: full JSON response body
|
||||
|
||||
## How Streaming De-anonymization Works
|
||||
|
||||
For streaming responses, each SSE chunk is sent through the output filters as it arrives from the LLM:
|
||||
For streaming responses, each raw SSE chunk is sent through the output filter as it arrives from the LLM:
|
||||
|
||||
1. Plano receives a chunk with content like `"The email [EMAIL_0] belongs to..."`
|
||||
2. The chunk content is sent to the `/deanonymize` endpoint
|
||||
3. The filter looks up the PII mapping (stored during anonymization) and replaces placeholders
|
||||
4. The restored chunk `"The email john@example.com belongs to..."` is streamed to the client
|
||||
1. Plano receives a raw SSE chunk like `data: {"choices":[{"delta":{"content":"The email [EMAIL_0] belongs to..."}}]}`
|
||||
2. The raw chunk bytes are sent to the `/deanonymize` endpoint
|
||||
3. The filter parses the SSE, looks up the PII mapping (stored during anonymization), and replaces placeholders in the delta content
|
||||
4. The restored chunk is returned and streamed to the client
|
||||
|
||||
Partial placeholders split across chunks (e.g., `[EMA` in one chunk, `IL_0]` in the next) are handled via internal buffering in the filter service.
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ model_providers:
|
|||
- model: openai/gpt-4o-mini
|
||||
access_key: $OPENAI_API_KEY
|
||||
default: true
|
||||
- model: anthropic/claude-sonnet-4-20250514
|
||||
access_key: $ANTHROPIC_API_KEY
|
||||
|
||||
listeners:
|
||||
- type: model
|
||||
|
|
|
|||
88
demos/filter_chains/pii_anonymizer/pii.py
Normal file
88
demos/filter_chains/pii_anonymizer/pii.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""PII detection and anonymization utilities."""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
# Order matters: SSN before phone to avoid overlap
|
||||
PII_PATTERNS = [
|
||||
("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")),
|
||||
("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")),
|
||||
("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")),
|
||||
("PHONE", re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")),
|
||||
]
|
||||
|
||||
|
||||
def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]:
|
||||
"""Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping)."""
|
||||
mapping: Dict[str, str] = {}
|
||||
counters: Dict[str, int] = {}
|
||||
matched_spans: List[Tuple[int, int]] = []
|
||||
|
||||
for pii_type, pattern in PII_PATTERNS:
|
||||
for match in pattern.finditer(text):
|
||||
start, end = match.start(), match.end()
|
||||
if any(s <= start < e or s < end <= e for s, e in matched_spans):
|
||||
continue
|
||||
matched_spans.append((start, end))
|
||||
idx = counters.get(pii_type, 0)
|
||||
counters[pii_type] = idx + 1
|
||||
mapping[f"[{pii_type}_{idx}]"] = match.group()
|
||||
|
||||
# Replace right-to-left to preserve span indices
|
||||
matched_spans.sort(reverse=True)
|
||||
result = text
|
||||
for start, end in matched_spans:
|
||||
placeholder = next(k for k, v in mapping.items() if v == text[start:end])
|
||||
result = result[:start] + placeholder + result[end:]
|
||||
|
||||
return result, mapping
|
||||
|
||||
|
||||
def deanonymize_text(
|
||||
text: str, mapping: Dict[str, str], buffer: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""Replace placeholders back with original PII values.
|
||||
|
||||
Handles partial placeholders at chunk boundaries via a buffer.
|
||||
Returns (processed_text, remaining_buffer).
|
||||
"""
|
||||
combined = buffer + text
|
||||
|
||||
# Build prefix set for all known placeholders (e.g. "[EMAIL_0" is a prefix of "[EMAIL_0]")
|
||||
prefixes: set[str] = set()
|
||||
for placeholder in mapping:
|
||||
for i in range(1, len(placeholder)):
|
||||
prefixes.add(placeholder[:i])
|
||||
|
||||
# If the tail looks like the start of a placeholder, hold it in the buffer
|
||||
remaining_buffer = ""
|
||||
last_bracket = combined.rfind("[")
|
||||
if last_bracket != -1 and "]" not in combined[last_bracket:]:
|
||||
tail = combined[last_bracket:]
|
||||
if tail in prefixes:
|
||||
remaining_buffer = tail
|
||||
combined = combined[:last_bracket]
|
||||
|
||||
for placeholder, original in mapping.items():
|
||||
combined = combined.replace(placeholder, original)
|
||||
|
||||
return combined, remaining_buffer
|
||||
|
||||
|
||||
def anonymize_message_content(content: Any, all_mappings: Dict[str, str]) -> Any:
|
||||
"""Anonymize string content or list of content parts."""
|
||||
if isinstance(content, str):
|
||||
anonymized, mapping = anonymize_text(content)
|
||||
all_mappings.update(mapping)
|
||||
return anonymized
|
||||
if isinstance(content, list):
|
||||
result = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
anonymized, mapping = anonymize_text(part.get("text", ""))
|
||||
all_mappings.update(mapping)
|
||||
result.append({**part, "text": anonymized})
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
return content
|
||||
|
|
@ -5,18 +5,26 @@ Inspired by Uber's GenAI Gateway PII Redactor. Two endpoints:
|
|||
POST /anonymize — replace PII with placeholders (input filter)
|
||||
POST /deanonymize — restore original PII from placeholders (output filter)
|
||||
|
||||
Uses regex-based detection for: email, phone, SSN, credit card.
|
||||
Correlates request/response via x-request-id header.
|
||||
Input filter (/anonymize):
|
||||
Receives the full raw request body (any API format). Anonymizes user message
|
||||
content and returns the modified body.
|
||||
|
||||
Output filter (/deanonymize):
|
||||
Receives raw LLM response bytes — SSE (streaming) or full JSON (non-streaming).
|
||||
De-anonymizes content and returns modified bytes.
|
||||
|
||||
The path suffix encodes the upstream API format so each endpoint knows how to
|
||||
parse the body (e.g. /anonymize/v1/chat/completions, /deanonymize/v1/messages).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from pydantic import BaseModel
|
||||
from fastapi.responses import Response
|
||||
|
||||
from pii import anonymize_text, anonymize_message_content
|
||||
from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -26,205 +34,79 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
app = FastAPI(title="PII Anonymizer", version="1.0.0")
|
||||
|
||||
# --- PII patterns (order matters: SSN before phone to avoid overlap) ---
|
||||
|
||||
PII_PATTERNS = [
|
||||
("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")),
|
||||
("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")),
|
||||
("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")),
|
||||
(
|
||||
"PHONE",
|
||||
re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}"),
|
||||
),
|
||||
]
|
||||
@app.post("/anonymize/{path:path}")
|
||||
async def anonymize(path: str, request: Request) -> dict[str, Any]:
|
||||
"""Anonymize PII in user messages. Receives and returns the full raw request body.
|
||||
|
||||
# --- In-memory mapping store (request_id -> mapping + timestamp) ---
|
||||
|
||||
_store_lock = threading.Lock()
|
||||
_mapping_store: Dict[str, Tuple[Dict[str, str], float]] = {}
|
||||
# Buffer for partial placeholder matches during streaming de-anonymization
|
||||
_buffer_store: Dict[str, str] = {}
|
||||
MAPPING_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
|
||||
def _cleanup_expired():
|
||||
"""Remove expired mappings."""
|
||||
now = time.time()
|
||||
expired = [
|
||||
k for k, (_, ts) in _mapping_store.items() if now - ts > MAPPING_TTL_SECONDS
|
||||
]
|
||||
for k in expired:
|
||||
del _mapping_store[k]
|
||||
_buffer_store.pop(k, None)
|
||||
|
||||
|
||||
def _store_mapping(request_id: str, mapping: Dict[str, str]):
|
||||
with _store_lock:
|
||||
_cleanup_expired()
|
||||
_mapping_store[request_id] = (mapping, time.time())
|
||||
|
||||
|
||||
def _get_mapping(request_id: str) -> Optional[Dict[str, str]]:
|
||||
with _store_lock:
|
||||
entry = _mapping_store.get(request_id)
|
||||
if entry:
|
||||
return entry[0]
|
||||
return None
|
||||
|
||||
|
||||
# --- Core logic ---
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]:
|
||||
"""Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping)."""
|
||||
mapping: Dict[str, str] = {}
|
||||
counters: Dict[str, int] = {}
|
||||
# Track spans already matched to avoid overlapping replacements
|
||||
matched_spans: List[Tuple[int, int]] = []
|
||||
|
||||
for pii_type, pattern in PII_PATTERNS:
|
||||
for match in pattern.finditer(text):
|
||||
start, end = match.start(), match.end()
|
||||
# Skip if this span overlaps with an already-matched span
|
||||
if any(s <= start < e or s < end <= e for s, e in matched_spans):
|
||||
continue
|
||||
matched_spans.append((start, end))
|
||||
idx = counters.get(pii_type, 0)
|
||||
counters[pii_type] = idx + 1
|
||||
placeholder = f"[{pii_type}_{idx}]"
|
||||
mapping[placeholder] = match.group()
|
||||
|
||||
# Replace from right to left to preserve indices
|
||||
matched_spans.sort(reverse=True)
|
||||
result = text
|
||||
for start, end in matched_spans:
|
||||
original = text[start:end]
|
||||
# Find the placeholder for this original value
|
||||
placeholder = next(k for k, v in mapping.items() if v == original)
|
||||
result = result[:start] + placeholder + result[end:]
|
||||
|
||||
return result, mapping
|
||||
|
||||
|
||||
def deanonymize_text(
|
||||
text: str, mapping: Dict[str, str], buffer: str = ""
|
||||
) -> Tuple[str, str]:
|
||||
"""Replace placeholders back with original PII values.
|
||||
|
||||
Handles partial placeholders across streaming chunks via a buffer.
|
||||
Only buffers text that could be the prefix of an actual placeholder
|
||||
from this request's mapping, not arbitrary ``[`` from normal text.
|
||||
Returns (processed_text, remaining_buffer).
|
||||
The endpoint path encodes the API format:
|
||||
/anonymize/v1/chat/completions — anonymize body["messages"]
|
||||
/anonymize/v1/responses — anonymize body["input"] (string or items list)
|
||||
/anonymize/v1/messages — anonymize body["messages"] (Anthropic format)
|
||||
"""
|
||||
combined = buffer + text
|
||||
|
||||
# Build the set of all prefixes for placeholders in this request's mapping.
|
||||
# e.g. for "[EMAIL_0]" -> {"[", "[E", "[EM", "[EMA", "[EMAI", "[EMAIL", "[EMAIL_", "[EMAIL_0"}
|
||||
prefixes: set[str] = set()
|
||||
for placeholder in mapping:
|
||||
# Exclude the full placeholder (with closing ']') — that's a complete match, not partial
|
||||
for i in range(1, len(placeholder)):
|
||||
prefixes.add(placeholder[:i])
|
||||
|
||||
# Check if the end of the text could be a partial placeholder.
|
||||
remaining_buffer = ""
|
||||
last_bracket = combined.rfind("[")
|
||||
if last_bracket != -1 and "]" not in combined[last_bracket:]:
|
||||
tail = combined[last_bracket:]
|
||||
if tail in prefixes:
|
||||
remaining_buffer = tail
|
||||
combined = combined[:last_bracket]
|
||||
|
||||
# Replace all complete placeholders
|
||||
for placeholder, original in mapping.items():
|
||||
combined = combined.replace(placeholder, original)
|
||||
|
||||
return combined, remaining_buffer
|
||||
|
||||
|
||||
# --- Endpoints ---
|
||||
|
||||
|
||||
@app.post("/anonymize")
|
||||
async def anonymize(messages: List[ChatMessage], request: Request) -> List[ChatMessage]:
|
||||
"""Anonymize PII in user messages. Stores mapping for later de-anonymization."""
|
||||
request_id = request.headers.get("x-request-id", "unknown")
|
||||
endpoint = f"/{path}"
|
||||
body = await request.json()
|
||||
all_mappings: Dict[str, str] = {}
|
||||
result_messages = []
|
||||
|
||||
for msg in messages:
|
||||
if msg.role == "user":
|
||||
anonymized, mapping = anonymize_text(msg.content)
|
||||
if endpoint == "/v1/responses":
|
||||
input_val = body.get("input", "")
|
||||
if isinstance(input_val, str):
|
||||
anonymized, mapping = anonymize_text(input_val)
|
||||
all_mappings.update(mapping)
|
||||
result_messages.append(ChatMessage(role=msg.role, content=anonymized))
|
||||
else:
|
||||
result_messages.append(msg)
|
||||
body = {**body, "input": anonymized}
|
||||
elif isinstance(input_val, list):
|
||||
items = [
|
||||
{**item, "content": anonymize_message_content(item.get("content", ""), all_mappings)}
|
||||
if isinstance(item, dict) and item.get("role") == "user"
|
||||
else item
|
||||
for item in input_val
|
||||
]
|
||||
body = {**body, "input": items}
|
||||
else:
|
||||
# /v1/chat/completions and /v1/messages both use "messages"
|
||||
messages = [
|
||||
{**msg, "content": anonymize_message_content(msg.get("content", ""), all_mappings)}
|
||||
if msg.get("role") == "user"
|
||||
else msg
|
||||
for msg in body.get("messages", [])
|
||||
]
|
||||
if messages:
|
||||
body = {**body, "messages": messages}
|
||||
|
||||
if all_mappings:
|
||||
_store_mapping(request_id, all_mappings)
|
||||
logger.info(
|
||||
"request_id=%s /anonymize mapping: %s",
|
||||
request_id,
|
||||
all_mappings,
|
||||
)
|
||||
store_mapping(request_id, all_mappings)
|
||||
logger.info("request_id=%s /anonymize mapping: %s", request_id, all_mappings)
|
||||
else:
|
||||
logger.info("request_id=%s no PII detected", request_id)
|
||||
|
||||
logger.info(
|
||||
"request_id=%s /anonymize input: %s -> output: %s",
|
||||
request_id,
|
||||
[m.content for m in messages],
|
||||
[m.content for m in result_messages],
|
||||
)
|
||||
|
||||
return result_messages
|
||||
return body
|
||||
|
||||
|
||||
@app.post("/deanonymize")
|
||||
async def deanonymize(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""De-anonymize PII placeholders in response messages using stored mapping."""
|
||||
@app.post("/deanonymize/{path:path}")
|
||||
async def deanonymize(path: str, request: Request) -> Response:
|
||||
"""De-anonymize PII placeholders in LLM response. Handles SSE (streaming) and JSON.
|
||||
|
||||
The path encodes the upstream API format:
|
||||
/deanonymize/v1/chat/completions — OpenAI chat completions
|
||||
/deanonymize/v1/messages — Anthropic messages
|
||||
/deanonymize/v1/responses — OpenAI responses API
|
||||
"""
|
||||
endpoint = f"/{path}"
|
||||
is_anthropic = endpoint == "/v1/messages"
|
||||
request_id = request.headers.get("x-request-id", "unknown")
|
||||
mapping = _get_mapping(request_id)
|
||||
mapping = get_mapping(request_id)
|
||||
raw_body = await request.body()
|
||||
|
||||
if not mapping:
|
||||
logger.info("request_id=%s no mapping found, passing through", request_id)
|
||||
return messages
|
||||
return Response(content=raw_body, media_type="application/json")
|
||||
|
||||
result_messages = []
|
||||
for msg in messages:
|
||||
if msg.role == "assistant" and msg.content:
|
||||
with _store_lock:
|
||||
buffer = _buffer_store.get(request_id, "")
|
||||
body_str = raw_body.decode("utf-8", errors="replace")
|
||||
|
||||
restored, remaining = deanonymize_text(msg.content, mapping, buffer)
|
||||
|
||||
with _store_lock:
|
||||
if remaining:
|
||||
_buffer_store[request_id] = remaining
|
||||
else:
|
||||
_buffer_store.pop(request_id, None)
|
||||
|
||||
# Only log when a replacement actually happened
|
||||
if restored != msg.content:
|
||||
logger.info(
|
||||
"request_id=%s /deanonymize '%s' -> '%s'",
|
||||
request_id,
|
||||
msg.content,
|
||||
restored,
|
||||
)
|
||||
|
||||
result_messages.append(ChatMessage(role=msg.role, content=restored))
|
||||
else:
|
||||
result_messages.append(msg)
|
||||
|
||||
return result_messages
|
||||
if "data: " in body_str:
|
||||
return deanonymize_sse(request_id, body_str, mapping, is_anthropic)
|
||||
return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
|
|
|||
115
demos/filter_chains/pii_anonymizer/store.py
Normal file
115
demos/filter_chains/pii_anonymizer/store.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
"""In-memory mapping store and LLM response processors for PII de-anonymization."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from fastapi.responses import Response
|
||||
|
||||
from pii import deanonymize_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAPPING_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
_lock = threading.Lock()
|
||||
_mappings: Dict[str, Tuple[Dict[str, str], float]] = {}
|
||||
_buffers: Dict[str, str] = {} # partial placeholder buffers for streaming
|
||||
|
||||
|
||||
def _cleanup_expired():
|
||||
now = time.time()
|
||||
expired = [k for k, (_, ts) in _mappings.items() if now - ts > MAPPING_TTL_SECONDS]
|
||||
for k in expired:
|
||||
del _mappings[k]
|
||||
_buffers.pop(k, None)
|
||||
|
||||
|
||||
def store_mapping(request_id: str, mapping: Dict[str, str]):
|
||||
with _lock:
|
||||
_cleanup_expired()
|
||||
_mappings[request_id] = (mapping, time.time())
|
||||
|
||||
|
||||
def get_mapping(request_id: str) -> Optional[Dict[str, str]]:
|
||||
with _lock:
|
||||
entry = _mappings.get(request_id)
|
||||
return entry[0] if entry else None
|
||||
|
||||
|
||||
def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) -> str:
|
||||
"""Restore PII in one streaming chunk, maintaining the per-request partial buffer."""
|
||||
with _lock:
|
||||
buffer = _buffers.get(request_id, "")
|
||||
restored, remaining = deanonymize_text(content, mapping, buffer)
|
||||
with _lock:
|
||||
if remaining:
|
||||
_buffers[request_id] = remaining
|
||||
else:
|
||||
_buffers.pop(request_id, None)
|
||||
if restored != content:
|
||||
logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored)
|
||||
return restored
|
||||
|
||||
|
||||
def deanonymize_sse(
|
||||
request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool
|
||||
) -> Response:
|
||||
result_lines = []
|
||||
for line in body_str.split("\n"):
|
||||
stripped = line.strip()
|
||||
if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"):
|
||||
result_lines.append(line)
|
||||
continue
|
||||
try:
|
||||
chunk = json.loads(stripped[6:])
|
||||
if is_anthropic:
|
||||
# {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}}
|
||||
if chunk.get("type") == "content_block_delta":
|
||||
delta = chunk.get("delta", {})
|
||||
if delta.get("type") == "text_delta" and delta.get("text"):
|
||||
delta["text"] = restore_streaming(request_id, delta["text"], mapping)
|
||||
else:
|
||||
# {"choices": [{"delta": {"content": "..."}}]}
|
||||
for choice in chunk.get("choices", []):
|
||||
delta = choice.get("delta", {})
|
||||
if delta.get("content"):
|
||||
delta["content"] = restore_streaming(request_id, delta["content"], mapping)
|
||||
result_lines.append("data: " + json.dumps(chunk))
|
||||
except json.JSONDecodeError:
|
||||
result_lines.append(line)
|
||||
return Response(content="\n".join(result_lines), media_type="text/plain")
|
||||
|
||||
|
||||
def deanonymize_json(
|
||||
request_id: str,
|
||||
raw_body: bytes,
|
||||
body_str: str,
|
||||
mapping: Dict[str, str],
|
||||
is_anthropic: bool,
|
||||
) -> Response:
|
||||
try:
|
||||
body = json.loads(body_str)
|
||||
if is_anthropic:
|
||||
# {"content": [{"type": "text", "text": "..."}]}
|
||||
for part in body.get("content", []):
|
||||
if isinstance(part, dict) and part.get("type") == "text" and part.get("text"):
|
||||
restored, _ = deanonymize_text(part["text"], mapping)
|
||||
if restored != part["text"]:
|
||||
logger.info("request_id=%s restored '%s' -> '%s'", request_id, part["text"], restored)
|
||||
part["text"] = restored
|
||||
else:
|
||||
# {"choices": [{"message": {"content": "..."}}]}
|
||||
for choice in body.get("choices", []):
|
||||
message = choice.get("message", {})
|
||||
content = message.get("content")
|
||||
if content and isinstance(content, str):
|
||||
restored, _ = deanonymize_text(content, mapping)
|
||||
if restored != content:
|
||||
logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored)
|
||||
message["content"] = restored
|
||||
return Response(content=json.dumps(body), media_type="application/json")
|
||||
except json.JSONDecodeError:
|
||||
return Response(content=raw_body, media_type="application/json")
|
||||
|
|
@ -1,14 +1,14 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
BASE_URL="http://localhost:12000/v1"
|
||||
BASE_URL="http://localhost:12000"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
# ── Wait for Plano to be ready ──────────────────────────────────────────────
|
||||
echo "Waiting for Plano to be ready..."
|
||||
for i in $(seq 1 30); do
|
||||
if curl -sf "$BASE_URL/models" > /dev/null 2>&1; then
|
||||
if curl -sf "$BASE_URL/v1/models" > /dev/null 2>&1; then
|
||||
echo "Plano is ready."
|
||||
break
|
||||
fi
|
||||
|
|
@ -22,11 +22,12 @@ done
|
|||
# ── Helper ───────────────────────────────────────────────────────────────────
|
||||
run_test() {
|
||||
local name="$1"
|
||||
local expected_code="$2"
|
||||
local body="$3"
|
||||
local path="$2"
|
||||
local expected_code="$3"
|
||||
local body="$4"
|
||||
|
||||
http_code=$(curl -s -o /tmp/plano_test_body -w "%{http_code}" \
|
||||
-X POST "$BASE_URL/chat/completions" \
|
||||
-X POST "$BASE_URL$path" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d "$body")
|
||||
|
||||
|
|
@ -40,34 +41,75 @@ run_test() {
|
|||
fi
|
||||
}
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────────────
|
||||
# ── /v1/chat/completions ─────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "Running tests..."
|
||||
echo "=== /v1/chat/completions ==="
|
||||
|
||||
run_test "Non-streaming with PII (email + phone)" 200 '{
|
||||
run_test "Non-streaming with PII (email + phone)" /v1/chat/completions 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Contact me at john@example.com or call 555-123-4567"}],
|
||||
"stream": false
|
||||
}'
|
||||
|
||||
run_test "Streaming with PII (SSN)" 200 '{
|
||||
run_test "Streaming with PII (SSN)" /v1/chat/completions 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "My SSN is 123-45-6789, please help me file taxes"}],
|
||||
"stream": true
|
||||
}'
|
||||
|
||||
run_test "No PII (clean message)" 200 '{
|
||||
run_test "No PII (clean message)" /v1/chat/completions 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "What is 2+2?"}],
|
||||
"stream": false
|
||||
}'
|
||||
|
||||
run_test "Multiple PII types" 200 '{
|
||||
run_test "Multiple PII types" /v1/chat/completions 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"messages": [{"role": "user", "content": "Email: test@test.com, Phone: 555-867-5309, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}],
|
||||
"messages": [{"role": "user", "content": "Email: test@test.com, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}],
|
||||
"stream": false
|
||||
}'
|
||||
|
||||
# ── /v1/responses ────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== /v1/responses ==="
|
||||
|
||||
run_test "Non-streaming with PII (email)" /v1/responses 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"input": "My email is jane@example.com — can you summarize it?"
|
||||
}'
|
||||
|
||||
run_test "Non-streaming with PII (credit card)" /v1/responses 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"input": "I need help disputing a charge on card 4111 1111 1111 1111"
|
||||
}'
|
||||
|
||||
run_test "No PII" /v1/responses 200 '{
|
||||
"model": "gpt-4o-mini",
|
||||
"input": "What is the capital of France?"
|
||||
}'
|
||||
|
||||
# ── /v1/messages (Anthropic) ─────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== /v1/messages ==="
|
||||
|
||||
run_test "Non-streaming with PII (phone)" /v1/messages 200 '{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "Call me at 555-867-5309 to discuss my account"}]
|
||||
}'
|
||||
|
||||
run_test "Non-streaming with PII (SSN)" /v1/messages 200 '{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "My SSN is 123-45-6789"}]
|
||||
}'
|
||||
|
||||
run_test "No PII" /v1/messages 200 '{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"max_tokens": 256,
|
||||
"messages": [{"role": "user", "content": "Hello, how are you?"}]
|
||||
}'
|
||||
|
||||
# ── Summary ──────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "Results: $PASS passed, $FAIL failed"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue