From a3b070440dd180254850dbbc20781537957c1dcc Mon Sep 17 00:00:00 2001 From: MeiyuZhong Date: Thu, 8 Jan 2026 12:27:43 -0800 Subject: [PATCH] Fix pre-commit formatting --- .../src/rag_agent/context_builder.py | 6 +++++- .../http_filter/src/rag_agent/input_guards.py | 14 ++++++++++---- .../src/rag_agent/query_rewriter.py | 19 ++++++++++++------- demos/use_cases/http_filter/start_agents.sh | 1 - 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/demos/use_cases/http_filter/src/rag_agent/context_builder.py b/demos/use_cases/http_filter/src/rag_agent/context_builder.py index 78687e30..62d51468 100644 --- a/demos/use_cases/http_filter/src/rag_agent/context_builder.py +++ b/demos/use_cases/http_filter/src/rag_agent/context_builder.py @@ -8,6 +8,7 @@ from pathlib import Path from .api import ChatMessage from fastapi import Request, FastAPI + # from . import mcp # from fastmcp.server.dependencies import get_http_headers @@ -186,8 +187,11 @@ async def augment_query_with_context( # Load knowledge base on module import load_knowledge_base() + @app.post("/") -async def context_builder(messages: List[ChatMessage],request: Request) -> List[ChatMessage]: +async def context_builder( + messages: List[ChatMessage], request: Request +) -> List[ChatMessage]: """MCP tool that augments user queries with relevant context from the knowledge base.""" logger.info(f"Received chat completion request with {len(messages)} messages") diff --git a/demos/use_cases/http_filter/src/rag_agent/input_guards.py b/demos/use_cases/http_filter/src/rag_agent/input_guards.py index e98236fb..2b92940e 100644 --- a/demos/use_cases/http_filter/src/rag_agent/input_guards.py +++ b/demos/use_cases/http_filter/src/rag_agent/input_guards.py @@ -4,6 +4,7 @@ import time from typing import List, Optional, Dict, Any import uuid from fastapi import FastAPI, Depends, Request, HTTPException + # from fastmcp.exceptions import ToolError from openai import AsyncOpenAI import os @@ -11,6 +12,7 @@ import logging from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage from . import mcp + # from fastmcp.server.dependencies import get_http_headers # Set up logging @@ -123,7 +125,9 @@ Respond in JSON format: # @mcp.tool @app.post("/") -async def input_guards(messages: List[ChatMessage], request: Request) -> List[ChatMessage]: +async def input_guards( + messages: List[ChatMessage], request: Request +) -> List[ChatMessage]: """Input guard that validates queries are within TechCorp's domain. If the query is out of scope, replaces the user message with a rejection notice. @@ -136,7 +140,6 @@ async def input_guards(messages: List[ChatMessage], request: Request) -> List[Ch traceparent_header = request.headers.get("traceparent") request_id = request.headers.get("x-request-id") - if traceparent_header: logger.info(f"Received traceparent header: {traceparent_header}") else: @@ -152,11 +155,14 @@ async def input_guards(messages: List[ChatMessage], request: Request) -> List[Ch # Throw ToolError error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support." # raise ToolError(error_message) - raise HTTPException(status_code=400, detail={"error": "out_of_scope", "message": error_message}) + raise HTTPException( + status_code=400, detail={"error": "out_of_scope", "message": error_message} + ) logger.info("Query validation passed - forwarding to next filter") return messages + @app.get("/health") async def health(): - return {"status": "healthy"} \ No newline at end of file + return {"status": "healthy"} diff --git a/demos/use_cases/http_filter/src/rag_agent/query_rewriter.py b/demos/use_cases/http_filter/src/rag_agent/query_rewriter.py index f6796887..f35012d4 100644 --- a/demos/use_cases/http_filter/src/rag_agent/query_rewriter.py +++ b/demos/use_cases/http_filter/src/rag_agent/query_rewriter.py @@ -9,6 +9,7 @@ import os import logging from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage + # from . import mcp # from fastmcp.server.dependencies import get_http_headers @@ -82,9 +83,10 @@ Return only the rewritten query, nothing else.""" return "" - @app.post("/") -async def query_rewriter_http(messages: List[ChatMessage], request: Request) -> List[ChatMessage]: +async def query_rewriter_http( + messages: List[ChatMessage], request: Request +) -> List[ChatMessage]: """HTTP filter endpoint used by Plano (type: http).""" logger.info(f"Received request with {len(messages)} messages") @@ -96,7 +98,9 @@ async def query_rewriter_http(messages: List[ChatMessage], request: Request) -> else: logger.info("No traceparent header found") - rewritten_query = await rewrite_query_with_archgw(messages, traceparent_header, request_id) + rewritten_query = await rewrite_query_with_archgw( + messages, traceparent_header, request_id + ) # Create updated messages with the rewritten query updated_messages = messages.copy() @@ -109,22 +113,23 @@ async def query_rewriter_http(messages: List[ChatMessage], request: Request) -> f"Updated user query from '{original_query}' to '{rewritten_query}'" ) break - updated_messages_data = [{"role": msg.role, "content": msg.content} for msg in updated_messages] + updated_messages_data = [ + {"role": msg.role, "content": msg.content} for msg in updated_messages + ] updated_messages = [ChatMessage(**msg) for msg in updated_messages_data] logger.info("Returning rewritten chat completion response") return updated_messages + @app.get("/health") async def health(): return {"status": "healthy"} + def start_server(host: str = "0.0.0.0", port: int = 10501): """Start the FastAPI server for query rewriter.""" import uvicorn logger.info(f"Starting Query Rewriter REST server on {host}:{port}") uvicorn.run(app, host=host, port=port) - - - diff --git a/demos/use_cases/http_filter/start_agents.sh b/demos/use_cases/http_filter/start_agents.sh index 36260ce3..06cabeec 100644 --- a/demos/use_cases/http_filter/start_agents.sh +++ b/demos/use_cases/http_filter/start_agents.sh @@ -76,4 +76,3 @@ pids+=($!) for PID in "${pids[@]}"; do wait "$PID" done -