Persistent websocket connections for socket clients and CLI tools (#723)

Replace per-request websocket connections in SocketClient and
AsyncSocketClient with a single persistent connection that
multiplexes requests by ID via a background reader task. This
eliminates repeated TCP+WS handshakes which caused significant
latency over proxies.

Convert show_flows, show_flow_blueprints, and
show_parameter_types CLI tools from sequential HTTP requests to
concurrent websocket requests using AsyncSocketClient, reducing
round trips from O(N) sequential to a small number of parallel
batches.

Also fix describe_interfaces bug in show_flows where response
queue was reading the request field instead of the response
field.
This commit is contained in:
cybermaggedon 2026-03-26 16:46:28 +00:00 committed by GitHub
parent 1ec081f42f
commit 9c55a0a0ff
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 654 additions and 1067 deletions

View file

@ -1,5 +1,6 @@
import json
import asyncio
import websockets
from typing import Optional, Dict, Any, AsyncIterator, Union
@ -8,13 +9,29 @@ from . exceptions import ProtocolException, ApplicationException
class AsyncSocketClient:
"""Asynchronous WebSocket client"""
"""Asynchronous WebSocket client with persistent connection.
Maintains a single websocket connection and multiplexes requests
by ID, routing responses via a background reader task.
Use as an async context manager for proper lifecycle management:
async with AsyncSocketClient(url, timeout, token) as client:
result = await client._send_request(...)
Or call connect()/aclose() manually.
"""
def __init__(self, url: str, timeout: int, token: Optional[str]):
self.url = self._convert_to_ws_url(url)
self.timeout = timeout
self.token = token
self._request_counter = 0
self._socket = None
self._connect_cm = None
self._reader_task = None
self._pending = {} # request_id -> asyncio.Queue
self._connected = False
def _convert_to_ws_url(self, url: str) -> str:
"""Convert HTTP URL to WebSocket URL"""
@ -25,82 +42,123 @@ class AsyncSocketClient:
elif url.startswith("ws://") or url.startswith("wss://"):
return url
else:
# Assume ws://
return f"ws://{url}"
def _build_ws_url(self):
ws_url = f"{self.url.rstrip('/')}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
return ws_url
async def connect(self):
"""Establish the persistent websocket connection."""
if self._connected:
return
ws_url = self._build_ws_url()
self._connect_cm = websockets.connect(
ws_url, ping_interval=20, ping_timeout=self.timeout
)
self._socket = await self._connect_cm.__aenter__()
self._connected = True
self._reader_task = asyncio.create_task(self._reader())
async def __aenter__(self):
await self.connect()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.aclose()
async def _ensure_connected(self):
"""Lazily connect if not already connected."""
if not self._connected:
await self.connect()
async def _reader(self):
"""Background task to read responses and route by request ID."""
try:
async for raw_message in self._socket:
response = json.loads(raw_message)
request_id = response.get("id")
if request_id and request_id in self._pending:
await self._pending[request_id].put(response)
# Ignore messages for unknown request IDs
except websockets.exceptions.ConnectionClosed:
pass
except Exception as e:
# Signal error to all pending requests
for queue in self._pending.values():
try:
await queue.put({"error": str(e)})
except:
pass
finally:
self._connected = False
def _next_request_id(self):
self._request_counter += 1
return f"req-{self._request_counter}"
def flow(self, flow_id: str):
"""Get async flow instance for WebSocket operations"""
return AsyncSocketFlowInstance(self, flow_id)
async def _send_request(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (non-streaming)"""
# Generate unique request ID
self._request_counter += 1
request_id = f"req-{self._request_counter}"
"""Send a request and wait for a single response."""
await self._ensure_connected()
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
request_id = self._next_request_id()
queue = asyncio.Queue()
self._pending[request_id] = queue
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
try:
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
await self._socket.send(json.dumps(message))
# Wait for single response
raw_message = await websocket.recv()
response = json.loads(raw_message)
if response.get("id") != request_id:
raise ProtocolException(f"Response ID mismatch")
response = await queue.get()
if "error" in response:
raise ApplicationException(response["error"])
if "response" not in response:
raise ProtocolException(f"Missing response in message")
raise ProtocolException("Missing response in message")
return response["response"]
finally:
self._pending.pop(request_id, None)
async def _send_request_streaming(self, service: str, flow: Optional[str], request: Dict[str, Any]):
"""Async WebSocket request implementation (streaming)"""
# Generate unique request ID
self._request_counter += 1
request_id = f"req-{self._request_counter}"
"""Send a request and yield streaming response chunks."""
await self._ensure_connected()
# Build WebSocket URL with optional token
ws_url = f"{self.url}/api/v1/socket"
if self.token:
ws_url = f"{ws_url}?token={self.token}"
request_id = self._next_request_id()
queue = asyncio.Queue()
self._pending[request_id] = queue
# Build request message
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
try:
message = {
"id": request_id,
"service": service,
"request": request
}
if flow:
message["flow"] = flow
# Connect and send request
async with websockets.connect(ws_url, ping_interval=20, ping_timeout=self.timeout) as websocket:
await websocket.send(json.dumps(message))
await self._socket.send(json.dumps(message))
# Yield chunks as they arrive
async for raw_message in websocket:
response = json.loads(raw_message)
if response.get("id") != request_id:
continue # Ignore messages for other requests
while True:
response = await queue.get()
if "error" in response:
raise ApplicationException(response["error"])
@ -108,18 +166,16 @@ class AsyncSocketClient:
if "response" in response:
resp = response["response"]
# Parse different chunk types
chunk = self._parse_chunk(resp)
if chunk is not None: # Skip provenance messages in streaming
if chunk is not None:
yield chunk
# Check if this is the final message
# end_of_session indicates entire session is complete (including provenance)
# end_of_dialog is for agent dialogs
# complete is from the gateway envelope
if resp.get("end_of_session") or resp.get("end_of_dialog") or response.get("complete"):
break
finally:
self._pending.pop(request_id, None)
def _parse_chunk(self, resp: Dict[str, Any]):
"""Parse response chunk into appropriate type. Returns None for non-content messages."""
chunk_type = resp.get("chunk_type")
@ -127,7 +183,6 @@ class AsyncSocketClient:
# Handle new GraphRAG message format with message_type
if message_type == "provenance":
# Provenance messages are not yielded to user - they're metadata
return None
if chunk_type == "thought":
@ -147,25 +202,41 @@ class AsyncSocketClient:
end_of_dialog=resp.get("end_of_dialog", False)
)
elif chunk_type == "action":
# Agent action chunks - treat as thoughts for display purposes
return AgentThought(
content=resp.get("content", ""),
end_of_message=resp.get("end_of_message", False)
)
else:
# RAG-style chunk (or generic chunk with message_type="chunk")
# Text-completion uses "response" field, RAG uses "chunk" field, Prompt uses "text" field
content = resp.get("response", resp.get("chunk", resp.get("text", "")))
return RAGChunk(
content=content,
end_of_stream=resp.get("end_of_stream", False),
error=None # Errors are always thrown, never stored
error=None
)
async def aclose(self):
"""Close WebSocket connection"""
# Cleanup handled by context manager
pass
"""Close the persistent WebSocket connection cleanly."""
# Wait for reader to finish (socket close will cause it to exit)
if self._reader_task:
self._reader_task.cancel()
try:
await self._reader_task
except asyncio.CancelledError:
pass
self._reader_task = None
# Exit the websockets context manager — this cleanly shuts down
# the connection and its keepalive task
if self._connect_cm:
try:
await self._connect_cm.__aexit__(None, None, None)
except Exception:
pass
self._connect_cm = None
self._socket = None
self._connected = False
self._pending.clear()
class AsyncSocketFlowInstance:
@ -292,7 +363,6 @@ class AsyncSocketFlowInstance:
async def graph_embeddings_query(self, text: str, user: str, collection: str, limit: int = 10, **kwargs):
"""Query graph embeddings for semantic search"""
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vector = emb_result.get("vectors", [[]])[0]
@ -362,7 +432,6 @@ class AsyncSocketFlowInstance:
limit: int = 10, **kwargs
):
"""Query row embeddings for semantic search on structured data"""
# First convert text to embedding vector
emb_result = await self.embeddings(texts=[text])
vector = emb_result.get("vectors", [[]])[0]

File diff suppressed because it is too large Load diff