Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)

Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).
This commit is contained in:
cybermaggedon 2026-03-26 10:49:28 +00:00 committed by GitHub
parent 97f5645ea0
commit 4164ef1c47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 10 deletions

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext:
sockets: dict[str, WebSocketManager]
websocket_url: str
gateway_token: str
@asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
"""
Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {}
try:
yield AppContext(sockets=sockets, websocket_url=websocket_url)
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally:
# Cleanup on shutdown
@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets:
logging.info("Return existing socket manager")
return sockets[user]
logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests
manager = WebSocketManager(websocket_url)
manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager
await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str
class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host
self.port = port
self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"],
@ -2060,8 +2063,11 @@ def main():
# Setup logging before creating server
setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run()
def run():

View file

@ -1,6 +1,7 @@
from dataclasses import dataclass
from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
@ -9,12 +10,22 @@ import time
class WebSocketManager:
def __init__(self, url):
def __init__(self, url, token=None):
self.url = url
self.token = token
self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self):
self.socket = await connect(self.url)
self.socket = await connect(self._build_url())
self.pending_requests = {}
self.running = True
self.reader_task = asyncio.create_task(self.reader())