ensure that request id is consistent

This commit is contained in:
Adil Hafeez 2026-01-06 14:36:51 -08:00
parent 745b36fdef
commit a3c80e8e90
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
9 changed files with 101 additions and 28 deletions

View file

@ -57,7 +57,10 @@ def load_knowledge_base():
async def find_relevant_passages(
query: str, traceparent: Optional[str] = None, top_k: int = 3
query: str,
traceparent: Optional[str] = None,
request_id: Optional[str] = None,
top_k: int = 3,
) -> List[Dict[str, str]]:
"""Use the LLM to find the most relevant passages from the knowledge base."""
@ -92,7 +95,11 @@ async def find_relevant_passages(
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
# Prepare extra headers if traceparent is provided
extra_headers = {"x-envoy-max-retries": "3"}
extra_headers = {
"x-envoy-max-retries": "3",
}
if request_id:
extra_headers["x-request-id"] = request_id
if traceparent:
extra_headers["traceparent"] = traceparent
@ -129,7 +136,9 @@ async def find_relevant_passages(
async def augment_query_with_context(
messages: List[ChatMessage], traceparent: Optional[str] = None
messages: List[ChatMessage],
traceparent: Optional[str] = None,
request_id: Optional[str] = None,
) -> List[ChatMessage]:
"""Extract user query, find relevant context, and augment the messages."""
@ -150,7 +159,9 @@ async def augment_query_with_context(
logger.info(f"Processing user query: '{last_user_message}'")
# Find relevant passages
relevant_passages = await find_relevant_passages(last_user_message, traceparent)
relevant_passages = await find_relevant_passages(
last_user_message, traceparent, request_id
)
if not relevant_passages:
logger.info("No relevant passages found, returning original messages")
@ -191,6 +202,8 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
# Get traceparent header from MCP request
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
request_id = headers.get("x-request-id")
logger.info(f"Received request ID: {request_id}")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
@ -198,7 +211,9 @@ async def context_builder(messages: List[ChatMessage]) -> List[ChatMessage]:
logger.info("No traceparent header found")
# Augment the user query with relevant context
updated_messages = await augment_query_with_context(messages, traceparent_header)
updated_messages = await augment_query_with_context(
messages, traceparent_header, request_id
)
# Return as dict to minimize text serialization
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]

View file

@ -34,7 +34,9 @@ app = FastAPI()
async def validate_query_scope(
messages: List[ChatMessage], traceparent_header: str
messages: List[ChatMessage],
traceparent_header: str,
request_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Validate that the user query is within TechCorp's domain.
@ -94,6 +96,9 @@ Respond in JSON format:
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
if request_id:
extra_headers["x-request-id"] = request_id
logger.info(f"Validating query scope: '{last_user_message}'")
response = await archgw_client.chat.completions.create(
model=GUARD_MODEL,
@ -132,6 +137,7 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
request_id = headers.get("x-request-id")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
@ -139,7 +145,9 @@ async def input_guards(messages: List[ChatMessage]) -> List[ChatMessage]:
logger.info("No traceparent header found")
# Validate the query scope
validation_result = await validate_query_scope(messages, traceparent_header)
validation_result = await validate_query_scope(
messages, traceparent_header, request_id
)
if not validation_result.get("is_valid", True):
reason = validation_result.get("reason", "Query is outside TechCorp's domain")

View file

@ -33,7 +33,9 @@ app = FastAPI()
async def rewrite_query_with_archgw(
messages: List[ChatMessage], traceparent_header: str
messages: List[ChatMessage],
traceparent_header: str,
request_id: Optional[str] = None,
) -> str:
"""Rewrite the user query using LLM for better retrieval."""
system_prompt = """You are a query rewriter that improves user queries for better retrieval.
@ -59,6 +61,8 @@ async def rewrite_query_with_archgw(
extra_headers = {"x-envoy-max-retries": "3"}
if traceparent_header:
extra_headers["traceparent"] = traceparent_header
if request_id:
extra_headers["x-request-id"] = request_id
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
response = await archgw_client.chat.completions.create(
model=QUERY_REWRITE_MODEL,
@ -93,6 +97,7 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
# Get traceparent header from HTTP request using FastMCP's dependency function
headers = get_http_headers()
traceparent_header = headers.get("traceparent")
request_id = headers.get("x-request-id")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
@ -100,7 +105,9 @@ async def query_rewriter(messages: List[ChatMessage]) -> List[ChatMessage]:
logger.info("No traceparent header found")
# Call archgw to rewrite the last user query
rewritten_query = await rewrite_query_with_archgw(messages, traceparent_header)
rewritten_query = await rewrite_query_with_archgw(
messages, traceparent_header, request_id
)
# Create updated messages with the rewritten query
updated_messages = messages.copy()

View file

@ -68,6 +68,7 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
# Get traceparent header from HTTP request
traceparent_header = request.headers.get("traceparent")
request_id = request.headers.get("x-request-id")
if traceparent_header:
logger.info(f"Received traceparent header: {traceparent_header}")
@ -75,7 +76,7 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
logger.info("No traceparent header found")
return StreamingResponse(
stream_chat_completions(request_body, traceparent_header),
stream_chat_completions(request_body, traceparent_header, request_id),
media_type="text/plain",
headers={
"content-type": "text/event-stream",
@ -84,7 +85,9 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
async def stream_chat_completions(
request_body: ChatCompletionRequest, traceparent_header: str = None
request_body: ChatCompletionRequest,
traceparent_header: str = None,
request_id: str = None,
):
"""Generate streaming chat completions."""
# Prepare messages for response generation
@ -96,8 +99,11 @@ async def stream_chat_completions(
f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to generate streaming response"
)
logger.info(f"rag_agent - request_id: {request_id}")
# Prepare extra headers if traceparent is provided
extra_headers = {"x-envoy-max-retries": "3"}
if request_id:
extra_headers["x-request-id"] = request_id
if traceparent_header:
extra_headers["traceparent"] = traceparent_header