trustgraph/trustgraph-mcp/trustgraph/mcp_server/tg_socket.py
cybermaggedon 4164ef1c47
Add GATEWAY_SECRET support for MCP server to API gateway auth (#721)
Pass bearer token from GATEWAY_SECRET environment variable as a
URL query parameter on websocket connections to the API gateway.
When unset or empty, no auth is applied (backwards compatible).
2026-03-26 10:49:28 +00:00

140 lines
4 KiB
Python

from dataclasses import dataclass
from websockets.asyncio.client import connect
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
import asyncio
import logging
import json
import uuid
import time
class WebSocketManager:
def __init__(self, url, token=None):
self.url = url
self.token = token
self.socket = None
def _build_url(self):
if not self.token:
return self.url
parsed = urlparse(self.url)
params = parse_qs(parsed.query)
params["token"] = [self.token]
new_query = urlencode(params, doseq=True)
return urlunparse(parsed._replace(query=new_query))
async def start(self):
self.socket = await connect(self._build_url())
self.pending_requests = {}
self.running = True
self.reader_task = asyncio.create_task(self.reader())
async def stop(self):
self.running = False
await self.reader_task
async def reader(self):
"""
Background task to read websocket responses and route to correct
request
"""
while self.running:
try:
try:
response_text = await asyncio.wait_for(
self.socket.recv(), 0.5
)
except TimeoutError:
continue
response = json.loads(response_text)
request_id = response.get("id")
if request_id and request_id in self.pending_requests:
# Put the response in the queue
queue = self.pending_requests[request_id]
await queue.put(response)
else:
logging.warning(
f"Response for unknown request ID: {request_id}"
)
except Exception as e:
logging.error(f"Error in websocket reader: {e}")
# Put error in all pending queues
for queue in self.pending_requests.values():
try:
await queue.put({"error": str(e)})
except:
pass
self.pending_requests.clear()
break
await self.socket.close()
self.socket = None
async def request(
self, service, request_data, flow_id="default",
):
"""
Send a request via websocket and handle single or streaming responses
"""
# Generate unique request ID
request_id = f"{uuid.uuid4()}"
# Determine if this service streams responses
streaming_services = {"agent"}
is_streaming = service in streaming_services
# Create a queue for all responses (streaming and single)
response_queue = asyncio.Queue()
self.pending_requests[request_id] = response_queue
try:
# Build request message
message = {
"id": request_id,
"service": service,
"request": request_data,
}
if flow_id is not None:
message["flow"] = flow_id
# Send request
await self.socket.send(json.dumps(message))
while self.running:
try:
response = await asyncio.wait_for(
response_queue.get(), 0.5
)
except TimeoutError:
continue
if "error" in response:
if "message" in response["error"]:
raise RuntimeError(response["error"]["text"])
else:
raise RuntimeError(str(response["error"]))
yield response["response"]
if "complete" in response:
if response["complete"]:
break
except Exception as e:
# Clean up on error
self.pending_requests.pop(request_id, None)
raise e