mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-06-11 07:45:13 +02:00
Compare commits
1 commit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5420a20d29 |
2 changed files with 1044 additions and 996 deletions
File diff suppressed because it is too large
Load diff
|
|
@ -1,49 +1,110 @@
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from websockets.asyncio.client import connect
|
from websockets.asyncio.client import connect
|
||||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qs
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
import time
|
import hashlib
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _token_key(token):
|
||||||
|
"""Derive a dict key from a token without storing the raw secret."""
|
||||||
|
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||||
|
|
||||||
|
|
||||||
class WebSocketManager:
|
class WebSocketManager:
|
||||||
|
"""Manages an authenticated WebSocket connection to the TrustGraph
|
||||||
|
gateway on behalf of a single caller.
|
||||||
|
|
||||||
def __init__(self, url, token=None):
|
Each caller token gets its own WebSocketManager so that gateway-side
|
||||||
|
identity, workspace, and capability scoping are preserved end-to-end.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, url, token):
|
||||||
self.url = url
|
self.url = url
|
||||||
|
# ── Security boundary: token storage ──
|
||||||
|
# This is the MCP caller's Bearer token, forwarded verbatim to
|
||||||
|
# the gateway. It MUST NOT be logged, persisted, or shared
|
||||||
|
# across callers. It is held only for the lifetime of this
|
||||||
|
# connection so that re-auth (e.g. after a reconnect) is
|
||||||
|
# possible.
|
||||||
self.token = token
|
self.token = token
|
||||||
self.socket = None
|
self.socket = None
|
||||||
|
self.identity = None
|
||||||
# FIXME: authentication is broken. The /api/v1/socket endpoint uses
|
self.last_used = None
|
||||||
# in-band auth (first-frame protocol via the Mux dispatcher), not
|
|
||||||
# query-parameter tokens. This query-string token is silently ignored.
|
|
||||||
# Fix: after connect(), send an auth frame with the bearer token as
|
|
||||||
# the first message, matching the gateway's in-band auth protocol.
|
|
||||||
def _build_url(self):
|
|
||||||
if not self.token:
|
|
||||||
return self.url
|
|
||||||
parsed = urlparse(self.url)
|
|
||||||
params = parse_qs(parsed.query)
|
|
||||||
params["token"] = [self.token]
|
|
||||||
new_query = urlencode(params, doseq=True)
|
|
||||||
return urlunparse(parsed._replace(query=new_query))
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
self.socket = await connect(self._build_url())
|
"""Connect and authenticate via the gateway's in-band auth
|
||||||
|
protocol. Raises on auth failure."""
|
||||||
|
|
||||||
|
# ── Security boundary: MCP server → gateway ──
|
||||||
|
# The WebSocket connects to the gateway and authenticates using
|
||||||
|
# the caller's Bearer token via the in-band first-frame auth
|
||||||
|
# protocol. The token belongs to the MCP client — we forward
|
||||||
|
# it as-is and never interpret its contents.
|
||||||
|
self.socket = await connect(self.url)
|
||||||
self.pending_requests = {}
|
self.pending_requests = {}
|
||||||
self.running = True
|
self.running = True
|
||||||
|
|
||||||
|
await self._authenticate()
|
||||||
|
|
||||||
self.reader_task = asyncio.create_task(self.reader())
|
self.reader_task = asyncio.create_task(self.reader())
|
||||||
|
|
||||||
|
async def _authenticate(self):
|
||||||
|
"""Send in-band auth frame and wait for auth-ok / auth-failed.
|
||||||
|
|
||||||
|
The gateway expects ``{"type": "auth", "token": "..."}`` as the
|
||||||
|
first frame on a new WebSocket. Any service frame sent before
|
||||||
|
auth-ok is rejected.
|
||||||
|
"""
|
||||||
|
await self.socket.send(json.dumps({
|
||||||
|
"type": "auth",
|
||||||
|
"token": self.token,
|
||||||
|
}))
|
||||||
|
|
||||||
|
response_text = await asyncio.wait_for(self.socket.recv(), 10)
|
||||||
|
response = json.loads(response_text)
|
||||||
|
|
||||||
|
if response.get("type") == "auth-ok":
|
||||||
|
logger.info(
|
||||||
|
"WebSocket authenticated, default workspace: %s",
|
||||||
|
response.get("workspace"),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Auth failed — close immediately, do not leave an
|
||||||
|
# unauthenticated socket open.
|
||||||
|
await self.socket.close()
|
||||||
|
self.socket = None
|
||||||
|
|
||||||
|
if response.get("type") == "auth-failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Gateway rejected the authentication token"
|
||||||
|
)
|
||||||
|
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected auth response type: {response.get('type')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def whoami(self):
|
||||||
|
"""Verify the token by calling the gateway's whoami endpoint.
|
||||||
|
Returns the identity dict and caches it on ``self.identity``.
|
||||||
|
"""
|
||||||
|
gen = self.request("iam", {"operation": "whoami"}, flow_id=None)
|
||||||
|
async for response in gen:
|
||||||
|
self.identity = response
|
||||||
|
return response
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
self.running = False
|
self.running = False
|
||||||
|
if hasattr(self, "reader_task"):
|
||||||
await self.reader_task
|
await self.reader_task
|
||||||
|
|
||||||
async def reader(self):
|
async def reader(self):
|
||||||
"""
|
"""Background task: read WebSocket frames and route them to the
|
||||||
Background task to read websocket responses and route to correct
|
correct pending-request queue by ``id``."""
|
||||||
request
|
|
||||||
"""
|
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
try:
|
try:
|
||||||
|
|
@ -59,23 +120,21 @@ class WebSocketManager:
|
||||||
|
|
||||||
request_id = response.get("id")
|
request_id = response.get("id")
|
||||||
if request_id and request_id in self.pending_requests:
|
if request_id and request_id in self.pending_requests:
|
||||||
# Put the response in the queue
|
|
||||||
queue = self.pending_requests[request_id]
|
queue = self.pending_requests[request_id]
|
||||||
await queue.put(response)
|
await queue.put(response)
|
||||||
else:
|
else:
|
||||||
logging.warning(
|
logger.warning(
|
||||||
f"Response for unknown request ID: {request_id}"
|
"Response for unknown request ID: %s", request_id
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
||||||
logging.error(f"Error in websocket reader: {e}")
|
logger.error("Error in websocket reader: %s", e)
|
||||||
|
|
||||||
# Put error in all pending queues
|
|
||||||
for queue in self.pending_requests.values():
|
for queue in self.pending_requests.values():
|
||||||
try:
|
try:
|
||||||
await queue.put({"error": str(e)})
|
await queue.put({"error": str(e)})
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
self.pending_requests.clear()
|
self.pending_requests.clear()
|
||||||
|
|
@ -86,25 +145,29 @@ class WebSocketManager:
|
||||||
|
|
||||||
async def request(
|
async def request(
|
||||||
self, service, request_data, flow_id="default",
|
self, service, request_data, flow_id="default",
|
||||||
|
workspace=None,
|
||||||
):
|
):
|
||||||
"""
|
"""Send a request via WebSocket and yield responses.
|
||||||
Send a request via websocket and handle single or streaming responses
|
|
||||||
|
Args:
|
||||||
|
service: Gateway service name (e.g. "graph-rag", "config").
|
||||||
|
request_data: Inner request payload.
|
||||||
|
flow_id: Optional flow identifier. ``None`` omits the field
|
||||||
|
(workspace-level services don't use flows).
|
||||||
|
workspace: Optional workspace override. When ``None`` the
|
||||||
|
gateway uses the caller's default workspace.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Generate unique request ID
|
import time
|
||||||
|
self.last_used = time.monotonic()
|
||||||
|
|
||||||
request_id = f"{uuid.uuid4()}"
|
request_id = f"{uuid.uuid4()}"
|
||||||
|
|
||||||
# Determine if this service streams responses
|
|
||||||
streaming_services = {"agent"}
|
|
||||||
is_streaming = service in streaming_services
|
|
||||||
|
|
||||||
# Create a queue for all responses (streaming and single)
|
|
||||||
response_queue = asyncio.Queue()
|
response_queue = asyncio.Queue()
|
||||||
self.pending_requests[request_id] = response_queue
|
self.pending_requests[request_id] = response_queue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Build request message
|
|
||||||
message = {
|
message = {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"service": service,
|
"service": service,
|
||||||
|
|
@ -114,7 +177,16 @@ class WebSocketManager:
|
||||||
if flow_id is not None:
|
if flow_id is not None:
|
||||||
message["flow"] = flow_id
|
message["flow"] = flow_id
|
||||||
|
|
||||||
# Send request
|
# ── Security boundary: workspace scoping ──
|
||||||
|
# When the caller supplies a workspace, we set it on the
|
||||||
|
# message envelope. The gateway's enforce_workspace()
|
||||||
|
# validates that the authenticated identity is permitted
|
||||||
|
# to access the target workspace — we MUST NOT skip or
|
||||||
|
# override that check. When workspace is None, the
|
||||||
|
# gateway default-fills from the identity's bound workspace.
|
||||||
|
if workspace is not None:
|
||||||
|
message["workspace"] = workspace
|
||||||
|
|
||||||
await self.socket.send(json.dumps(message))
|
await self.socket.send(json.dumps(message))
|
||||||
|
|
||||||
while self.running:
|
while self.running:
|
||||||
|
|
@ -127,19 +199,17 @@ class WebSocketManager:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "error" in response:
|
if "error" in response:
|
||||||
if "message" in response["error"]:
|
if isinstance(response["error"], dict):
|
||||||
raise RuntimeError(response["error"]["text"])
|
raise RuntimeError(
|
||||||
|
response["error"].get("message", str(response["error"]))
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError(str(response["error"]))
|
raise RuntimeError(str(response["error"]))
|
||||||
|
|
||||||
yield response["response"]
|
yield response["response"]
|
||||||
|
|
||||||
if "complete" in response:
|
if response.get("complete"):
|
||||||
if response["complete"]:
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
finally:
|
||||||
# Clean up on error
|
|
||||||
self.pending_requests.pop(request_id, None)
|
self.pending_requests.pop(request_id, None)
|
||||||
raise e
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue