mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 08:26:21 +02:00
Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)
Pass bearer token from GATEWAY_SECRET environment variable as a URL query parameter on websocket connections to the API gateway. When unset or empty, no auth is applied (backwards compatible).
This commit is contained in:
parent
97f5645ea0
commit
4164ef1c47
2 changed files with 27 additions and 10 deletions
|
|
@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager
|
||||||
class AppContext:
|
class AppContext:
|
||||||
sockets: dict[str, WebSocketManager]
|
sockets: dict[str, WebSocketManager]
|
||||||
websocket_url: str
|
websocket_url: str
|
||||||
|
gateway_token: str
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]:
|
async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Manage application lifecycle with type-safe context
|
Manage application lifecycle with type-safe context
|
||||||
|
|
@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8
|
||||||
sockets = {}
|
sockets = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield AppContext(sockets=sockets, websocket_url=websocket_url)
|
yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||||
finally:
|
finally:
|
||||||
|
|
||||||
# Cleanup on shutdown
|
# Cleanup on shutdown
|
||||||
|
|
@ -53,6 +54,7 @@ async def get_socket_manager(ctx, user):
|
||||||
lifespan_context = ctx.request_context.lifespan_context
|
lifespan_context = ctx.request_context.lifespan_context
|
||||||
sockets = lifespan_context.sockets
|
sockets = lifespan_context.sockets
|
||||||
websocket_url = lifespan_context.websocket_url
|
websocket_url = lifespan_context.websocket_url
|
||||||
|
gateway_token = lifespan_context.gateway_token
|
||||||
|
|
||||||
if user in sockets:
|
if user in sockets:
|
||||||
logging.info("Return existing socket manager")
|
logging.info("Return existing socket manager")
|
||||||
|
|
@ -61,7 +63,7 @@ async def get_socket_manager(ctx, user):
|
||||||
logging.info(f"Opening socket to {websocket_url}...")
|
logging.info(f"Opening socket to {websocket_url}...")
|
||||||
|
|
||||||
# Create manager with empty pending requests
|
# Create manager with empty pending requests
|
||||||
manager = WebSocketManager(websocket_url)
|
manager = WebSocketManager(websocket_url, token=gateway_token)
|
||||||
|
|
||||||
# Start reader task with the proper manager
|
# Start reader task with the proper manager
|
||||||
await manager.start()
|
await manager.start()
|
||||||
|
|
@ -193,13 +195,14 @@ class GetSystemPromptResponse:
|
||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
class McpServer:
|
class McpServer:
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"):
|
def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""):
|
||||||
self.host = host
|
self.host = host
|
||||||
self.port = port
|
self.port = port
|
||||||
self.websocket_url = websocket_url
|
self.websocket_url = websocket_url
|
||||||
|
self.gateway_token = gateway_token
|
||||||
|
|
||||||
# Create a partial function to pass websocket_url to app_lifespan
|
# Create a partial function to pass websocket_url to app_lifespan
|
||||||
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url)
|
lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token)
|
||||||
|
|
||||||
self.mcp = FastMCP(
|
self.mcp = FastMCP(
|
||||||
"TrustGraph", dependencies=["trustgraph-base"],
|
"TrustGraph", dependencies=["trustgraph-base"],
|
||||||
|
|
@ -2060,8 +2063,11 @@ def main():
|
||||||
# Setup logging before creating server
|
# Setup logging before creating server
|
||||||
setup_logging(vars(args))
|
setup_logging(vars(args))
|
||||||
|
|
||||||
|
# Read gateway auth token from environment
|
||||||
|
gateway_token = os.environ.get("GATEWAY_SECRET", "")
|
||||||
|
|
||||||
# Create and run the MCP server
|
# Create and run the MCP server
|
||||||
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url)
|
server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token)
|
||||||
server.run()
|
server.run()
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from websockets.asyncio.client import connect
|
from websockets.asyncio.client import connect
|
||||||
|
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
|
|
@ -9,12 +10,22 @@ import time
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
|
|
||||||
def __init__(self, url):
|
def __init__(self, url, token=None):
|
||||||
self.url = url
|
self.url = url
|
||||||
|
self.token = token
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
|
||||||
|
def _build_url(self):
|
||||||
|
if not self.token:
|
||||||
|
return self.url
|
||||||
|
parsed = urlparse(self.url)
|
||||||
|
params = parse_qs(parsed.query)
|
||||||
|
params["token"] = [self.token]
|
||||||
|
new_query = urlencode(params, doseq=True)
|
||||||
|
return urlunparse(parsed._replace(query=new_query))
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.socket = await connect(self.url)
|
self.socket = await connect(self._build_url())
|
||||||
self.pending_requests = {}
|
self.pending_requests = {}
|
||||||
self.running = True
|
self.running = True
|
||||||
self.reader_task = asyncio.create_task(self.reader())
|
self.reader_task = asyncio.create_task(self.reader())
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue