From 0533987a2f47394e2f50a14a0781744ea18b04fc Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 23 Dec 2025 15:17:55 -0800 Subject: [PATCH] add input_guards and update prompt guards section --- .../src/handlers/agent_chat_completions.rs | 3 +- crates/brightstaff/src/handlers/jsonrpc.rs | 2 +- .../src/handlers/pipeline_processor.rs | 2 +- demos/use_cases/mcp_filter/Dockerfile | 26 +++ demos/use_cases/mcp_filter/README.md | 40 +++-- demos/use_cases/mcp_filter/arch_config.yaml | 15 +- .../use_cases/mcp_filter/docker-compose.yaml | 25 +++ .../mcp_filter/src/rag_agent/__init__.py | 9 +- .../mcp_filter/src/rag_agent/input_guards.py | 153 ++++++++++++++++++ demos/use_cases/mcp_filter/start_agents.sh | 12 +- docs/source/guides/prompt_guard.rst | 105 +++++++----- 11 files changed, 321 insertions(+), 71 deletions(-) create mode 100644 demos/use_cases/mcp_filter/Dockerfile create mode 100644 demos/use_cases/mcp_filter/src/rag_agent/input_guards.py diff --git a/crates/brightstaff/src/handlers/agent_chat_completions.rs b/crates/brightstaff/src/handlers/agent_chat_completions.rs index 34a9ce65..0c1232a2 100644 --- a/crates/brightstaff/src/handlers/agent_chat_completions.rs +++ b/crates/brightstaff/src/handlers/agent_chat_completions.rs @@ -61,6 +61,7 @@ pub async fn agent_chat( body, }) = &err { + warn!( "Client error from agent '{}' (HTTP {}): {}", agent, status, body @@ -77,7 +78,7 @@ pub async fn agent_chat( let json_string = error_json.to_string(); let mut response = Response::new(ResponseHandler::create_full_body(json_string)); *response.status_mut() = hyper::StatusCode::from_u16(*status) - .unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR); + .unwrap_or(hyper::StatusCode::BAD_REQUEST); response.headers_mut().insert( hyper::header::CONTENT_TYPE, "application/json".parse().unwrap(), diff --git a/crates/brightstaff/src/handlers/jsonrpc.rs b/crates/brightstaff/src/handlers/jsonrpc.rs index 0f8b9373..a34167fe 100644 --- a/crates/brightstaff/src/handlers/jsonrpc.rs +++ b/crates/brightstaff/src/handlers/jsonrpc.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; pub const JSON_RPC_VERSION: &str = "2.0"; pub const TOOL_CALL_METHOD : &str = "tools/call"; pub const MCP_INITIALIZE: &str = "initialize"; -pub const MCP_INITIALIZE_NOTIFICATION: &str = "initialize/notification"; +pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized"; #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(untagged)] diff --git a/crates/brightstaff/src/handlers/pipeline_processor.rs b/crates/brightstaff/src/handlers/pipeline_processor.rs index 04aa427a..b3a07559 100644 --- a/crates/brightstaff/src/handlers/pipeline_processor.rs +++ b/crates/brightstaff/src/handlers/pipeline_processor.rs @@ -551,7 +551,7 @@ impl PipelineProcessor { return Err(PipelineError::ClientError { agent: agent.id.clone(), - status: http_status.as_u16(), + status: hyper::StatusCode::BAD_REQUEST.as_u16(), body: error_message, }); } diff --git a/demos/use_cases/mcp_filter/Dockerfile b/demos/use_cases/mcp_filter/Dockerfile new file mode 100644 index 00000000..5882714a --- /dev/null +++ b/demos/use_cases/mcp_filter/Dockerfile @@ -0,0 +1,26 @@ +FROM python:3.13-slim + +WORKDIR /app + +# Install bash and uv +RUN apt-get update && apt-get install -y bash && rm -rf /var/lib/apt/lists/* +RUN pip install --no-cache-dir uv + +# Copy dependency files +COPY pyproject.toml README.md ./ + +# Copy source code +COPY src/ ./src/ +COPY start_agents.sh ./ + +# Install dependencies using uv +RUN uv pip install --system --no-cache click fastmcp pydantic fastapi uvicorn openai + +# Make start script executable +RUN chmod +x start_agents.sh + +# Expose ports for all agents +EXPOSE 10500 10501 10502 10505 + +# Run the start script with bash +CMD ["bash", "./start_agents.sh"] diff --git a/demos/use_cases/mcp_filter/README.md b/demos/use_cases/mcp_filter/README.md index a524c1b4..508060ae 100644 --- a/demos/use_cases/mcp_filter/README.md +++ b/demos/use_cases/mcp_filter/README.md @@ -4,14 +4,21 @@ A multi-agent RAG system demonstrating archgw's agent filter chain with MCP prot ## Architecture -This demo consists of three components: -1. **Query Rewriter** (MCP filter) - Rewrites user queries for better retrieval -2. **Context Builder** (MCP filter) - Retrieves relevant context from knowledge base -3. **RAG Agent** (REST) - Generates final responses based on augmented context +This demo consists of four components: +1. **Input Guards** (MCP filter) - Validates queries are within TechCorp's domain +2. **Query Rewriter** (MCP filter) - Rewrites user queries for better retrieval +3. **Context Builder** (MCP filter) - Retrieves relevant context from knowledge base +4. **RAG Agent** (REST) - Generates final responses based on augmented context ## Components -### Query Rewriter Filter (MCP) +### Input Guards Filter (MCP) +- **Port**: 10500 +- **Tool**: `input_guards` +- Validates queries are within TechCorp's domain +- Rejects queries about other companies or unrelated topics + +### Query Rewrit3r Filter (MCP) - **Port**: 10501 - **Tool**: `query_rewriter` - Improves queries using LLM before retrieval @@ -34,6 +41,7 @@ This demo consists of three components: ``` This starts: +- Input Guards MCP server on port 10500 - Query Rewriter MCP server on port 10501 - Context Builder MCP server on port 10502 - RAG Agent REST server on port 10505 @@ -59,29 +67,37 @@ The `arch_config.yaml` defines how agents are connected: ```yaml filters: + - id: input_guards + url: http://host.docker.internal:10500 + # type: mcp (default) + # tool: input_guards (default - same as filter id) + - id: query_rewriter - url: mcp://host.docker.internal:10500 - tool: rewrite_query_with_archgw # MCP tool name + url: http://host.docker.internal:10501 + # type: mcp (default) - id: context_builder - url: mcp://host.docker.internal:10501 - tool: chat_completions + url: http://host.docker.internal:10502 ``` -How It Works + +## How It Works 1. User sends request to archgw listener on port 8001 2. Request passes through MCP filter chain: + - **Input Guards** validates the query is within TechCorp's domain - **Query Rewriter** rewrites the query for better retrieval - **Context Builder** augments query with relevant knowledge base passages 3. Augmented request is forwarded to **RAG Agent** REST endpoint 4. RAG Agent generates final response using LLM -## Configuration +## Additional Configuration See `arch_config.yaml` for the complete filter chain setup. The MCP filters use default settings: - `type: mcp` (default) - `transport: streamable-http` (default) -- Tool name defaults to filter ID `sample_queries.md` for example queries to test the RAG system. +- Tool name defaults to filter ID + +See `sample_queries.md` for example queries to test the RAG system. Example request: ```bash diff --git a/demos/use_cases/mcp_filter/arch_config.yaml b/demos/use_cases/mcp_filter/arch_config.yaml index b70d817b..0b2b58a6 100644 --- a/demos/use_cases/mcp_filter/arch_config.yaml +++ b/demos/use_cases/mcp_filter/arch_config.yaml @@ -5,12 +5,16 @@ agents: url: http://host.docker.internal:10505 filters: - - id: query_rewriter + - id: input_guards url: http://host.docker.internal:10500 - type: http - # type: http or mcp, mcp is default - # transport: streamable-http # default is streamable-http - # tool: query_rewriter # default name is the filter id + # type: mcp (default) + # transport: streamable-http (default) + # tool: input_guards (default - same as filter id) + - id: query_rewriter + url: http://host.docker.internal:10501 + # type: mcp (default) + # transport: streamable-http (default) + # tool: query_rewriter (default - same as filter id) - id: context_builder url: http://host.docker.internal:10502 @@ -36,6 +40,7 @@ listeners: - id: rag_agent description: virtual assistant for retrieval augmented generation tasks filter_chain: + - input_guards - query_rewriter - context_builder tracing: diff --git a/demos/use_cases/mcp_filter/docker-compose.yaml b/demos/use_cases/mcp_filter/docker-compose.yaml index a5d45ed9..e6644321 100644 --- a/demos/use_cases/mcp_filter/docker-compose.yaml +++ b/demos/use_cases/mcp_filter/docker-compose.yaml @@ -1,4 +1,29 @@ services: + rag-agents: + build: + context: . + dockerfile: Dockerfile + ports: + - "10500:10500" + - "10501:10501" + - "10502:10502" + - "10505:10505" + environment: + - LLM_GATEWAY_ENDPOINT=${LLM_GATEWAY_ENDPOINT:-http://host.docker.internal:12000/v1} + - OPENAI_API_KEY=${OPENAI_API_KEY:?OPENAI_API_KEY environment variable is required but not set} + archgw: + build: + context: ../../../ + dockerfile: arch/Dockerfile + ports: + - "12000:12000" + - "8001:8001" + environment: + - ARCH_CONFIG_PATH=/config/arch_config.yaml + - OPENAI_API_KEY=${OPENAI_API_KEY:?OPENAI_API_KEY environment variable is required but not set} + volumes: + - ./arch_config.yaml:/app/arch_config.yaml + - /etc/ssl/cert.pem:/etc/ssl/cert.pem jaeger: build: context: ../../shared/jaeger diff --git a/demos/use_cases/mcp_filter/src/rag_agent/__init__.py b/demos/use_cases/mcp_filter/src/rag_agent/__init__.py index aa601877..6589e775 100644 --- a/demos/use_cases/mcp_filter/src/rag_agent/__init__.py +++ b/demos/use_cases/mcp_filter/src/rag_agent/__init__.py @@ -37,6 +37,7 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port): # Map friendly names to agent modules agent_map = { + "input_guards": ("rag_agent.input_guards", "Input Guards Agent"), "query_rewriter": ("rag_agent.query_rewriter", "Query Rewriter Agent"), "context_builder": ("rag_agent.context_builder", "Context Builder Agent"), "response_generator": ( @@ -75,10 +76,12 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port): print(f"Remove --rest-server flag to start {agent} as an MCP server.") return else: - # Only query_rewriter and context_builder support MCP - if agent not in ["query_rewriter", "context_builder"]: + # Only input_guards, query_rewriter and context_builder support MCP + if agent not in ["input_guards", "query_rewriter", "context_builder"]: print(f"Error: Agent '{agent}' does not support MCP mode.") - print(f"MCP is only supported for: query_rewriter, context_builder") + print( + f"MCP is only supported for: input_guards, query_rewriter, context_builder" + ) print(f"Use --rest-server flag to start {agent} as a REST server.") return diff --git a/demos/use_cases/mcp_filter/src/rag_agent/input_guards.py b/demos/use_cases/mcp_filter/src/rag_agent/input_guards.py new file mode 100644 index 00000000..633ab5d0 --- /dev/null +++ b/demos/use_cases/mcp_filter/src/rag_agent/input_guards.py @@ -0,0 +1,153 @@ +import asyncio +import json +import time +from typing import List, Optional, Dict, Any +import uuid +from fastapi import FastAPI, Depends, Request +from fastmcp.exceptions import ToolError +from openai import AsyncOpenAI +import os +import logging + +from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage +from . import mcp +from fastmcp.server.dependencies import get_http_headers + +# Set up logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - [INPUT_GUARDS] - %(levelname)s - %(message)s", +) +logger = logging.getLogger(__name__) + +# Configuration for archgw LLM gateway +LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1") +GUARD_MODEL = "gpt-4o-mini" + +# Initialize OpenAI client for archgw +archgw_client = AsyncOpenAI( + base_url=LLM_GATEWAY_ENDPOINT, + api_key="EMPTY", # archgw doesn't require a real API key +) + +app = FastAPI() + + +async def validate_query_scope( + messages: List[ChatMessage], traceparent_header: str +) -> Dict[str, Any]: + """Validate that the user query is within TechCorp's domain. + + Returns a dict with: + - is_valid: bool indicating if query is within scope + - reason: str explaining why query is out of scope (if applicable) + """ + system_prompt = """You are an input validation guard for TechCorp's customer support system. + +Your job is to determine if a user's query is related to TechCorp and its services/products. + +TechCorp is a technology company that provides: +- Cloud services and infrastructure +- SaaS products +- Technical support +- Service level agreements (SLAs) +- Uptime guarantees +- Enterprise solutions + +ALLOW queries about: +- TechCorp's services, products, or offerings +- TechCorp's pricing, SLAs, uptime, or policies +- Technical support for TechCorp products +- General questions about TechCorp as a company + +REJECT queries about: +- Other companies or their products +- General knowledge questions unrelated to TechCorp +- Personal advice or topics outside TechCorp's domain +- Anything that doesn't relate to TechCorp's business + +Respond in JSON format: +{ + "is_valid": true/false, + "reason": "brief explanation if invalid" +}""" + + # Get the last user message for validation + last_user_message = None + for msg in reversed(messages): + if msg.role == "user": + last_user_message = msg.content + break + + if not last_user_message: + return {"is_valid": True, "reason": ""} + + # Prepare messages for the guard + guard_messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"Query to validate: {last_user_message}"}, + ] + + try: + # Call archgw using OpenAI client + extra_headers = {"x-envoy-max-retries": "3"} + if traceparent_header: + extra_headers["traceparent"] = traceparent_header + + logger.info(f"Validating query scope: '{last_user_message}'") + response = await archgw_client.chat.completions.create( + model=GUARD_MODEL, + messages=guard_messages, + temperature=0.1, + max_tokens=150, + extra_headers=extra_headers, + ) + + result_text = response.choices[0].message.content.strip() + + # Parse JSON response + try: + result = json.loads(result_text) + logger.info(f"Validation result: {result}") + return result + except json.JSONDecodeError: + logger.error(f"Failed to parse validation response: {result_text}") + # Default to allowing if parsing fails + return {"is_valid": True, "reason": ""} + + except Exception as e: + logger.error(f"Error validating query: {e}") + # Default to allowing if validation fails + return {"is_valid": True, "reason": ""} + + +@mcp.tool +async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]: + """Input guard that validates queries are within TechCorp's domain. + + If the query is out of scope, replaces the user message with a rejection notice. + """ + logger.info(f"Received request with {len(messages)} messages") + + # Get traceparent header from HTTP request using FastMCP's dependency function + headers = get_http_headers() + traceparent_header = headers.get("traceparent") + + if traceparent_header: + logger.info(f"Received traceparent header: {traceparent_header}") + else: + logger.info("No traceparent header found") + + # Validate the query scope + validation_result = await validate_query_scope(messages, traceparent_header) + + if not validation_result.get("is_valid", True): + reason = validation_result.get("reason", "Query is outside TechCorp's domain") + logger.warning(f"Query rejected: {reason}") + + # Throw ToolError + error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support." + raise ToolError(error_message) + + logger.info("Query validation passed - forwarding to next filter") + return messages diff --git a/demos/use_cases/mcp_filter/start_agents.sh b/demos/use_cases/mcp_filter/start_agents.sh index 2c2f446e..9a1d9b25 100644 --- a/demos/use_cases/mcp_filter/start_agents.sh +++ b/demos/use_cases/mcp_filter/start_agents.sh @@ -21,17 +21,11 @@ cleanup() { trap cleanup EXIT -# log "Starting input guards filter on port 10500..." -# uv run python -m rag_agent --host 0.0.0.0 --port 10500 --agent input_guards & -# WAIT_FOR_PIDS+=($!) - - -log "Starting query_rewriter agent on port 10500/http..." -uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_rewriter & +log "Starting input_guards agent on port 10500/mcp..." +uv run python -m rag_agent --host 0.0.0.0 --port 10500 --agent input_guards & WAIT_FOR_PIDS+=($!) - -log "Starting query_parser agent on port 10501/mcp..." +log "Starting query_rewriter agent on port 10501/mcp..." uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter & WAIT_FOR_PIDS+=($!) diff --git a/docs/source/guides/prompt_guard.rst b/docs/source/guides/prompt_guard.rst index 9db88ce5..81545a91 100644 --- a/docs/source/guides/prompt_guard.rst +++ b/docs/source/guides/prompt_guard.rst @@ -24,66 +24,93 @@ more reliable, and easier to reason about. How Guardrails Work ------------------- -In Plano, guardrails are usually implemented as filters that run as HTTP services. Each filter receives the incoming prompt and related metadata, evaluates it -against policy, and either lets the request continue (HTTP 200) or terminates it early with an appropriate error code (typically HTTP 4xx for policy failures). +In Plano, guardrails are implemented as MCP filters that validate incoming requests. Each filter receives the chat messages, evaluates them +against policy, and either lets the request continue or raises a ``ToolError`` to reject it with a helpful error message. -The example below shows a simple, plain-Python HTTP service that acts as a topicality guardrail: it rejects any prompt that is not related to the -"weather" domain. +The example below shows an input guard for TechCorp's customer support system that validates queries are within the company's domain: .. code-block:: python - :caption: Example topicality guard filter in plain Python (FastAPI) + :caption: Example domain validation guard using FastMCP - from fastapi import FastAPI, Request, HTTPException + from typing import List + from fastmcp.exceptions import ToolError + from . import mcp - app = FastAPI() + @mcp.tool + async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]: + """Validates queries are within TechCorp's domain.""" - ALLOWED_KEYWORDS = {"weather", "forecast", "temperature", "rain", "snow", "humidity"} + # Get the user's query + user_query = next( + (msg.content for msg in reversed(messages) if msg.role == "user"), + "" + ) - @app.post("/guardrails/topic") - async def topic_guard(request: Request): - body = await request.json() - # Expecting an OpenAI-style request body with messages - messages = body.get("messages", []) - user_content = " ".join( - m["content"] for m in messages if m.get("role") == "user" - ).lower() + # Use an LLM to validate the query scope (simplified) + is_valid = await validate_with_llm(user_query) - if not any(keyword in user_content for keyword in ALLOWED_KEYWORDS): - # Return 400 to indicate a policy failure (not a server error) - raise HTTPException( - status_code=400, - detail={ - "error": "off_topic", - "message": "This assistant only answers weather-related questions.", - }, + if not is_valid: + raise ToolError( + "I can only assist with questions related to TechCorp and its services. " + "Please ask about TechCorp's products, pricing, SLAs, or technical support." ) - # If the prompt is on-topic, just pass the original body through - return body + return messages -To wire this guardrail into Plano, you define a listener of ``type: agent`` and attach a filter chain with a single filter that points -to the Python service above. +To wire this guardrail into Plano, define the filter and add it to your agent's filter chain: .. code-block:: yaml - :caption: Listener (type: agent) with a topicality guard filter + :caption: Plano configuration with input guard filter filters: - - id: topicality_guard - url: http://topic-guard:8000/guardrails/topic + - id: input_guards + url: http://localhost:10500 listeners: - - type: agent - name: agent_listener + - type: agent + name: agent_1 port: 8001 - router: arch_agent_router + router: plano_orchestrator_v1 agents: - - id: rag_agent + - id: rag_agent description: virtual assistant for retrieval augmented generation tasks filter_chain: - - topicality_guard + - input_guards -When a request arrives at ``agent_listener``, Plano will first call the ``topicality_guard`` filter. If the filter returns **HTTP 200**, -the request continues on to the configured agent or prompt target. If the filter returns **HTTP 400**, Plano returns that error back to -the caller and does not forward the request further—enforcing your domain guardrail without changing any application code. +When a request arrives at ``agent_1``, Plano invokes the ``input_guards`` filter first. If validation passes, the request continues to +the agent. If validation fails (``ToolError`` raised), Plano returns an error response to the caller. + +Testing the Guardrail +--------------------- + +Here's an example of the guardrail in action, rejecting a query about Apple Corporation (outside TechCorp's domain): + +.. code-block:: bash + :caption: Request that violates the guardrail policy + + curl -X POST http://localhost:8001/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "gpt-4", + "messages": [ + { + "role": "user", + "content": "what is sla for apple corporation?" + } + ], + "stream": false + }' + +.. code-block:: json + :caption: Error response from the guardrail + + { + "error": "ClientError", + "agent": "input_guards", + "status": 400, + "agent_response": "I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. The query is about SLA for Apple Corporation, which is unrelated to TechCorp.\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support." + } + +This prevents out-of-scope queries from reaching your agent while providing clear feedback to users about why their request was rejected.