Add GATEWAY_SECRET support for MCP server to API gateway auth (#718)

Add GATEWAY_SECRET support for MCP server to API gateway auth

Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).

Lock mistralai to fix incompatible API
This commit is contained in:
cybermaggedon 2026-03-26 10:51:44 +00:00 committed by GitHub
parent 6d8da748d7
commit 7c37c7569a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 11 deletions

View file

@ -27,7 +27,7 @@ dependencies = [
"langchain-text-splitters", "langchain-text-splitters",
"mcp", "mcp",
"minio", "minio",
"mistralai", "mistralai<2.0.0",
"neo4j", "neo4j",
"nltk", "nltk",
"ollama", "ollama",

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext: class AppContext:
sockets: dict[str, WebSocketManager] sockets: dict[str, WebSocketManager]
websocket_url: str websocket_url: str
gateway_token: str
@asynccontextmanager @asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
""" """
Manage application lifecycle with type-safe context Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {} sockets = {}
try: try:
yield AppContext(sockets=sockets, websocket_url=websocket_url) yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally: finally:
# Cleanup on shutdown # Cleanup on shutdown
@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets: if user in sockets:
logging.info("Return existing socket manager") logging.info("Return existing socket manager")
return sockets[user] return sockets[user]
logging.info(f"Opening socket to {websocket_url}...") logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests # Create manager with empty pending requests
manager = WebSocketManager(websocket_url) manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager # Start reader task with the proper manager
await manager.start() await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str prompt: str
class McpServer: class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host self.host = host
self.port = port self.port = port
self.websocket_url = websocket_url self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan # Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP( self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"], "TrustGraph", dependencies=["trustgraph-base"],
@ -2051,8 +2054,11 @@ def main():
# Setup logging before creating server # Setup logging before creating server
setup_logging(vars(args)) setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server # Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run() server.run()
def run(): def run():

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio import asyncio
import logging import logging
import json import json
@ -9,12 +10,22 @@ import time
class WebSocketManager: class WebSocketManager:
def __init__(self, url): def __init__(self, url, token=None):
self.url = url self.url = url
self.token = token
self.socket = None self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self): async def start(self):
self.socket = await connect(self.url) self.socket = await connect(self._build_url())
self.pending_requests = {} self.pending_requests = {}
self.running = True self.running = True
self.reader_task = asyncio.create_task(self.reader()) self.reader_task = asyncio.create_task(self.reader())