diff --git a/crates/brightstaff/src/handlers/llm.rs b/crates/brightstaff/src/handlers/llm.rs index b8beb7b8..487a7b22 100644 --- a/crates/brightstaff/src/handlers/llm.rs +++ b/crates/brightstaff/src/handlers/llm.rs @@ -268,12 +268,13 @@ async fn llm_chat_inner( } // === Input filters processing for model listener === + // Filters receive the raw request bytes and return (possibly modified) raw bytes. + // The returned bytes are re-parsed into a ProviderRequestType to continue the request. { if let Some(ref fc) = *input_filters { if !fc.is_empty() { debug!(input_filters = ?fc, "processing model listener input filters"); - // Create a temporary AgentFilterChain to reuse PipelineProcessor let temp_filter_chain = AgentFilterChain { id: "model_listener".to_string(), default: None, @@ -282,23 +283,34 @@ async fn llm_chat_inner( }; let mut pipeline_processor = PipelineProcessor::default(); - let messages = client_request.get_messages(); match pipeline_processor - .process_filter_chain( - &messages, + .process_raw_filter_chain( + &chat_request_bytes, &temp_filter_chain, &input_filter_agents, &request_headers, + &request_path, ) .await { - Ok(filtered_messages) => { - client_request.set_messages(&filtered_messages); - info!( - original_count = messages.len(), - filtered_count = filtered_messages.len(), - "filter chain processed successfully" - ); + Ok(filtered_bytes) => { + match ProviderRequestType::try_from(( + &filtered_bytes[..], + &SupportedAPIsFromClient::from_endpoint(request_path.as_str()).unwrap(), + )) { + Ok(updated_request) => { + client_request = updated_request; + info!("input filter chain processed successfully"); + } + Err(parse_err) => { + warn!(error = %parse_err, "input filter returned invalid request JSON"); + return Ok(BrightStaffError::InvalidRequest(format!( + "Input filter returned invalid request: {}", + parse_err + )) + .into_response()); + } + } } Err(super::pipeline_processor::PipelineError::ClientError { agent, @@ -508,21 +520,25 @@ async fn llm_chat_inner( propagator.inject_context(&cx, &mut HeaderInjector(&mut request_headers)); }); - // Output filters are only supported for /v1/chat/completions — the SSE content - // extraction logic is specific to that API shape (choices[].delta.content). let output_filters_configured = output_filters .as_ref() .as_ref() .map(|fc| !fc.is_empty()) .unwrap_or(false); - let has_output_filter = output_filters_configured - && request_path == common::consts::CHAT_COMPLETIONS_PATH; - if output_filters_configured && !has_output_filter { - warn!( - path = %request_path, - "output filters are configured but only supported for /v1/chat/completions, skipping" - ); - } + let has_output_filter = output_filters_configured; + + // Extract the upstream API path (e.g. "/v1/messages" from "https://api.anthropic.com/v1/messages"). + // Output filters are called at so they know the exact byte format. + let upstream_api_path = { + let after_scheme = full_qualified_llm_provider_url + .find("://") + .map(|i| &full_qualified_llm_provider_url[i + 3..]) + .unwrap_or(&full_qualified_llm_provider_url); + after_scheme + .find('/') + .map(|i| after_scheme[i..].to_string()) + .unwrap_or_else(|| "/".to_string()) + }; // Save request headers for output filters (before they're consumed by upstream request) let output_filter_request_headers = if has_output_filter { @@ -607,6 +623,7 @@ async fn llm_chat_inner( ofc, ofa, output_filter_request_headers.unwrap(), + upstream_api_path.clone(), ) } else { create_streaming_response(byte_stream, state_processor, 16) @@ -621,6 +638,7 @@ async fn llm_chat_inner( ofc, ofa, output_filter_request_headers.unwrap(), + upstream_api_path, ) } else { // Use base processor without state management diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index d9e87095..776a88f5 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use bytes::Bytes; use common::configuration::{Agent, AgentFilterChain}; use common::consts::{ ARCH_UPSTREAM_HOST_HEADER, BRIGHT_STAFF_SERVICE_NAME, ENVOY_RETRY_HEADER, TRACE_PARENT_HEADER, @@ -605,6 +606,139 @@ impl PipelineProcessor { Ok(messages) } + /// Execute a raw bytes filter — POST bytes to agent.url, receive bytes back. + /// Used for input and output filters where the full raw request/response is passed through. + /// No MCP protocol wrapping; agent_type is ignored. + #[instrument( + skip(self, raw_bytes, agent, request_headers), + fields( + agent_id = %agent.id, + agent_url = %agent.url, + filter_name = %agent.id, + bytes_len = raw_bytes.len() + ) + )] + async fn execute_raw_filter( + &mut self, + raw_bytes: &[u8], + agent: &Agent, + request_headers: &HeaderMap, + request_path: &str, + ) -> Result { + set_service_name(operation_component::AGENT_FILTER); + use opentelemetry::trace::get_active_span; + get_active_span(|span| { + span.update_name(format!("execute_raw_filter ({})", agent.id)); + }); + + let mut agent_headers = request_headers.clone(); + agent_headers.remove(hyper::header::CONTENT_LENGTH); + + agent_headers.remove(TRACE_PARENT_HEADER); + global::get_text_map_propagator(|propagator| { + let cx = + tracing_opentelemetry::OpenTelemetrySpanExt::context(&tracing::Span::current()); + propagator.inject_context(&cx, &mut HeaderInjector(&mut agent_headers)); + }); + + agent_headers.insert( + ARCH_UPSTREAM_HOST_HEADER, + hyper::header::HeaderValue::from_str(&agent.id) + .map_err(|_| PipelineError::AgentNotFound(agent.id.clone()))?, + ); + agent_headers.insert( + ENVOY_RETRY_HEADER, + hyper::header::HeaderValue::from_str("3").unwrap(), + ); + agent_headers.insert( + "Accept", + hyper::header::HeaderValue::from_static("application/json"), + ); + agent_headers.insert( + "Content-Type", + hyper::header::HeaderValue::from_static("application/json"), + ); + + // Append the original request path so the filter endpoint encodes the API format. + // e.g. agent.url="http://host/anonymize" + request_path="/v1/chat/completions" + // -> POST http://host/anonymize/v1/chat/completions + let url = format!("{}{}", agent.url, request_path); + debug!(agent = %agent.id, url = %url, "sending raw filter request"); + + let response = self + .client + .post(&url) + .headers(agent_headers) + .body(raw_bytes.to_vec()) + .send() + .await?; + + let http_status = response.status(); + let response_bytes = response.bytes().await?; + + if !http_status.is_success() { + let error_body = String::from_utf8_lossy(&response_bytes).to_string(); + return Err(if http_status.is_client_error() { + PipelineError::ClientError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + } else { + PipelineError::ServerError { + agent: agent.id.clone(), + status: http_status.as_u16(), + body: error_body, + } + }); + } + + debug!(agent = %agent.id, bytes_len = response_bytes.len(), "raw filter response received"); + Ok(response_bytes) + } + + /// Process a chain of raw-bytes filters sequentially. + /// Input: raw request or response bytes. Output: filtered bytes. + /// Each agent receives the output of the previous one. + pub async fn process_raw_filter_chain( + &mut self, + raw_bytes: &[u8], + agent_filter_chain: &AgentFilterChain, + agent_map: &HashMap, + request_headers: &HeaderMap, + request_path: &str, + ) -> Result { + let filter_chain = match agent_filter_chain.filter_chain.as_ref() { + Some(fc) if !fc.is_empty() => fc, + _ => return Ok(Bytes::copy_from_slice(raw_bytes)), + }; + + let mut current_bytes = Bytes::copy_from_slice(raw_bytes); + + for agent_name in filter_chain { + debug!(agent = %agent_name, "processing raw filter agent"); + + let agent = agent_map + .get(agent_name) + .ok_or_else(|| PipelineError::AgentNotFound(agent_name.clone()))?; + + info!( + agent = %agent_name, + url = %agent.url, + bytes_len = current_bytes.len(), + "executing raw filter" + ); + + current_bytes = self + .execute_raw_filter(¤t_bytes, agent, request_headers, request_path) + .await?; + + info!(agent = %agent_name, bytes_len = current_bytes.len(), "raw filter completed"); + } + + Ok(current_bytes) + } + /// Send request to terminal agent and return the raw response for streaming /// Note: The caller is responsible for creating the plano(agent) span that wraps /// both this call and the subsequent response consumption. diff --git a/crates/brightstaff/src/handlers/utils.rs b/crates/brightstaff/src/handlers/utils.rs index de548e7a..18a4adca 100644 --- a/crates/brightstaff/src/handlers/utils.rs +++ b/crates/brightstaff/src/handlers/utils.rs @@ -14,10 +14,10 @@ use tokio_stream::StreamExt; use tracing::{debug, info, warn, Instrument}; use tracing_opentelemetry::OpenTelemetrySpanExt; -use super::pipeline_processor::PipelineProcessor; +use super::pipeline_processor::{PipelineError, PipelineProcessor}; use crate::signals::{InteractionQuality, SignalAnalyzer, TextBasedSignalAnalyzer, FLAG_MARKER}; use crate::tracing::{llm, set_service_name, signals as signal_constants}; -use hermesllm::apis::openai::{Message, MessageContent, Role}; +use hermesllm::apis::openai::Message; /// Trait for processing streaming chunks /// Implementors can inject custom logic during streaming (e.g., hallucination detection, logging) @@ -281,184 +281,9 @@ where } } -/// Extract content text from an SSE chunk line (the JSON part after "data: "). -/// Returns the content string and whether it was found. -fn extract_sse_content(json_str: &str) -> Option { - let value: serde_json::Value = serde_json::from_str(json_str).ok()?; - value - .get("choices")? - .get(0)? - .get("delta")? - .get("content")? - .as_str() - .map(|s| s.to_string()) -} - -/// Replace content in an SSE JSON chunk with new content. -fn replace_sse_content(json_str: &str, new_content: &str) -> Option { - let mut value: serde_json::Value = serde_json::from_str(json_str).ok()?; - value - .get_mut("choices")? - .get_mut(0)? - .get_mut("delta")? - .as_object_mut()? - .insert( - "content".to_string(), - serde_json::Value::String(new_content.to_string()), - ); - serde_json::to_string(&value).ok() -} - -/// Process an SSE chunk through the output filter chain. -/// Parses each `data: ` line, extracts content, sends through filters, and reconstructs. -async fn filter_sse_chunk( - chunk_str: &str, - pipeline_processor: &mut PipelineProcessor, - filter_chain: &AgentFilterChain, - filter_agents: &HashMap, - request_headers: &HeaderMap, -) -> String { - let mut result = String::new(); - for line in chunk_str.split('\n') { - if let Some(json_str) = line.strip_prefix("data: ") { - if json_str.trim() == "[DONE]" { - result.push_str(line); - result.push('\n'); - continue; - } - if let Some(content) = extract_sse_content(json_str) { - if content.is_empty() { - result.push_str(line); - result.push('\n'); - continue; - } - // Send content through output filter chain - let messages = vec![Message { - role: Role::Assistant, - content: Some(MessageContent::Text(content)), - name: None, - tool_calls: None, - tool_call_id: None, - }]; - match pipeline_processor - .process_filter_chain(&messages, filter_chain, filter_agents, request_headers) - .await - { - Ok(filtered_messages) => { - if let Some(msg) = filtered_messages.first() { - let filtered_content = match &msg.content { - Some(MessageContent::Text(t)) => Some(t.clone()), - _ => None, - }; - if let Some(filtered_content) = filtered_content { - if let Some(new_json) = - replace_sse_content(json_str, &filtered_content) - { - result.push_str("data: "); - result.push_str(&new_json); - result.push('\n'); - continue; - } - } - } - // Fallback: pass through original - result.push_str(line); - result.push('\n'); - } - Err(e) => { - warn!(error = %e, "output filter chain error, passing through original chunk"); - result.push_str(line); - result.push('\n'); - } - } - } else { - // No content in this SSE line, pass through - result.push_str(line); - result.push('\n'); - } - } else { - result.push_str(line); - result.push('\n'); - } - } - // Remove trailing extra newline if the original didn't end with one - if !chunk_str.ends_with('\n') && result.ends_with('\n') { - result.pop(); - } - result -} - -/// Process a non-streaming JSON response through the output filter chain. -/// Extracts assistant message content, filters it, and reconstructs the response. -pub async fn filter_non_streaming_response( - response_bytes: &[u8], - pipeline_processor: &mut PipelineProcessor, - filter_chain: &AgentFilterChain, - filter_agents: &HashMap, - request_headers: &HeaderMap, -) -> Bytes { - let response_str = match std::str::from_utf8(response_bytes) { - Ok(s) => s, - Err(_) => return Bytes::from(response_bytes.to_vec()), - }; - - let mut value: serde_json::Value = match serde_json::from_str(response_str) { - Ok(v) => v, - Err(_) => return Bytes::from(response_bytes.to_vec()), - }; - - // Extract content from choices[0].message.content - let content = value - .get("choices") - .and_then(|c| c.get(0)) - .and_then(|c| c.get("message")) - .and_then(|m| m.get("content")) - .and_then(|c| c.as_str()) - .map(|s| s.to_string()); - - if let Some(content) = content { - let messages = vec![Message { - role: Role::Assistant, - content: Some(MessageContent::Text(content)), - name: None, - tool_calls: None, - tool_call_id: None, - }]; - match pipeline_processor - .process_filter_chain(&messages, filter_chain, filter_agents, request_headers) - .await - { - Ok(filtered_messages) => { - if let Some(msg) = filtered_messages.first() { - let filtered_content = match &msg.content { - Some(MessageContent::Text(t)) => Some(t.clone()), - _ => None, - }; - if let Some(filtered_content) = filtered_content { - if let Some(choices) = value.get_mut("choices") { - if let Some(choice) = choices.get_mut(0) { - if let Some(message) = choice.get_mut("message") { - message.as_object_mut().unwrap().insert( - "content".to_string(), - serde_json::Value::String(filtered_content), - ); - } - } - } - } - } - } - Err(e) => { - warn!(error = %e, "output filter chain error on non-streaming response"); - } - } - } - - Bytes::from(serde_json::to_string(&value).unwrap_or_else(|_| response_str.to_string())) -} - -/// Creates a streaming response that processes each chunk through output filters. -/// The output filter is called asynchronously for each SSE chunk's content. +/// Creates a streaming response that processes each raw chunk through output filters. +/// Filters receive the raw LLM response bytes and return (possibly modified) bytes. +/// On filter error mid-stream the original chunk is passed through (headers already sent). pub fn create_streaming_response_with_output_filter( mut byte_stream: S, mut inner_processor: P, @@ -466,6 +291,7 @@ pub fn create_streaming_response_with_output_filter( output_filters: Vec, output_filter_agents: HashMap, request_headers: HeaderMap, + upstream_path: String, ) -> StreamingResponse where S: StreamExt> + Send + Unpin + 'static, @@ -501,32 +327,35 @@ where is_first_chunk = false; } - // Try to process through output filter chain - let processed_chunk = if let Ok(chunk_str) = std::str::from_utf8(&chunk) { - if chunk_str.contains("data: ") { - let filtered = filter_sse_chunk( - chunk_str, - &mut pipeline_processor, - &temp_filter_chain, - &output_filter_agents, - &request_headers, - ) - .await; - Bytes::from(filtered) - } else { - // Non-SSE chunk (could be non-streaming JSON response) - let filtered = filter_non_streaming_response( - &chunk, - &mut pipeline_processor, - &temp_filter_chain, - &output_filter_agents, - &request_headers, - ) - .await; - filtered + // Pass raw chunk bytes through the output filter chain + let processed_chunk = match pipeline_processor + .process_raw_filter_chain( + &chunk, + &temp_filter_chain, + &output_filter_agents, + &request_headers, + &upstream_path, + ) + .await + { + Ok(filtered) => filtered, + Err(PipelineError::ClientError { + agent, + status, + body, + }) => { + warn!( + agent = %agent, + status = %status, + body = %body, + "output filter client error, passing through original chunk" + ); + chunk + } + Err(e) => { + warn!(error = %e, "output filter error, passing through original chunk"); + chunk } - } else { - chunk }; // Pass through inner processor for metrics/observability diff --git a/demos/filter_chains/model_listener_filter/README.md b/demos/filter_chains/model_listener_filter/README.md index 92d44695..fb49ee1e 100644 --- a/demos/filter_chains/model_listener_filter/README.md +++ b/demos/filter_chains/model_listener_filter/README.md @@ -3,7 +3,11 @@ Run content-safety filters on direct LLM requests — no agent layer required. This demo uses the `input_filters` feature on a **model-type listener** to intercept -`/v1/chat/completions` requests and block unsafe content before they reach the LLM provider. +requests and block unsafe content before they reach the LLM provider. Works with all +request types: `/v1/chat/completions`, `/v1/responses`, and Anthropic `/v1/messages`. + +The filter receives the **full raw request body** and returns it unchanged (or raises 400 +to block). No message extraction — the complete JSON payload flows through as-is. ## Architecture diff --git a/demos/filter_chains/model_listener_filter/content_guard.py b/demos/filter_chains/model_listener_filter/content_guard.py index b452b10d..2bd35b63 100644 --- a/demos/filter_chains/model_listener_filter/content_guard.py +++ b/demos/filter_chains/model_listener_filter/content_guard.py @@ -3,13 +3,15 @@ Content guard filter — keyword-based content safety for model listeners. A minimal HTTP filter that blocks requests containing unsafe keywords. No LLM calls required — keeps the demo self-contained and fast. + +Receives the full raw request body (any API format: /v1/chat/completions, +/v1/responses, /v1/messages) and returns it unchanged or raises 400. """ import logging -from typing import List +from typing import Any from fastapi import FastAPI, Request, HTTPException -from pydantic import BaseModel logging.basicConfig( level=logging.INFO, @@ -36,11 +38,6 @@ BLOCKED_KEYWORDS = [ ] -class ChatMessage(BaseModel): - role: str - content: str - - def check_content(text: str) -> str | None: """Return the matched keyword if blocked, else None.""" lower = text.lower() @@ -50,19 +47,58 @@ def check_content(text: str) -> str | None: return None -@app.post("/") -async def content_guard( - messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: - """Block messages that contain unsafe keywords.""" - last_user_msg = None +def extract_last_user_text(body: dict[str, Any]) -> str | None: + """Extract the most recent user message text from any supported request format.""" + messages = body.get("messages", []) + # Anthropic /v1/messages and OpenAI /v1/chat/completions both use "messages" for msg in reversed(messages): - if msg.role == "user": - last_user_msg = msg.content - break + if msg.get("role") == "user": + content = msg.get("content", "") + if isinstance(content, str): + return content + if isinstance(content, list): + # Multimodal content parts + return " ".join( + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ) + + # OpenAI /v1/responses uses "input" instead of "messages" + input_val = body.get("input") + if isinstance(input_val, str): + return input_val + if isinstance(input_val, list): + for item in reversed(input_val): + if isinstance(item, dict) and item.get("role") == "user": + content = item.get("content", "") + if isinstance(content, str): + return content + + return None + + +@app.post("/{path:path}") +async def content_guard(path: str, request: Request) -> dict[str, Any]: + """Block requests containing unsafe keywords. Returns the full request body unchanged. + + The endpoint path encodes the API format: + /v1/chat/completions — check body["messages"] + /v1/responses — check body["input"] + /v1/messages — check body["messages"] (Anthropic format) + """ + endpoint = f"/{path}" + body = await request.json() + + # /v1/responses uses "input" instead of "messages" + if endpoint == "/v1/responses": + input_val = body.get("input", "") + last_user_msg = input_val if isinstance(input_val, str) else None + else: + last_user_msg = extract_last_user_text(body) if last_user_msg is None: - return messages + return body matched = check_content(last_user_msg) if matched: @@ -76,7 +112,7 @@ async def content_guard( ) logger.info("Content check passed — forwarding request") - return messages + return body @app.get("/health") diff --git a/demos/filter_chains/pii_anonymizer/README.md b/demos/filter_chains/pii_anonymizer/README.md index a288733b..7f4c0611 100644 --- a/demos/filter_chains/pii_anonymizer/README.md +++ b/demos/filter_chains/pii_anonymizer/README.md @@ -80,14 +80,26 @@ Check the PII filter service logs in the terminal running `start_agents.sh`. You | Email | standard email format | `user@example.com` | `[EMAIL_0]` | | Phone | US phone formats | `555-123-4567` | `[PHONE_0]` | +## Filter Contract + +**Input filter (`/anonymize`)** receives the **full raw request body** and returns the modified body: +```json +{"model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Contact john@example.com"}], "stream": true} +``` +→ returns the same structure with PII replaced in the `messages` array. + +**Output filter (`/deanonymize`)** receives the **raw LLM response bytes** and returns modified bytes: +- *Streaming*: raw SSE chunk, e.g. `data: {"choices":[{"delta":{"content":"Contact [EMAIL_0]"}}]}` +- *Non-streaming*: full JSON response body + ## How Streaming De-anonymization Works -For streaming responses, each SSE chunk is sent through the output filters as it arrives from the LLM: +For streaming responses, each raw SSE chunk is sent through the output filter as it arrives from the LLM: -1. Plano receives a chunk with content like `"The email [EMAIL_0] belongs to..."` -2. The chunk content is sent to the `/deanonymize` endpoint -3. The filter looks up the PII mapping (stored during anonymization) and replaces placeholders -4. The restored chunk `"The email john@example.com belongs to..."` is streamed to the client +1. Plano receives a raw SSE chunk like `data: {"choices":[{"delta":{"content":"The email [EMAIL_0] belongs to..."}}]}` +2. The raw chunk bytes are sent to the `/deanonymize` endpoint +3. The filter parses the SSE, looks up the PII mapping (stored during anonymization), and replaces placeholders in the delta content +4. The restored chunk is returned and streamed to the client Partial placeholders split across chunks (e.g., `[EMA` in one chunk, `IL_0]` in the next) are handled via internal buffering in the filter service. diff --git a/demos/filter_chains/pii_anonymizer/config.yaml b/demos/filter_chains/pii_anonymizer/config.yaml index 3bae7354..b183379f 100644 --- a/demos/filter_chains/pii_anonymizer/config.yaml +++ b/demos/filter_chains/pii_anonymizer/config.yaml @@ -12,6 +12,8 @@ model_providers: - model: openai/gpt-4o-mini access_key: $OPENAI_API_KEY default: true + - model: anthropic/claude-sonnet-4-20250514 + access_key: $ANTHROPIC_API_KEY listeners: - type: model diff --git a/demos/filter_chains/pii_anonymizer/pii.py b/demos/filter_chains/pii_anonymizer/pii.py new file mode 100644 index 00000000..0d1dd58e --- /dev/null +++ b/demos/filter_chains/pii_anonymizer/pii.py @@ -0,0 +1,88 @@ +"""PII detection and anonymization utilities.""" + +import re +from typing import Any, Dict, List, Tuple + +# Order matters: SSN before phone to avoid overlap +PII_PATTERNS = [ + ("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")), + ("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")), + ("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")), + ("PHONE", re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}")), +] + + +def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]: + """Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping).""" + mapping: Dict[str, str] = {} + counters: Dict[str, int] = {} + matched_spans: List[Tuple[int, int]] = [] + + for pii_type, pattern in PII_PATTERNS: + for match in pattern.finditer(text): + start, end = match.start(), match.end() + if any(s <= start < e or s < end <= e for s, e in matched_spans): + continue + matched_spans.append((start, end)) + idx = counters.get(pii_type, 0) + counters[pii_type] = idx + 1 + mapping[f"[{pii_type}_{idx}]"] = match.group() + + # Replace right-to-left to preserve span indices + matched_spans.sort(reverse=True) + result = text + for start, end in matched_spans: + placeholder = next(k for k, v in mapping.items() if v == text[start:end]) + result = result[:start] + placeholder + result[end:] + + return result, mapping + + +def deanonymize_text( + text: str, mapping: Dict[str, str], buffer: str = "" +) -> Tuple[str, str]: + """Replace placeholders back with original PII values. + + Handles partial placeholders at chunk boundaries via a buffer. + Returns (processed_text, remaining_buffer). + """ + combined = buffer + text + + # Build prefix set for all known placeholders (e.g. "[EMAIL_0" is a prefix of "[EMAIL_0]") + prefixes: set[str] = set() + for placeholder in mapping: + for i in range(1, len(placeholder)): + prefixes.add(placeholder[:i]) + + # If the tail looks like the start of a placeholder, hold it in the buffer + remaining_buffer = "" + last_bracket = combined.rfind("[") + if last_bracket != -1 and "]" not in combined[last_bracket:]: + tail = combined[last_bracket:] + if tail in prefixes: + remaining_buffer = tail + combined = combined[:last_bracket] + + for placeholder, original in mapping.items(): + combined = combined.replace(placeholder, original) + + return combined, remaining_buffer + + +def anonymize_message_content(content: Any, all_mappings: Dict[str, str]) -> Any: + """Anonymize string content or list of content parts.""" + if isinstance(content, str): + anonymized, mapping = anonymize_text(content) + all_mappings.update(mapping) + return anonymized + if isinstance(content, list): + result = [] + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + anonymized, mapping = anonymize_text(part.get("text", "")) + all_mappings.update(mapping) + result.append({**part, "text": anonymized}) + else: + result.append(part) + return result + return content diff --git a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py index 1e390ce6..9adf16cd 100644 --- a/demos/filter_chains/pii_anonymizer/pii_anonymizer.py +++ b/demos/filter_chains/pii_anonymizer/pii_anonymizer.py @@ -5,18 +5,26 @@ Inspired by Uber's GenAI Gateway PII Redactor. Two endpoints: POST /anonymize — replace PII with placeholders (input filter) POST /deanonymize — restore original PII from placeholders (output filter) -Uses regex-based detection for: email, phone, SSN, credit card. -Correlates request/response via x-request-id header. +Input filter (/anonymize): + Receives the full raw request body (any API format). Anonymizes user message + content and returns the modified body. + +Output filter (/deanonymize): + Receives raw LLM response bytes — SSE (streaming) or full JSON (non-streaming). + De-anonymizes content and returns modified bytes. + +The path suffix encodes the upstream API format so each endpoint knows how to +parse the body (e.g. /anonymize/v1/chat/completions, /deanonymize/v1/messages). """ import logging -import re -import time -import threading -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict from fastapi import FastAPI, Request -from pydantic import BaseModel +from fastapi.responses import Response + +from pii import anonymize_text, anonymize_message_content +from store import get_mapping, store_mapping, deanonymize_sse, deanonymize_json logging.basicConfig( level=logging.INFO, @@ -26,205 +34,79 @@ logger = logging.getLogger(__name__) app = FastAPI(title="PII Anonymizer", version="1.0.0") -# --- PII patterns (order matters: SSN before phone to avoid overlap) --- -PII_PATTERNS = [ - ("SSN", re.compile(r"\b\d{3}-\d{2}-\d{4}\b")), - ("CREDIT_CARD", re.compile(r"\b(?:\d{4}[-\s]?){3}\d{4}\b")), - ("EMAIL", re.compile(r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}")), - ( - "PHONE", - re.compile(r"(\+?1[-.\s]?)?\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}"), - ), -] +@app.post("/anonymize/{path:path}") +async def anonymize(path: str, request: Request) -> dict[str, Any]: + """Anonymize PII in user messages. Receives and returns the full raw request body. -# --- In-memory mapping store (request_id -> mapping + timestamp) --- - -_store_lock = threading.Lock() -_mapping_store: Dict[str, Tuple[Dict[str, str], float]] = {} -# Buffer for partial placeholder matches during streaming de-anonymization -_buffer_store: Dict[str, str] = {} -MAPPING_TTL_SECONDS = 300 # 5 minutes - - -def _cleanup_expired(): - """Remove expired mappings.""" - now = time.time() - expired = [ - k for k, (_, ts) in _mapping_store.items() if now - ts > MAPPING_TTL_SECONDS - ] - for k in expired: - del _mapping_store[k] - _buffer_store.pop(k, None) - - -def _store_mapping(request_id: str, mapping: Dict[str, str]): - with _store_lock: - _cleanup_expired() - _mapping_store[request_id] = (mapping, time.time()) - - -def _get_mapping(request_id: str) -> Optional[Dict[str, str]]: - with _store_lock: - entry = _mapping_store.get(request_id) - if entry: - return entry[0] - return None - - -# --- Core logic --- - - -class ChatMessage(BaseModel): - role: str - content: str - - -def anonymize_text(text: str) -> Tuple[str, Dict[str, str]]: - """Replace PII with [TYPE_N] placeholders. Returns (anonymized_text, mapping).""" - mapping: Dict[str, str] = {} - counters: Dict[str, int] = {} - # Track spans already matched to avoid overlapping replacements - matched_spans: List[Tuple[int, int]] = [] - - for pii_type, pattern in PII_PATTERNS: - for match in pattern.finditer(text): - start, end = match.start(), match.end() - # Skip if this span overlaps with an already-matched span - if any(s <= start < e or s < end <= e for s, e in matched_spans): - continue - matched_spans.append((start, end)) - idx = counters.get(pii_type, 0) - counters[pii_type] = idx + 1 - placeholder = f"[{pii_type}_{idx}]" - mapping[placeholder] = match.group() - - # Replace from right to left to preserve indices - matched_spans.sort(reverse=True) - result = text - for start, end in matched_spans: - original = text[start:end] - # Find the placeholder for this original value - placeholder = next(k for k, v in mapping.items() if v == original) - result = result[:start] + placeholder + result[end:] - - return result, mapping - - -def deanonymize_text( - text: str, mapping: Dict[str, str], buffer: str = "" -) -> Tuple[str, str]: - """Replace placeholders back with original PII values. - - Handles partial placeholders across streaming chunks via a buffer. - Only buffers text that could be the prefix of an actual placeholder - from this request's mapping, not arbitrary ``[`` from normal text. - Returns (processed_text, remaining_buffer). + The endpoint path encodes the API format: + /anonymize/v1/chat/completions — anonymize body["messages"] + /anonymize/v1/responses — anonymize body["input"] (string or items list) + /anonymize/v1/messages — anonymize body["messages"] (Anthropic format) """ - combined = buffer + text - - # Build the set of all prefixes for placeholders in this request's mapping. - # e.g. for "[EMAIL_0]" -> {"[", "[E", "[EM", "[EMA", "[EMAI", "[EMAIL", "[EMAIL_", "[EMAIL_0"} - prefixes: set[str] = set() - for placeholder in mapping: - # Exclude the full placeholder (with closing ']') — that's a complete match, not partial - for i in range(1, len(placeholder)): - prefixes.add(placeholder[:i]) - - # Check if the end of the text could be a partial placeholder. - remaining_buffer = "" - last_bracket = combined.rfind("[") - if last_bracket != -1 and "]" not in combined[last_bracket:]: - tail = combined[last_bracket:] - if tail in prefixes: - remaining_buffer = tail - combined = combined[:last_bracket] - - # Replace all complete placeholders - for placeholder, original in mapping.items(): - combined = combined.replace(placeholder, original) - - return combined, remaining_buffer - - -# --- Endpoints --- - - -@app.post("/anonymize") -async def anonymize(messages: List[ChatMessage], request: Request) -> List[ChatMessage]: - """Anonymize PII in user messages. Stores mapping for later de-anonymization.""" request_id = request.headers.get("x-request-id", "unknown") + endpoint = f"/{path}" + body = await request.json() all_mappings: Dict[str, str] = {} - result_messages = [] - for msg in messages: - if msg.role == "user": - anonymized, mapping = anonymize_text(msg.content) + if endpoint == "/v1/responses": + input_val = body.get("input", "") + if isinstance(input_val, str): + anonymized, mapping = anonymize_text(input_val) all_mappings.update(mapping) - result_messages.append(ChatMessage(role=msg.role, content=anonymized)) - else: - result_messages.append(msg) + body = {**body, "input": anonymized} + elif isinstance(input_val, list): + items = [ + {**item, "content": anonymize_message_content(item.get("content", ""), all_mappings)} + if isinstance(item, dict) and item.get("role") == "user" + else item + for item in input_val + ] + body = {**body, "input": items} + else: + # /v1/chat/completions and /v1/messages both use "messages" + messages = [ + {**msg, "content": anonymize_message_content(msg.get("content", ""), all_mappings)} + if msg.get("role") == "user" + else msg + for msg in body.get("messages", []) + ] + if messages: + body = {**body, "messages": messages} if all_mappings: - _store_mapping(request_id, all_mappings) - logger.info( - "request_id=%s /anonymize mapping: %s", - request_id, - all_mappings, - ) + store_mapping(request_id, all_mappings) + logger.info("request_id=%s /anonymize mapping: %s", request_id, all_mappings) else: logger.info("request_id=%s no PII detected", request_id) - logger.info( - "request_id=%s /anonymize input: %s -> output: %s", - request_id, - [m.content for m in messages], - [m.content for m in result_messages], - ) - - return result_messages + return body -@app.post("/deanonymize") -async def deanonymize( - messages: List[ChatMessage], request: Request -) -> List[ChatMessage]: - """De-anonymize PII placeholders in response messages using stored mapping.""" +@app.post("/deanonymize/{path:path}") +async def deanonymize(path: str, request: Request) -> Response: + """De-anonymize PII placeholders in LLM response. Handles SSE (streaming) and JSON. + + The path encodes the upstream API format: + /deanonymize/v1/chat/completions — OpenAI chat completions + /deanonymize/v1/messages — Anthropic messages + /deanonymize/v1/responses — OpenAI responses API + """ + endpoint = f"/{path}" + is_anthropic = endpoint == "/v1/messages" request_id = request.headers.get("x-request-id", "unknown") - mapping = _get_mapping(request_id) + mapping = get_mapping(request_id) + raw_body = await request.body() if not mapping: logger.info("request_id=%s no mapping found, passing through", request_id) - return messages + return Response(content=raw_body, media_type="application/json") - result_messages = [] - for msg in messages: - if msg.role == "assistant" and msg.content: - with _store_lock: - buffer = _buffer_store.get(request_id, "") + body_str = raw_body.decode("utf-8", errors="replace") - restored, remaining = deanonymize_text(msg.content, mapping, buffer) - - with _store_lock: - if remaining: - _buffer_store[request_id] = remaining - else: - _buffer_store.pop(request_id, None) - - # Only log when a replacement actually happened - if restored != msg.content: - logger.info( - "request_id=%s /deanonymize '%s' -> '%s'", - request_id, - msg.content, - restored, - ) - - result_messages.append(ChatMessage(role=msg.role, content=restored)) - else: - result_messages.append(msg) - - return result_messages + if "data: " in body_str: + return deanonymize_sse(request_id, body_str, mapping, is_anthropic) + return deanonymize_json(request_id, raw_body, body_str, mapping, is_anthropic) @app.get("/health") diff --git a/demos/filter_chains/pii_anonymizer/store.py b/demos/filter_chains/pii_anonymizer/store.py new file mode 100644 index 00000000..5be58f68 --- /dev/null +++ b/demos/filter_chains/pii_anonymizer/store.py @@ -0,0 +1,115 @@ +"""In-memory mapping store and LLM response processors for PII de-anonymization.""" + +import json +import logging +import threading +import time +from typing import Dict, Optional, Tuple + +from fastapi.responses import Response + +from pii import deanonymize_text + +logger = logging.getLogger(__name__) + +MAPPING_TTL_SECONDS = 300 # 5 minutes + +_lock = threading.Lock() +_mappings: Dict[str, Tuple[Dict[str, str], float]] = {} +_buffers: Dict[str, str] = {} # partial placeholder buffers for streaming + + +def _cleanup_expired(): + now = time.time() + expired = [k for k, (_, ts) in _mappings.items() if now - ts > MAPPING_TTL_SECONDS] + for k in expired: + del _mappings[k] + _buffers.pop(k, None) + + +def store_mapping(request_id: str, mapping: Dict[str, str]): + with _lock: + _cleanup_expired() + _mappings[request_id] = (mapping, time.time()) + + +def get_mapping(request_id: str) -> Optional[Dict[str, str]]: + with _lock: + entry = _mappings.get(request_id) + return entry[0] if entry else None + + +def restore_streaming(request_id: str, content: str, mapping: Dict[str, str]) -> str: + """Restore PII in one streaming chunk, maintaining the per-request partial buffer.""" + with _lock: + buffer = _buffers.get(request_id, "") + restored, remaining = deanonymize_text(content, mapping, buffer) + with _lock: + if remaining: + _buffers[request_id] = remaining + else: + _buffers.pop(request_id, None) + if restored != content: + logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored) + return restored + + +def deanonymize_sse( + request_id: str, body_str: str, mapping: Dict[str, str], is_anthropic: bool +) -> Response: + result_lines = [] + for line in body_str.split("\n"): + stripped = line.strip() + if not (stripped.startswith("data: ") and stripped[6:] != "[DONE]"): + result_lines.append(line) + continue + try: + chunk = json.loads(stripped[6:]) + if is_anthropic: + # {"type": "content_block_delta", "delta": {"type": "text_delta", "text": "..."}} + if chunk.get("type") == "content_block_delta": + delta = chunk.get("delta", {}) + if delta.get("type") == "text_delta" and delta.get("text"): + delta["text"] = restore_streaming(request_id, delta["text"], mapping) + else: + # {"choices": [{"delta": {"content": "..."}}]} + for choice in chunk.get("choices", []): + delta = choice.get("delta", {}) + if delta.get("content"): + delta["content"] = restore_streaming(request_id, delta["content"], mapping) + result_lines.append("data: " + json.dumps(chunk)) + except json.JSONDecodeError: + result_lines.append(line) + return Response(content="\n".join(result_lines), media_type="text/plain") + + +def deanonymize_json( + request_id: str, + raw_body: bytes, + body_str: str, + mapping: Dict[str, str], + is_anthropic: bool, +) -> Response: + try: + body = json.loads(body_str) + if is_anthropic: + # {"content": [{"type": "text", "text": "..."}]} + for part in body.get("content", []): + if isinstance(part, dict) and part.get("type") == "text" and part.get("text"): + restored, _ = deanonymize_text(part["text"], mapping) + if restored != part["text"]: + logger.info("request_id=%s restored '%s' -> '%s'", request_id, part["text"], restored) + part["text"] = restored + else: + # {"choices": [{"message": {"content": "..."}}]} + for choice in body.get("choices", []): + message = choice.get("message", {}) + content = message.get("content") + if content and isinstance(content, str): + restored, _ = deanonymize_text(content, mapping) + if restored != content: + logger.info("request_id=%s restored '%s' -> '%s'", request_id, content, restored) + message["content"] = restored + return Response(content=json.dumps(body), media_type="application/json") + except json.JSONDecodeError: + return Response(content=raw_body, media_type="application/json") diff --git a/demos/filter_chains/pii_anonymizer/test.sh b/demos/filter_chains/pii_anonymizer/test.sh index 49bcb585..795de7e7 100755 --- a/demos/filter_chains/pii_anonymizer/test.sh +++ b/demos/filter_chains/pii_anonymizer/test.sh @@ -1,14 +1,14 @@ #!/usr/bin/env bash set -euo pipefail -BASE_URL="http://localhost:12000/v1" +BASE_URL="http://localhost:12000" PASS=0 FAIL=0 # ── Wait for Plano to be ready ────────────────────────────────────────────── echo "Waiting for Plano to be ready..." for i in $(seq 1 30); do - if curl -sf "$BASE_URL/models" > /dev/null 2>&1; then + if curl -sf "$BASE_URL/v1/models" > /dev/null 2>&1; then echo "Plano is ready." break fi @@ -22,11 +22,12 @@ done # ── Helper ─────────────────────────────────────────────────────────────────── run_test() { local name="$1" - local expected_code="$2" - local body="$3" + local path="$2" + local expected_code="$3" + local body="$4" http_code=$(curl -s -o /tmp/plano_test_body -w "%{http_code}" \ - -X POST "$BASE_URL/chat/completions" \ + -X POST "$BASE_URL$path" \ -H "Content-Type: application/json" \ -d "$body") @@ -40,34 +41,75 @@ run_test() { fi } -# ── Tests ──────────────────────────────────────────────────────────────────── +# ── /v1/chat/completions ───────────────────────────────────────────────────── echo "" -echo "Running tests..." +echo "=== /v1/chat/completions ===" -run_test "Non-streaming with PII (email + phone)" 200 '{ +run_test "Non-streaming with PII (email + phone)" /v1/chat/completions 200 '{ "model": "gpt-4o-mini", "messages": [{"role": "user", "content": "Contact me at john@example.com or call 555-123-4567"}], "stream": false }' -run_test "Streaming with PII (SSN)" 200 '{ +run_test "Streaming with PII (SSN)" /v1/chat/completions 200 '{ "model": "gpt-4o-mini", "messages": [{"role": "user", "content": "My SSN is 123-45-6789, please help me file taxes"}], "stream": true }' -run_test "No PII (clean message)" 200 '{ +run_test "No PII (clean message)" /v1/chat/completions 200 '{ "model": "gpt-4o-mini", "messages": [{"role": "user", "content": "What is 2+2?"}], "stream": false }' -run_test "Multiple PII types" 200 '{ +run_test "Multiple PII types" /v1/chat/completions 200 '{ "model": "gpt-4o-mini", - "messages": [{"role": "user", "content": "Email: test@test.com, Phone: 555-867-5309, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}], + "messages": [{"role": "user", "content": "Email: test@test.com, SSN: 987-65-4321, Card: 4111 1111 1111 1111"}], "stream": false }' +# ── /v1/responses ──────────────────────────────────────────────────────────── +echo "" +echo "=== /v1/responses ===" + +run_test "Non-streaming with PII (email)" /v1/responses 200 '{ + "model": "gpt-4o-mini", + "input": "My email is jane@example.com — can you summarize it?" +}' + +run_test "Non-streaming with PII (credit card)" /v1/responses 200 '{ + "model": "gpt-4o-mini", + "input": "I need help disputing a charge on card 4111 1111 1111 1111" +}' + +run_test "No PII" /v1/responses 200 '{ + "model": "gpt-4o-mini", + "input": "What is the capital of France?" +}' + +# ── /v1/messages (Anthropic) ───────────────────────────────────────────────── +echo "" +echo "=== /v1/messages ===" + +run_test "Non-streaming with PII (phone)" /v1/messages 200 '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Call me at 555-867-5309 to discuss my account"}] +}' + +run_test "Non-streaming with PII (SSN)" /v1/messages 200 '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 256, + "messages": [{"role": "user", "content": "My SSN is 123-45-6789"}] +}' + +run_test "No PII" /v1/messages 200 '{ + "model": "claude-sonnet-4-20250514", + "max_tokens": 256, + "messages": [{"role": "user", "content": "Hello, how are you?"}] +}' + # ── Summary ────────────────────────────────────────────────────────────────── echo "" echo "Results: $PASS passed, $FAIL failed"