diff --git a/trustgraph-mcp/trustgraph/mcp_server/mcp.py b/trustgraph-mcp/trustgraph/mcp_server/mcp.py index e551ed5d..eadd841b 100755 --- a/trustgraph-mcp/trustgraph/mcp_server/mcp.py +++ b/trustgraph-mcp/trustgraph/mcp_server/mcp.py @@ -24,9 +24,10 @@ from . tg_socket import WebSocketManager class AppContext: sockets: dict[str, WebSocketManager] websocket_url: str + gateway_token: str @asynccontextmanager -async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket") -> AsyncIterator[AppContext]: +async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = "") -> AsyncIterator[AppContext]: """ Manage application lifecycle with type-safe context @@ -36,7 +37,7 @@ async def app_lifespan(server: FastMCP, websocket_url: str = "ws://api-gateway:8 sockets = {} try: - yield AppContext(sockets=sockets, websocket_url=websocket_url) + yield AppContext(sockets=sockets, websocket_url=websocket_url, gateway_token=gateway_token) finally: # Cleanup on shutdown @@ -53,15 +54,16 @@ async def get_socket_manager(ctx, user): lifespan_context = ctx.request_context.lifespan_context sockets = lifespan_context.sockets websocket_url = lifespan_context.websocket_url + gateway_token = lifespan_context.gateway_token if user in sockets: logging.info("Return existing socket manager") return sockets[user] logging.info(f"Opening socket to {websocket_url}...") - + # Create manager with empty pending requests - manager = WebSocketManager(websocket_url) + manager = WebSocketManager(websocket_url, token=gateway_token) # Start reader task with the proper manager await manager.start() @@ -193,13 +195,14 @@ class GetSystemPromptResponse: prompt: str class McpServer: - def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket"): + def __init__(self, host: str = "0.0.0.0", port: int = 8000, websocket_url: str = "ws://api-gateway:8088/api/v1/socket", gateway_token: str = ""): self.host = host self.port = port self.websocket_url = websocket_url - + self.gateway_token = gateway_token + # Create a partial function to pass websocket_url to app_lifespan - lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url) + lifespan_with_url = partial(app_lifespan, websocket_url=websocket_url, gateway_token=gateway_token) self.mcp = FastMCP( "TrustGraph", dependencies=["trustgraph-base"], @@ -2060,8 +2063,11 @@ def main(): # Setup logging before creating server setup_logging(vars(args)) + # Read gateway auth token from environment + gateway_token = os.environ.get("GATEWAY_SECRET", "") + # Create and run the MCP server - server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url) + server = McpServer(host=args.host, port=args.port, websocket_url=args.websocket_url, gateway_token=gateway_token) server.run() def run(): diff --git a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py index 44f1bf2e..d255ae14 100644 --- a/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py +++ b/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from websockets.asyncio.client import connect +from urllib.parse import urlencode, urlparse, urlunparse, parse_qs import asyncio import logging import json @@ -9,12 +10,22 @@ import time class WebSocketManager: - def __init__(self, url): + def __init__(self, url, token=None): self.url = url + self.token = token self.socket = None + def _build_url(self): + if not self.token: + return self.url + parsed = urlparse(self.url) + params = parse_qs(parsed.query) + params["token"] = [self.token] + new_query = urlencode(params, doseq=True) + return urlunparse(parsed._replace(query=new_query)) + async def start(self): - self.socket = await connect(self.url) + self.socket = await connect(self._build_url()) self.pending_requests = {} self.running = True self.reader_task = asyncio.create_task(self.reader())