mirror of
https://github.com/katanemo/plano.git
synced 2026-06-29 15:49:40 +02:00
fix http_filter agent: request id + pre-commit
This commit is contained in:
parent
dffd5dd2ab
commit
bfe8d7b1aa
3 changed files with 23 additions and 8 deletions
|
|
@ -61,7 +61,10 @@ def load_knowledge_base():
|
||||||
|
|
||||||
|
|
||||||
async def find_relevant_passages(
|
async def find_relevant_passages(
|
||||||
query: str, traceparent: Optional[str] = None, request_id: Optional[str] = None, top_k: int = 3
|
query: str,
|
||||||
|
traceparent: Optional[str] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
|
top_k: int = 3,
|
||||||
) -> List[Dict[str, str]]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
"""Use the LLM to find the most relevant passages from the knowledge base."""
|
||||||
|
|
||||||
|
|
@ -133,7 +136,9 @@ async def find_relevant_passages(
|
||||||
|
|
||||||
|
|
||||||
async def augment_query_with_context(
|
async def augment_query_with_context(
|
||||||
messages: List[ChatMessage], traceparent: Optional[str] = None, request_id: Optional[str] = None
|
messages: List[ChatMessage],
|
||||||
|
traceparent: Optional[str] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
) -> List[ChatMessage]:
|
) -> List[ChatMessage]:
|
||||||
"""Extract user query, find relevant context, and augment the messages."""
|
"""Extract user query, find relevant context, and augment the messages."""
|
||||||
|
|
||||||
|
|
@ -154,7 +159,9 @@ async def augment_query_with_context(
|
||||||
logger.info(f"Processing user query: '{last_user_message}'")
|
logger.info(f"Processing user query: '{last_user_message}'")
|
||||||
|
|
||||||
# Find relevant passages
|
# Find relevant passages
|
||||||
relevant_passages = await find_relevant_passages(last_user_message, traceparent, request_id)
|
relevant_passages = await find_relevant_passages(
|
||||||
|
last_user_message, traceparent, request_id
|
||||||
|
)
|
||||||
|
|
||||||
if not relevant_passages:
|
if not relevant_passages:
|
||||||
logger.info("No relevant passages found, returning original messages")
|
logger.info("No relevant passages found, returning original messages")
|
||||||
|
|
@ -208,7 +215,9 @@ async def context_builder(
|
||||||
logger.info("No traceparent header found")
|
logger.info("No traceparent header found")
|
||||||
|
|
||||||
# Augment the user query with relevant context
|
# Augment the user query with relevant context
|
||||||
updated_messages = await augment_query_with_context(messages, traceparent_header, request_id)
|
updated_messages = await augment_query_with_context(
|
||||||
|
messages, traceparent_header, request_id
|
||||||
|
)
|
||||||
|
|
||||||
# Return as dict to minimize text serialization
|
# Return as dict to minimize text serialization
|
||||||
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
return [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,9 @@ app = FastAPI(title="RAG Agent Input Guards", version="1.0.0")
|
||||||
|
|
||||||
|
|
||||||
async def validate_query_scope(
|
async def validate_query_scope(
|
||||||
messages: List[ChatMessage], traceparent_header: Optional[str] = None, request_id: Optional[str] = None
|
messages: List[ChatMessage],
|
||||||
|
traceparent_header: Optional[str] = None,
|
||||||
|
request_id: Optional[str] = None,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Validate that the user query is within TechCorp's domain.
|
"""Validate that the user query is within TechCorp's domain.
|
||||||
|
|
||||||
|
|
@ -146,7 +148,9 @@ async def input_guards(
|
||||||
logger.info("No traceparent header found")
|
logger.info("No traceparent header found")
|
||||||
|
|
||||||
# Validate the query scope
|
# Validate the query scope
|
||||||
validation_result = await validate_query_scope(messages, traceparent_header, request_id)
|
validation_result = await validate_query_scope(
|
||||||
|
messages, traceparent_header, request_id
|
||||||
|
)
|
||||||
|
|
||||||
if not validation_result.get("is_valid", True):
|
if not validation_result.get("is_valid", True):
|
||||||
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
reason = validation_result.get("reason", "Query is outside TechCorp's domain")
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
|
||||||
logger.info("No traceparent header found")
|
logger.info("No traceparent header found")
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
stream_chat_completions(request_body, traceparent_header,request_id),
|
stream_chat_completions(request_body, traceparent_header, request_id),
|
||||||
media_type="text/plain",
|
media_type="text/plain",
|
||||||
headers={
|
headers={
|
||||||
"content-type": "text/event-stream",
|
"content-type": "text/event-stream",
|
||||||
|
|
@ -86,7 +86,9 @@ async def chat_completion_http(request: Request, request_body: ChatCompletionReq
|
||||||
|
|
||||||
|
|
||||||
async def stream_chat_completions(
|
async def stream_chat_completions(
|
||||||
request_body: ChatCompletionRequest, traceparent_header: str = None, request_id: str = None
|
request_body: ChatCompletionRequest,
|
||||||
|
traceparent_header: str = None,
|
||||||
|
request_id: str = None,
|
||||||
):
|
):
|
||||||
"""Generate streaming chat completions."""
|
"""Generate streaming chat completions."""
|
||||||
# Prepare messages for response generation
|
# Prepare messages for response generation
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue