fix http_filter agent: request id + pre-commit

This commit is contained in:
MeiyuZhong 2026-01-12 12:07:48 -08:00
parent dffd5dd2ab
commit bfe8d7b1aa
No known key found for this signature in database
3 changed files with 23 additions and 8 deletions

View file

@ -61,7 +61,10 @@ def load_knowledge_base():
async def find_relevant_passages( async def find_relevant_passages(
query: str, traceparent: Optional[str] = None, request_id: Optional[str] = None, top_k: int = 3 query: str,
traceparent: Optional[str] = None,
request_id: Optional[str] = None,
top_k: int = 3,
) -> List[Dict[str, str]]: ) -> List[Dict[str, str]]:
"""Use the LLM to find the most relevant passages from the knowledge base.""" """Use the LLM to find the most relevant passages from the knowledge base."""
@ -133,7 +136,9 @@ async def find_relevant_passages(
async def augment_query_with_context( async def augment_query_with_context(
messages: List[ChatMessage], traceparent: Optional[str] = None, request_id: Optional[str] = None messages: List[ChatMessage],
traceparent: Optional[str] = None,
request_id: Optional[str] = None,
) -> List[ChatMessage]: ) -> List[ChatMessage]:
"""Extract user query, find relevant context, and augment the messages.""" """Extract user query, find relevant context, and augment the messages."""
@ -154,7 +159,9 @@ async def augment_query_with_context(
logger.info(f"Processing user query: '{last_user_message}'") logger.info(f"Processing user query: '{last_user_message}'")
# Find relevant passages # Find relevant passages
relevant_passages = await find_relevant_passages(last_user_message, traceparent, request_id) relevant_passages = await find_relevant_passages(
last_user_message, traceparent, request_id
)
if not relevant_passages: if not relevant_passages:
logger.info("No relevant passages found, returning original messages") logger.info("No relevant passages found, returning original messages")
@ -208,7 +215,9 @@ async def context_builder(
logger.info("No traceparent header found") logger.info("No traceparent header found")
# Augment the user query with relevant context # Augment the user query with relevant context
updated_messages = await augment_query_with_context(messages, traceparent_header, request_id) updated_messages = await augment_query_with_context(
messages, traceparent_header, request_id
)
# Return as dict to minimize text serialization # Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages] return [{"role": msg.role, "content": msg.content} for msg in updated_messages]

View file

@ -36,7 +36,9 @@ app = FastAPI(title="RAG Agent Input Guards", version="1.0.0")
async def validate_query_scope( async def validate_query_scope(
messages: List[ChatMessage], traceparent_header: Optional[str] = None, request_id: Optional[str] = None messages: List[ChatMessage],
traceparent_header: Optional[str] = None,
request_id: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Validate that the user query is within TechCorp's domain. """Validate that the user query is within TechCorp's domain.
@ -146,7 +148,9 @@ async def input_guards(
logger.info("No traceparent header found") logger.info("No traceparent header found")
# Validate the query scope # Validate the query scope
validation_result = await validate_query_scope(messages, traceparent_header, request_id) validation_result = await validate_query_scope(
messages, traceparent_header, request_id
)
if not validation_result.get("is_valid", True): if not validation_result.get("is_valid", True):
reason = validation_result.get("reason", "Query is outside TechCorp's domain") reason = validation_result.get("reason", "Query is outside TechCorp's domain")

View file

@ -76,7 +76,7 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
logger.info("No traceparent header found") logger.info("No traceparent header found")
return StreamingResponse( return StreamingResponse(
stream_chat_completions(request_body, traceparent_header,request_id), stream_chat_completions(request_body, traceparent_header, request_id),
media_type="text/plain", media_type="text/plain",
headers={ headers={
"content-type": "text/event-stream", "content-type": "text/event-stream",
@ -86,7 +86,9 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
async def stream_chat_completions( async def stream_chat_completions(
request_body: ChatCompletionRequest, traceparent_header: str = None, request_id: str = None request_body: ChatCompletionRequest,
traceparent_header: str = None,
request_id: str = None,
): ):
"""Generate streaming chat completions.""" """Generate streaming chat completions."""
# Prepare messages for response generation # Prepare messages for response generation