mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 16:36:21 +02:00
Persistent websocket connections for socket clients and CLI tools (#723)
Replace per-request websocket connections in SocketClient and AsyncSocketClient with a single persistent connection that multiplexes requests by ID via a background reader task. This eliminates repeated TCP+WS handshakes which caused significant latency over proxies. Convert show_flows, show_flow_blueprints, and show_parameter_types CLI tools from sequential HTTP requests to concurrent websocket requests using AsyncSocketClient, reducing round trips from O(N) sequential to a small number of parallel batches. Also fix describe_interfaces bug in show_flows where response queue was reading the request field instead of the response field.
This commit is contained in:
parent
1ec081f42f
commit
9c55a0a0ff
6 changed files with 654 additions and 1067 deletions
|
|
@ -1,5 +1,6 @@
|
|||
|
||||
import json
|
||||
import asyncio
|
||||
import websockets
|
||||
from typing import Optional, Dict, Any, AsyncIterator, Union
|
||||
|
||||
|
|
@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException
|
|||
|
||||
|
||||
class AsyncSocketClient:
|
||||
"""Asynchronous WebSocket client"""
|
||||
"""Asynchronous WebSocket client with persistent connection.
|
||||
|
||||
Maintains a single websocket connection and multiplexes requests
|
||||
by ID, routing responses via a background reader task.
|
||||
|
||||
Use as an async context manager for proper lifecycle management:
|
||||
|
||||
async with AsyncSocketClient(url, timeout, token) as client:
|
||||
result = await client._send_request(...)
|
||||
|
||||
Or call connect()/aclose() manually.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, timeout: int, token: Optional[str]):
|
||||
self.url = self._convert_to_ws_url(url)
|
||||
self.timeout = timeout
|
||||
self.token = token
|
||||
self._request_counter = 0
|
||||
self._socket = None
|
||||
self._connect_cm = None
|
||||
self._reader_task = None
|
||||
self._pending = {} # request_id -> asyncio.Queue
|
||||
self._connected = False
|
||||
|
||||
def _convert_to_ws_url(self, url: str) -> str:
|
||||
"""Convert HTTP URL to WebSocket URL"""
|
||||
|
|
@ -25,82 +42,123 @@ class AsyncSocketClient:
|
|||
elif url.startswith("ws://") or url.startswith("wss://"):
|
||||
return url
|
||||
else:
|
||||
# Assume ws://
|
||||
return f"ws://{url}"
|
||||
|
||||
def _build_ws_url(self):
|
||||
ws_url = f"{self.url.rstrip('/')}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
return ws_url
|
||||
|
||||
async def connect(self):
|
||||
"""Establish the persistent websocket connection."""
|
||||
if self._connected:
|
||||
return
|
||||
|
||||
ws_url = self._build_ws_url()
|
||||
self._connect_cm = websockets.connect(
|
||||
ws_url, ping_interval=20, ping_timeout=self.timeout
|
||||
)
|
||||
self._socket = await self._connect_cm.__aenter__()
|
||||
self._connected = True
|
||||
self._reader_task = asyncio.create_task(self._reader())
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.aclose()
|
||||
|
||||
async def _ensure_connected(self):
|
||||
"""Lazily connect if not already connected."""
|
||||
if not self._connected:
|
||||
await self.connect()
|
||||
|
||||
async def _reader(self):
|
||||
"""Background task to read responses and route by request ID."""
|
||||
try:
|
||||
async for raw_message in self._socket:
|
||||
response = json.loads(raw_message)
|
||||
request_id = response.get("id")
|
||||
|
||||
if request_id and request_id in self._pending:
|
||||
await self._pending[request_id].put(response)
|
||||
# Ignore messages for unknown request IDs
|
||||
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
except Exception as e:
|
||||
# Signal error to all pending requests
|
||||
for queue in self._pending.values():
|
||||
try:
|
||||
await queue.put({"error": str(e)})
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._connected = False
|
||||
|
||||
def _next_request_id(self):
|
||||
self._request_counter += 1
|
||||
return f"req-{self._request_counter}"
|
||||
|
||||
def flow(self, flow_id: str):
|
||||
"""Get async flow instance for WebSocket operations"""
|
||||
return AsyncSocketFlowInstance(self, flow_id)
|
||||
|
||||
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||
"""Async WebSocket request implementation (non-streaming)"""
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
"""Send a request and wait for a single response."""
|
||||
await self._ensure_connected()
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
request_id = self._next_request_id()
|
||||
queue = asyncio.Queue()
|
||||
self._pending[request_id] = queue
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
try:
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
await self._socket.send(json.dumps(message))
|
||||
|
||||
# Wait for single response
|
||||
raw_message = await websocket.recv()
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
raise ProtocolException(f"Response ID mismatch")
|
||||
response = await queue.get()
|
||||
|
||||
if "error" in response:
|
||||
raise ApplicationException(response["error"])
|
||||
|
||||
if "response" not in response:
|
||||
raise ProtocolException(f"Missing response in message")
|
||||
raise ProtocolException("Missing response in message")
|
||||
|
||||
return response["response"]
|
||||
|
||||
finally:
|
||||
self._pending.pop(request_id, None)
|
||||
|
||||
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
|
||||
"""Async WebSocket request implementation (streaming)"""
|
||||
# Generate unique request ID
|
||||
self._request_counter += 1
|
||||
request_id = f"req-{self._request_counter}"
|
||||
"""Send a request and yield streaming response chunks."""
|
||||
await self._ensure_connected()
|
||||
|
||||
# Build WebSocket URL with optional token
|
||||
ws_url = f"{self.url}/api/v1/socket"
|
||||
if self.token:
|
||||
ws_url = f"{ws_url}?token={self.token}"
|
||||
request_id = self._next_request_id()
|
||||
queue = asyncio.Queue()
|
||||
self._pending[request_id] = queue
|
||||
|
||||
# Build request message
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
try:
|
||||
message = {
|
||||
"id": request_id,
|
||||
"service": service,
|
||||
"request": request
|
||||
}
|
||||
if flow:
|
||||
message["flow"] = flow
|
||||
|
||||
# Connect and send request
|
||||
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
|
||||
await websocket.send(json.dumps(message))
|
||||
await self._socket.send(json.dumps(message))
|
||||
|
||||
# Yield chunks as they arrive
|
||||
async for raw_message in websocket:
|
||||
response = json.loads(raw_message)
|
||||
|
||||
if response.get("id") != request_id:
|
||||
continue # Ignore messages for other requests
|
||||
while True:
|
||||
response = await queue.get()
|
||||
|
||||
if "error" in response:
|
||||
raise ApplicationException(response["error"])
|
||||
|
|
@ -108,18 +166,16 @@ class AsyncSocketClient:
|
|||
if "response" in response:
|
||||
resp = response["response"]
|
||||
|
||||
# Parse different chunk types
|
||||
chunk = self._parse_chunk(resp)
|
||||
if chunk is not None: # Skip provenance messages in streaming
|
||||
if chunk is not None:
|
||||
yield chunk
|
||||
|
||||
# Check if this is the final message
|
||||
# end_of_session indicates entire session is complete (including provenance)
|
||||
# end_of_dialog is for agent dialogs
|
||||
# complete is from the gateway envelope
|
||||
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
|
||||
break
|
||||
|
||||
finally:
|
||||
self._pending.pop(request_id, None)
|
||||
|
||||
def _parse_chunk(self, resp: Dict[str, Any]):
|
||||
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
|
||||
chunk_type = resp.get("chunk_type")
|
||||
|
|
@ -127,7 +183,6 @@ class AsyncSocketClient:
|
|||
|
||||
# Handle new GraphRAG message format with message_type
|
||||
if message_type == "provenance":
|
||||
# Provenance messages are not yielded to user - they're metadata
|
||||
return None
|
||||
|
||||
if chunk_type == "thought":
|
||||
|
|
@ -147,25 +202,41 @@ class AsyncSocketClient:
|
|||
end_of_dialog=resp.get("end_of_dialog", False)
|
||||
)
|
||||
elif chunk_type == "action":
|
||||
# Agent action chunks - treat as thoughts for display purposes
|
||||
return AgentThought(
|
||||
content=resp.get("content", ""),
|
||||
end_of_message=resp.get("end_of_message", False)
|
||||
)
|
||||
else:
|
||||
# RAG-style chunk (or generic chunk with message_type="chunk")
|
||||
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
|
||||
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
|
||||
return RAGChunk(
|
||||
content=content,
|
||||
end_of_stream=resp.get("end_of_stream", False),
|
||||
error=None # Errors are always thrown, never stored
|
||||
error=None
|
||||
)
|
||||
|
||||
async def aclose(self):
|
||||
"""Close WebSocket connection"""
|
||||
# Cleanup handled by context manager
|
||||
pass
|
||||
"""Close the persistent WebSocket connection cleanly."""
|
||||
# Wait for reader to finish (socket close will cause it to exit)
|
||||
if self._reader_task:
|
||||
self._reader_task.cancel()
|
||||
try:
|
||||
await self._reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._reader_task = None
|
||||
|
||||
# Exit the websockets context manager — this cleanly shuts down
|
||||
# the connection and its keepalive task
|
||||
if self._connect_cm:
|
||||
try:
|
||||
await self._connect_cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
pass
|
||||
self._connect_cm = None
|
||||
|
||||
self._socket = None
|
||||
self._connected = False
|
||||
self._pending.clear()
|
||||
|
||||
|
||||
class AsyncSocketFlowInstance:
|
||||
|
|
@ -292,7 +363,6 @@ class AsyncSocketFlowInstance:
|
|||
|
||||
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
|
||||
"""Query graph embeddings for semantic search"""
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
|
|
@ -362,7 +432,6 @@ class AsyncSocketFlowInstance:
|
|||
limit: int = 10, **kwargs
|
||||
):
|
||||
"""Query row embeddings for semantic search on structured data"""
|
||||
# First convert text to embedding vector
|
||||
emb_result = await self.embeddings(texts=[text])
|
||||
vector = emb_result.get("vectors", [[]])[0]
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
Loading…
Add table
Add a link
Reference in a new issue