mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Added request id across agents
This commit is contained in:
parent
2cf2b821f2
commit
dffd5dd2ab
4 changed files with 14 additions and 14 deletions
|
|
@ -61,7 +61,7 @@ def load_knowledge_base():
|
|||
|
||||
|
||||
async def find_relevant_passages(
|
||||
query: str, traceparent: Optional[str] = None, top_k: int = 3
|
||||
query: str, traceparent: Optional[str] = None, request_id: Optional[str] = None, top_k: int = 3
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ async def find_relevant_passages(
|
|||
logger.info(f"Calling archgw to find relevant passages for query: '{query}'")
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {"x-envoy-max-retries": "3"}
|
||||
extra_headers = {"x-envoy-max-retries": "3", "x-request-id": request_id}
|
||||
if traceparent:
|
||||
extra_headers["traceparent"] = traceparent
|
||||
|
||||
|
|
@ -133,7 +133,7 @@ async def find_relevant_passages(
|
|||
|
||||
|
||||
async def augment_query_with_context(
|
||||
messages: List[ChatMessage], traceparent: Optional[str] = None
|
||||
messages: List[ChatMessage], traceparent: Optional[str] = None, request_id: Optional[str] = None
|
||||
) -> List[ChatMessage]:
|
||||
"""Extract user query, find relevant context, and augment the messages."""
|
||||
|
||||
|
|
@ -154,7 +154,7 @@ async def augment_query_with_context(
|
|||
logger.info(f"Processing user query: '{last_user_message}'")
|
||||
|
||||
# Find relevant passages
|
||||
relevant_passages = await find_relevant_passages(last_user_message, traceparent)
|
||||
relevant_passages = await find_relevant_passages(last_user_message, traceparent, request_id)
|
||||
|
||||
if not relevant_passages:
|
||||
logger.info("No relevant passages found, returning original messages")
|
||||
|
|
@ -208,7 +208,7 @@ async def context_builder(
|
|||
logger.info("No traceparent header found")
|
||||
|
||||
# Augment the user query with relevant context
|
||||
updated_messages = await augment_query_with_context(messages, traceparent_header)
|
||||
updated_messages = await augment_query_with_context(messages, traceparent_header, request_id)
|
||||
|
||||
# Return as dict to minimize text serialization
|
||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ app = FastAPI(title="RAG Agent Input Guards", version="1.0.0")
|
|||
|
||||
|
||||
async def validate_query_scope(
|
||||
messages: List[ChatMessage], traceparent_header: str
|
||||
messages: List[ChatMessage], traceparent_header: Optional[str] = None, request_id: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Validate that the user query is within TechCorp's domain.
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ Respond in JSON format:
|
|||
|
||||
try:
|
||||
# Call archgw using OpenAI client
|
||||
extra_headers = {"x-envoy-max-retries": "3"}
|
||||
extra_headers = {"x-envoy-max-retries": "3", "x-request-id": request_id}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
|
||||
|
|
@ -146,7 +146,7 @@ async def input_guards(
|
|||
logger.info("No traceparent header found")
|
||||
|
||||
# Validate the query scope
|
||||
validation_result = await validate_query_scope(messages, traceparent_header)
|
||||
validation_result = await validate_query_scope(messages, traceparent_header, request_id)
|
||||
|
||||
if not validation_result.get("is_valid", True):
|
||||
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
||||
|
|
|
|||
|
|
@ -54,11 +54,9 @@ Return only the rewritten query, nothing else."""
|
|||
for msg in messages:
|
||||
rewrite_messages.append({"role": msg.role, "content": msg.content})
|
||||
|
||||
extra_headers = {"x-envoy-max-retries": "3"}
|
||||
extra_headers = {"x-envoy-max-retries": "3", "x-request-id": request_id}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
if request_id:
|
||||
extra_headers["x-request-id"] = request_id
|
||||
|
||||
try:
|
||||
logger.info(f"Calling archgw at {LLM_GATEWAY_ENDPOINT} to rewrite query")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
|
|||
|
||||
# Get traceparent header from HTTP request
|
||||
traceparent_header = request.headers.get("traceparent")
|
||||
request_id = request.headers.get("x-request-id") or f"req-{uuid.uuid4().hex}"
|
||||
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
|
|
@ -75,16 +76,17 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
|
|||
logger.info("No traceparent header found")
|
||||
|
||||
return StreamingResponse(
|
||||
stream_chat_completions(request_body, traceparent_header),
|
||||
stream_chat_completions(request_body, traceparent_header,request_id),
|
||||
media_type="text/plain",
|
||||
headers={
|
||||
"content-type": "text/event-stream",
|
||||
"x-request-id": request_id,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
async def stream_chat_completions(
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None
|
||||
request_body: ChatCompletionRequest, traceparent_header: str = None, request_id: str = None
|
||||
):
|
||||
"""Generate streaming chat completions."""
|
||||
# Prepare messages for response generation
|
||||
|
|
@ -97,7 +99,7 @@ async def stream_chat_completions(
|
|||
)
|
||||
|
||||
# Prepare extra headers if traceparent is provided
|
||||
extra_headers = {"x-envoy-max-retries": "3"}
|
||||
extra_headers = {"x-envoy-max-retries": "3", "x-request-id": request_id}
|
||||
if traceparent_header:
|
||||
extra_headers["traceparent"] = traceparent_header
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue