mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Add AUTH_GATEWAY support for MCP server to API gateway auth (#717)
Add GATEWAY_SECRET support for MCP server to API gateway auth Pass bearer token from GATEWAY_SECRET environment variable as a URL query parameter on websocket connections to the API gateway. When unset or empty, no auth is applied (backwards compatible). Lock mistralai to fix incompatible API
This commit is contained in:
parent
3bebd12ccc
commit
1febfccc8a
3 changed files with 28 additions and 11 deletions
|
|
@ -28,7 +28,7 @@ dependencies = [
|
|||
"langchain-text-splitters",
|
||||
"mcp",
|
||||
"minio",
|
||||
"mistralai",
|
||||
"mistralai<2.0.0",
|
||||
"neo4j",
|
||||
"nltk",
|
||||
"ollama",
|
||||
|
|
|
|||
|
|
@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
|
|||
class AppContext:
|
||||
sockets: dict[str, WebSocketManager]
|
||||
websocket_url: str
|
||||
gateway_token: str
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
|
||||
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
|
||||
|
||||
"""
|
||||
Manage application lifecycle with type-safe context
|
||||
|
|
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
|
|||
sockets = {}
|
||||
|
||||
try:
|
||||
yield AppContext(sockets=sockets, websocket_url=websocket_url)
|
||||
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||
finally:
|
||||
|
||||
# Cleanup on shutdown
|
||||
|
|
@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user):
|
|||
lifespan_context = ctx.request_context.lifespan_context
|
||||
sockets = lifespan_context.sockets
|
||||
websocket_url = lifespan_context.websocket_url
|
||||
gateway_token = lifespan_context.gateway_token
|
||||
|
||||
if user in sockets:
|
||||
logging.info("Return existing socket manager")
|
||||
return sockets[user]
|
||||
|
||||
logging.info(f"Opening socket to {websocket_url}...")
|
||||
|
||||
|
||||
# Create manager with empty pending requests
|
||||
manager = WebSocketManager(websocket_url)
|
||||
manager = WebSocketManager(websocket_url, token=gateway_token)
|
||||
|
||||
# Start reader task with the proper manager
|
||||
await manager.start()
|
||||
|
|
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
|
|||
prompt: str
|
||||
|
||||
class McpServer:
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
|
||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.websocket_url = websocket_url
|
||||
|
||||
self.gateway_token = gateway_token
|
||||
|
||||
# Create a partial function to pass websocket_url to app_lifespan
|
||||
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
|
||||
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||
|
||||
self.mcp = FastMCP(
|
||||
"TrustGraph", dependencies=["trustgraph-base"],
|
||||
|
|
@ -2051,8 +2054,11 @@ def main():
|
|||
# Setup logging before creating server
|
||||
setup_logging(vars(args))
|
||||
|
||||
# Read gateway auth token from environment
|
||||
gateway_token = os.environ.get("GATEWAY_SECRET", "")
|
||||
|
||||
# Create and run the MCP server
|
||||
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
|
||||
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
|
||||
server.run()
|
||||
|
||||
def run():
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
|
||||
from dataclasses import dataclass
|
||||
from websockets.asyncio.client import connect
|
||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
||||
import asyncio
|
||||
import logging
|
||||
import json
|
||||
|
|
@ -9,12 +10,22 @@ import time
|
|||
|
||||
class WebSocketManager:
|
||||
|
||||
def __init__(self, url):
|
||||
def __init__(self, url, token=None):
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.socket = None
|
||||
|
||||
def _build_url(self):
|
||||
if not self.token:
|
||||
return self.url
|
||||
parsed = urlparse(self.url)
|
||||
params = parse_qs(parsed.query)
|
||||
params["token"] = [self.token]
|
||||
new_query = urlencode(params, doseq=True)
|
||||
return urlunparse(parsed._replace(query=new_query))
|
||||
|
||||
async def start(self):
|
||||
self.socket = await connect(self.url)
|
||||
self.socket = await connect(self._build_url())
|
||||
self.pending_requests = {}
|
||||
self.running = True
|
||||
self.reader_task = asyncio.create_task(self.reader())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue