Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)

Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).
This commit is contained in:
cybermaggedon 2026-03-26 10:49:28 +00:00 committed by GitHub
parent 97f5645ea0
commit 4164ef1c47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 10 deletions

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext: class AppContext:
sockets: dict[str, WebSocketManager] sockets: dict[str, WebSocketManager]
websocket_url: str websocket_url: str
gateway_token: str
@asynccontextmanager @asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
""" """
Manage application lifecycle with type-safe context Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {} sockets = {}
try: try:
yield AppContext(sockets=sockets, websocket_url=websocket_url) yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally: finally:
# Cleanup on shutdown # Cleanup on shutdown
@ -53,6 +54,7 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets: if user in sockets:
logging.info("Return existing socket manager") logging.info("Return existing socket manager")
@ -61,7 +63,7 @@ async def get_socket_manager(ctx, user):
logging.info(f"Opening socket to {websocket_url}...") logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests # Create manager with empty pending requests
manager = WebSocketManager(websocket_url) manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager # Start reader task with the proper manager
await manager.start() await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str prompt: str
class McpServer: class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host self.host = host
self.port = port self.port = port
self.websocket_url = websocket_url self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan # Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP( self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"], "TrustGraph", dependencies=["trustgraph-base"],
@ -2060,8 +2063,11 @@ def main():
# Setup logging before creating server # Setup logging before creating server
setup_logging(vars(args)) setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server # Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run() server.run()
def run(): def run():

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from websockets.asyncio.client import connect from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio import asyncio
import logging import logging
import json import json
@ -9,12 +10,22 @@ import time
class WebSocketManager: class WebSocketManager:
def __init__(self, url): def __init__(self, url, token=None):
self.url = url self.url = url
self.token = token
self.socket = None self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self): async def start(self):
self.socket = await connect(self.url) self.socket = await connect(self._build_url())
self.pending_requests = {} self.pending_requests = {}
self.running = True self.running = True
self.reader_task = asyncio.create_task(self.reader()) self.reader_task = asyncio.create_task(self.reader())