mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
Fix pre-commit formatting
This commit is contained in:
parent
91ddaa992b
commit
a3b070440d
4 changed files with 27 additions and 13 deletions
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
from .api import ChatMessage
|
||||
from fastapi import Request, FastAPI
|
||||
|
||||
# from . import mcp
|
||||
# from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -186,8 +187,11 @@ async def augment_query_with_context(
|
|||
# Load knowledge base on module import
|
||||
load_knowledge_base()
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def context_builder(messages: List[ChatMessage],request: Request) -> List[ChatMessage]:
|
||||
async def context_builder(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""MCP tool that augments user queries with relevant context from the knowledge base."""
|
||||
logger.info(f"Received chat completion request with {len(messages)} messages")
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import time
|
|||
from typing import List, Optional, Dict, Any
|
||||
import uuid
|
||||
from fastapi import FastAPI, Depends, Request, HTTPException
|
||||
|
||||
# from fastmcp.exceptions import ToolError
|
||||
from openai import AsyncOpenAI
|
||||
import os
|
||||
|
|
@ -11,6 +12,7 @@ import logging
|
|||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
||||
from . import mcp
|
||||
|
||||
# from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
# Set up logging
|
||||
|
|
@ -123,7 +125,9 @@ Respond in JSON format:
|
|||
|
||||
# @mcp.tool
|
||||
@app.post("/")
|
||||
async def input_guards(messages: List[ChatMessage], request: Request) -> List[ChatMessage]:
|
||||
async def input_guards(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""Input guard that validates queries are within TechCorp's domain.
|
||||
|
||||
If the query is out of scope, replaces the user message with a rejection notice.
|
||||
|
|
@ -136,7 +140,6 @@ async def input_guards(messages: List[ChatMessage], request: Request) -> List[Ch
|
|||
traceparent_header = request.headers.get("traceparent")
|
||||
request_id = request.headers.get("x-request-id")
|
||||
|
||||
|
||||
if traceparent_header:
|
||||
logger.info(f"Received traceparent header: {traceparent_header}")
|
||||
else:
|
||||
|
|
@ -152,11 +155,14 @@ async def input_guards(messages: List[ChatMessage], request: Request) -> List[Ch
|
|||
# Throw ToolError
|
||||
error_message = f"I apologize, but I can only assist with questions related to TechCorp and its services. Your query appears to be outside this scope. {reason}\n\nPlease ask me about TechCorp's products, services, pricing, SLAs, or technical support."
|
||||
# raise ToolError(error_message)
|
||||
raise HTTPException(status_code=400, detail={"error": "out_of_scope", "message": error_message})
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "out_of_scope", "message": error_message}
|
||||
)
|
||||
|
||||
logger.info("Query validation passed - forwarding to next filter")
|
||||
return messages
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
return {"status": "healthy"}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import os
|
|||
import logging
|
||||
|
||||
from .api import ChatCompletionRequest, ChatCompletionResponse, ChatMessage
|
||||
|
||||
# from . import mcp
|
||||
# from fastmcp.server.dependencies import get_http_headers
|
||||
|
||||
|
|
@ -82,9 +83,10 @@ Return only the rewritten query, nothing else."""
|
|||
return ""
|
||||
|
||||
|
||||
|
||||
@app.post("/")
|
||||
async def query_rewriter_http(messages: List[ChatMessage], request: Request) -> List[ChatMessage]:
|
||||
async def query_rewriter_http(
|
||||
messages: List[ChatMessage], request: Request
|
||||
) -> List[ChatMessage]:
|
||||
"""HTTP filter endpoint used by Plano (type: http)."""
|
||||
logger.info(f"Received request with {len(messages)} messages")
|
||||
|
||||
|
|
@ -96,7 +98,9 @@ async def query_rewriter_http(messages: List[ChatMessage], request: Request) ->
|
|||
else:
|
||||
logger.info("No traceparent header found")
|
||||
|
||||
rewritten_query = await rewrite_query_with_archgw(messages, traceparent_header, request_id)
|
||||
rewritten_query = await rewrite_query_with_archgw(
|
||||
messages, traceparent_header, request_id
|
||||
)
|
||||
# Create updated messages with the rewritten query
|
||||
updated_messages = messages.copy()
|
||||
|
||||
|
|
@ -109,22 +113,23 @@ async def query_rewriter_http(messages: List[ChatMessage], request: Request) ->
|
|||
f"Updated user query from '{original_query}' to '{rewritten_query}'"
|
||||
)
|
||||
break
|
||||
updated_messages_data = [{"role": msg.role, "content": msg.content} for msg in updated_messages]
|
||||
updated_messages_data = [
|
||||
{"role": msg.role, "content": msg.content} for msg in updated_messages
|
||||
]
|
||||
updated_messages = [ChatMessage(**msg) for msg in updated_messages_data]
|
||||
|
||||
logger.info("Returning rewritten chat completion response")
|
||||
return updated_messages
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
def start_server(host: str = "0.0.0.0", port: int = 10501):
|
||||
"""Start the FastAPI server for query rewriter."""
|
||||
import uvicorn
|
||||
|
||||
logger.info(f"Starting Query Rewriter REST server on {host}:{port}")
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -76,4 +76,3 @@ pids+=($!)
|
|||
for PID in "${pids[@]}"; do
|
||||
wait "$PID"
|
||||
done
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue