add input_guards and update prompt guards section

This commit is contained in:
Adil Hafeez 2025-12-23 15:17:55 -08:00
parent bbadd61de0
commit 0533987a2f
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
11 changed files with 321 additions and 71 deletions

View file

@ -61,6 +61,7 @@ pub async fn agent_chat(
body,
}) = &err
{
warn!(
"Client error from agent '{}' (HTTP {}): {}",
agent, status, body
@ -77,7 +78,7 @@ pub async fn agent_chat(
let json_string = error_json.to_string();
let mut response = Response::new(ResponseHandler::create_full_body(json_string));
*response.status_mut() = hyper::StatusCode::from_u16(*status)
.unwrap_or(hyper::StatusCode::INTERNAL_SERVER_ERROR);
.unwrap_or(hyper::StatusCode::BAD_REQUEST);
response.headers_mut().insert(
hyper::header::CONTENT_TYPE,
"application/json".parse().unwrap(),

View file

@ -4,7 +4,7 @@ use std::collections::HashMap;
pub const JSON_RPC_VERSION: &str = "2.0";
pub const TOOL_CALL_METHOD : &str = "tools/call";
pub const MCP_INITIALIZE: &str = "initialize";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "initialize/notification";
pub const MCP_INITIALIZE_NOTIFICATION: &str = "notifications/initialized";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]

View file

@ -551,7 +551,7 @@ impl PipelineProcessor {
return Err(PipelineError::ClientError {
agent: agent.id.clone(),
status: http_status.as_u16(),
status: hyper::StatusCode::BAD_REQUEST.as_u16(),
body: error_message,
});
}

View file

@ -0,0 +1,26 @@
FROM python:3.13-slim
WORKDIR /app
# Install bash and uv
RUN apt-get update && apt-get install -y bash && rm -rf /var/lib/apt/lists/*
RUN pip install --no-cache-dir uv
# Copy dependency files
COPY pyproject.toml README.md ./
# Copy source code
COPY src/ ./src/
COPY start_agents.sh ./
# Install dependencies using uv
RUN uv pip install --system --no-cache click fastmcp pydantic fastapi uvicorn openai
# Make start script executable
RUN chmod +x start_agents.sh
# Expose ports for all agents
EXPOSE 10500 10501 10502 10505
# Run the start script with bash
CMD ["bash", "./start_agents.sh"]

View file

@ -4,14 +4,21 @@ A multi-agent RAG system demonstrating archgw's agent filter chain with MCP prot
## Architecture
This demo consists of three components:
1. **Query Rewriter** (MCP filter) - Rewrites user queries for better retrieval
2. **Context Builder** (MCP filter) - Retrieves relevant context from knowledge base
3. **RAG Agent** (REST) - Generates final responses based on augmented context
This demo consists of four components:
1. **Input Guards** (MCP filter) - Validates queries are within TechCorp's domain
2. **Query Rewriter** (MCP filter) - Rewrites user queries for better retrieval
3. **Context Builder** (MCP filter) - Retrieves relevant context from knowledge base
4. **RAG Agent** (REST) - Generates final responses based on augmented context
## Components
### Query Rewriter Filter (MCP)
### Input Guards Filter (MCP)
- **Port**: 10500
- **Tool**: `input_guards`
- Validates queries are within TechCorp's domain
- Rejects queries about other companies or unrelated topics
### Query Rewrit3r Filter (MCP)
- **Port**: 10501
- **Tool**: `query_rewriter`
- Improves queries using LLM before retrieval
@ -34,6 +41,7 @@ This demo consists of three components:
```
This starts:
- Input Guards MCP server on port 10500
- Query Rewriter MCP server on port 10501
- Context Builder MCP server on port 10502
- RAG Agent REST server on port 10505
@ -59,29 +67,37 @@ The `arch_config.yaml` defines how agents are connected:
```yaml
filters:
- id: input_guards
url: http://host.docker.internal:10500
# type: mcp (default)
# tool: input_guards (default - same as filter id)
- id: query_rewriter
url: mcp://host.docker.internal:10500
tool: rewrite_query_with_archgw # MCP tool name
url: http://host.docker.internal:10501
# type: mcp (default)
- id: context_builder
url: mcp://host.docker.internal:10501
tool: chat_completions
url: http://host.docker.internal:10502
```
How It Works
## How It Works
1. User sends request to archgw listener on port 8001
2. Request passes through MCP filter chain:
- **Input Guards** validates the query is within TechCorp's domain
- **Query Rewriter** rewrites the query for better retrieval
- **Context Builder** augments query with relevant knowledge base passages
3. Augmented request is forwarded to **RAG Agent** REST endpoint
4. RAG Agent generates final response using LLM
## Configuration
## Additional Configuration
See `arch_config.yaml` for the complete filter chain setup. The MCP filters use default settings:
- `type: mcp` (default)
- `transport: streamable-http` (default)
- Tool name defaults to filter ID `sample_queries.md` for example queries to test the RAG system.
- Tool name defaults to filter ID
See `sample_queries.md` for example queries to test the RAG system.
Example request:
```bash

View file

@ -5,12 +5,16 @@ agents:
url: http://host.docker.internal:10505
filters:
- id: query_rewriter
- id: input_guards
url: http://host.docker.internal:10500
type: http
# type: http or mcp, mcp is default
# transport: streamable-http # default is streamable-http
# tool: query_rewriter # default name is the filter id
# type: mcp (default)
# transport: streamable-http (default)
# tool: input_guards (default - same as filter id)
- id: query_rewriter
url: http://host.docker.internal:10501
# type: mcp (default)
# transport: streamable-http (default)
# tool: query_rewriter (default - same as filter id)
- id: context_builder
url: http://host.docker.internal:10502
@ -36,6 +40,7 @@ listeners:
- id: rag_agent
description: virtual assistant for retrieval augmented generation tasks
filter_chain:
- input_guards
- query_rewriter
- context_builder
tracing:

View file

@ -1,4 +1,29 @@
services:
rag-agents:
build:
context: .
dockerfile: Dockerfile
ports:
- "10500:10500"
- "10501:10501"
- "10502:10502"
- "10505:10505"
environment:
- LLM_GATEWAY_ENDPOINT=${LLM_GATEWAY_ENDPOINT:-http://host.docker.internal:12000/v1}
- OPENAI_API_KEY=${OPENAI_API_KEY:?OPENAI_API_KEY environment variable is required but not set}
archgw:
build:
context: ../../../
dockerfile: arch/Dockerfile
ports:
- "12000:12000"
- "8001:8001"
environment:
- ARCH_CONFIG_PATH=/config/arch_config.yaml
- OPENAI_API_KEY=${OPENAI_API_KEY:?OPENAI_API_KEY environment variable is required but not set}
volumes:
- ./arch_config.yaml:/app/arch_config.yaml
- /etc/ssl/cert.pem:/etc/ssl/cert.pem
jaeger:
build:
context: ../../shared/jaeger

View file

@ -37,6 +37,7 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port):
# Map friendly names to agent modules
agent_map = {
"input_guards": ("rag_agent.input_guards", "Input Guards Agent"),
"query_rewriter": ("rag_agent.query_rewriter", "Query Rewriter Agent"),
"context_builder": ("rag_agent.context_builder", "Context Builder Agent"),
"response_generator": (
@ -75,10 +76,12 @@ def main(host, port, agent, transport, agent_name, rest_server, rest_port):
print(f"Remove --rest-server flag to start {agent} as an MCP server.")
return
else:
# Only query_rewriter and context_builder support MCP
if agent not in ["query_rewriter", "context_builder"]:
# Only input_guards, query_rewriter and context_builder support MCP
if agent not in ["input_guards", "query_rewriter", "context_builder"]:
print(f"Error: Agent '{agent}' does not support MCP mode.")
print(f"MCP is only supported for: query_rewriter, context_builder")
print(
f"MCP is only supported for: input_guards, query_rewriter, context_builder"
)
print(f"Use --rest-server flag to start {agent} as a REST server.")
return

View file

@ -0,0 +1,153 @@
import asyncio
import json
import time
from typing import List, Optional, Dict, Any
import uuid
from fastapi import FastAPI, Depends, Request
from fastmcp.exceptions import ToolError
from openai import AsyncOpenAI
import os
import logging
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
from . import mcp
from fastmcp.server.dependencies import get_http_headers
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - [INPUT_GUARDS] - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Configuration for archgw LLM gateway
LLM_GATEWAY_ENDPOINT = os.getenv("LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1")
GUARD_MODEL = "gpt-4o-mini"
# Initialize OpenAI client for archgw
archgw_client = AsyncOpenAI(
base_url=LLM_GATEWAY_ENDPOINT,
api_key="EMPTY", # archgw doesn't require a real API key
)
app = FastAPI()
async def validate_query_scope(
messages: List[ChatMessage], traceparent_header: str
) -> Dict[str, Any]:
"""Validate that the user query is within TechCorp's domain.
Returns a dict with:
- is_valid: bool indicating if query is within scope
- reason: str explaining why query is out of scope (if applicable)
"""
system_prompt = """You are an input validation guard for TechCorp's customer support system.
Your job is to determine if a user's query is related to TechCorp and its services/products.
TechCorp is a technology company that provides:
- Cloud services and infrastructure
- SaaS products
- Technical support
- Service level agreements (SLAs)
- Uptime guarantees
- Enterprise solutions
ALLOW queries about:
- TechCorp's services, products, or offerings
- TechCorp's pricing, SLAs, uptime, or policies
- Technical support for TechCorp products
- General questions about TechCorp as a company
REJECT queries about:
- Other companies or their products
- General knowledge questions unrelated to TechCorp
- Personal advice or topics outside TechCorp's domain
- Anything that doesn't relate to TechCorp's business
Respond in JSON format:
{
"is_valid": true/false,
"reason": "brief explanation if invalid"
}"""
# Get the last user message for validation
last_user_message = None
for msg in reversed(messages):
if msg.role == "user":
last_user_message = msg.content
break
if not last_user_message:
return {"is_valid": True, "reason": ""}
# Prepare messages for the guard
guard_messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"Query to validate: {last_user_message}"},
]
try:
# Call archgw using OpenAI client
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
logger.info(f"Validating query scope: '{last_user_message}'")
response = await archgw_client.chat.completions.create(
model=GUARD_MODEL,
messages=guard_messages,
temperature=0.1,
max_tokens=150,
extra_headers=extra_headers,
)
result_text = response.choices[0].message.content.strip()
# Parse JSON response
try:
result = json.loads(result_text)
logger.info(f"Validation result: {result}")
return result
except json.JSONDecodeError:
logger.error(f"Failed to parse validation response: {result_text}")
# Default to allowing if parsing fails
return {"is_valid": True, "reason": ""}
except Exception as e:
logger.error(f"Error validating query: {e}")
# Default to allowing if validation fails
return {"is_valid": True, "reason": ""}
@mcp.tool
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Input guard that validates queries are within TechCorp's domain.
If the query is out of scope, replaces the user message with a rejection notice.
"""
logger.info(f"Received request with {len(messages)} messages")
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
else:
logger.info("No traceparent header found")
# Validate the query scope
validation_result = await validate_query_scope(messages, traceparent_header)
if not validation_result.get("is_valid", True):
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
logger.warning(f"Query rejected: {reason}")
# Throw ToolError
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
raise ToolError(error_message)
logger.info("Query validation passed - forwarding to next filter")
return messages

View file

@ -21,17 +21,11 @@ cleanup() {
trap cleanup EXIT
# log "Starting input guards filter on port 10500..."
# uv run python -m rag_agent --host 0.0.0.0 --port 10500 --agent input_guards &
# WAIT_FOR_PIDS+=($!)
log "Starting query_rewriter agent on port 10500/http..."
uv run python -m rag_agent --rest-server --host 0.0.0.0 --rest-port 10500 --agent query_rewriter &
log "Starting input_guards agent on port 10500/mcp..."
uv run python -m rag_agent --host 0.0.0.0 --port 10500 --agent input_guards &
WAIT_FOR_PIDS+=($!)
log "Starting query_parser agent on port 10501/mcp..."
log "Starting query_rewriter agent on port 10501/mcp..."
uv run python -m rag_agent --host 0.0.0.0 --port 10501 --agent query_rewriter &
WAIT_FOR_PIDS+=($!)

View file

@ -24,66 +24,93 @@ more reliable, and easier to reason about.
How Guardrails Work
-------------------
In Plano, guardrails are usually implemented as filters that run as HTTP services. Each filter receives the incoming prompt and related metadata, evaluates it
against policy, and either lets the request continue (HTTP 200) or terminates it early with an appropriate error code (typically HTTP 4xx for policy failures).
In Plano, guardrails are implemented as MCP filters that validate incoming requests. Each filter receives the chat messages, evaluates them
against policy, and either lets the request continue or raises a ``ToolError`` to reject it with a helpful error message.
The example below shows a simple, plain-Python HTTP service that acts as a topicality guardrail: it rejects any prompt that is not related to the
"weather" domain.
The example below shows an input guard for TechCorp's customer support system that validates queries are within the company's domain:
.. code-block:: python
:caption: Example topicality guard filter in plain Python (FastAPI)
:caption: Example domain validation guard using FastMCP
from fastapi import FastAPI, Request, HTTPException
from typing import List
from fastmcp.exceptions import ToolError
from . import mcp
app = FastAPI()
@mcp.tool
async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
"""Validates queries are within TechCorp's domain."""
ALLOWED_KEYWORDS = {"weather", "forecast", "temperature", "rain", "snow", "humidity"}
# Get the user's query
user_query = next(
(msg.content for msg in reversed(messages) if msg.role == "user"),
""
)
@app.post("/guardrails/topic")
async def topic_guard(request: Request):
body = await request.json()
# Expecting an OpenAI-style request body with messages
messages = body.get("messages", [])
user_content = " ".join(
m["content"] for m in messages if m.get("role") == "user"
).lower()
# Use an LLM to validate the query scope (simplified)
is_valid = await validate_with_llm(user_query)
if not any(keyword in user_content for keyword in ALLOWED_KEYWORDS):
# Return 400 to indicate a policy failure (not a server error)
raise HTTPException(
status_code=400,
detail={
"error": "off_topic",
"message": "This assistant only answers weather-related questions.",
},
if not is_valid:
raise ToolError(
"I can only assist with questions related to TechCorp and its services. "
"Please ask about TechCorp's products, pricing, SLAs, or technical support."
)
# If the prompt is on-topic, just pass the original body through
return body
return messages
To wire this guardrail into Plano, you define a listener of ``type: agent`` and attach a filter chain with a single filter that points
to the Python service above.
To wire this guardrail into Plano, define the filter and add it to your agent's filter chain:
.. code-block:: yaml
:caption: Listener (type: agent) with a topicality guard filter
:caption: Plano configuration with input guard filter
filters:
- id: topicality_guard
url: http://topic-guard:8000/guardrails/topic
- id: input_guards
url: http://localhost:10500
listeners:
- type: agent
name: agent_listener
- type: agent
name: agent_1
port: 8001
router: arch_agent_router
router: plano_orchestrator_v1
agents:
- id: rag_agent
- id: rag_agent
description: virtual assistant for retrieval augmented generation tasks
filter_chain:
- topicality_guard
- input_guards
When a request arrives at ``agent_listener``, Plano will first call the ``topicality_guard`` filter. If the filter returns **HTTP 200**,
the request continues on to the configured agent or prompt target. If the filter returns **HTTP 400**, Plano returns that error back to
the caller and does not forward the request further—enforcing your domain guardrail without changing any application code.
When a request arrives at ``agent_1``, Plano invokes the ``input_guards`` filter first. If validation passes, the request continues to
the agent. If validation fails (``ToolError`` raised), Plano returns an error response to the caller.
Testing the Guardrail
---------------------
Here's an example of the guardrail in action, rejecting a query about Apple Corporation (outside TechCorp's domain):
.. code-block:: bash
:caption: Request that violates the guardrail policy
curl -X POST http://localhost:8001/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4",
"messages": [
{
"role": "user",
"content": "what is sla for apple corporation?"
}
],
"stream": false
}'
.. code-block:: json
:caption: Error response from the guardrail
{
"error": "ClientError",
"agent": "input_guards",
"status": 400,
"agent_response": "I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. The query is about SLA for Apple Corporation, which is unrelated to TechCorp.\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
}
This prevents out-of-scope queries from reaching your agent while providing clear feedback to users about why their request was rejected.