Add AUTH_GATEWAY support for MCP server to API gateway auth (#717)

Add GATEWAY_SECRET support for MCP server to API gateway auth

Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).

Lock mistralai to fix incompatible API
This commit is contained in:
cybermaggedon 2026-03-26 10:51:25 +00:00 committed by GitHub
parent 3bebd12ccc
commit 1febfccc8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 28 additions and 11 deletions

View file

@ -28,7 +28,7 @@ dependencies = [
"langchain-text-splitters",
"mcp",
"minio",
"mistralai",
"mistralai<2.0.0",
"neo4j",
"nltk",
"ollama",

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext:
sockets: dict[str, WebSocketManager]
websocket_url: str
gateway_token: str
@asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
"""
Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {}
try:
yield AppContext(sockets=sockets, websocket_url=websocket_url)
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally:
# Cleanup on shutdown
@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets:
logging.info("Return existing socket manager")
return sockets[user]
logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests
manager = WebSocketManager(websocket_url)
manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager
await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str
class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host
self.port = port
self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"],
@ -2051,8 +2054,11 @@ def main():
# Setup logging before creating server
setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run()
def run():

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass
from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
@ -9,12 +10,22 @@ import time
class WebSocketManager:
def __init__(self, url):
def __init__(self, url, token=None):
self.url = url
self.token = token
self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self):
self.socket = await connect(self.url)
self.socket = await connect(self._build_url())
self.pending_requests = {}
self.running = True
self.reader_task = asyncio.create_task(self.reader())