Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)

Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).
This commit is contained in:
cybermaggedon 2026-03-26 10:49:28 +00:00 committed by GitHub
parent 97f5645ea0
commit 4164ef1c47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 27 additions and 10 deletions

View file

@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
class AppContext:
sockets: dict[str, WebSocketManager]
websocket_url: str
gateway_token: str
@asynccontextmanager
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
"""
Manage application lifecycle with type-safe context
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
sockets = {}
try:
yield AppContext(sockets=sockets, websocket_url=websocket_url)
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
finally:
# Cleanup on shutdown
@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user):
lifespan_context = ctx.request_context.lifespan_context
sockets = lifespan_context.sockets
websocket_url = lifespan_context.websocket_url
gateway_token = lifespan_context.gateway_token
if user in sockets:
logging.info("Return existing socket manager")
return sockets[user]
logging.info(f"Opening socket to {websocket_url}...")
# Create manager with empty pending requests
manager = WebSocketManager(websocket_url)
manager = WebSocketManager(websocket_url, token=gateway_token)
# Start reader task with the proper manager
await manager.start()
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
prompt: str
class McpServer:
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
self.host = host
self.port = port
self.websocket_url = websocket_url
self.gateway_token = gateway_token
# Create a partial function to pass websocket_url to app_lifespan
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
self.mcp = FastMCP(
"TrustGraph", dependencies=["trustgraph-base"],
@ -2060,8 +2063,11 @@ def main():
# Setup logging before creating server
setup_logging(vars(args))
# Read gateway auth token from environment
gateway_token = os.environ.get("GATEWAY_SECRET", "")
# Create and run the MCP server
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
server.run()
def run():