Smoke-test websocket tool (#852)

This commit is contained in:
cybermaggedon 2026-04-28 15:05:35 +01:00 committed by GitHub
parent 666af1c4b3
commit b15f1a167c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -0,0 +1,475 @@
#!/usr/bin/env python3
"""
WebSocket smoke / load test that hammers a TrustGraph gateway with a
mix of `embeddings`, `graph-embeddings`, and `triples` queries while
keeping a target number of in-flight requests at all times.
Useful for reproducing the "worker hangs after a while, all subsequent
requests time out" failure mode — leaves enough load on the system to
saturate worker concurrency and reports per-service success/timeout
rates and latency distributions over time.
Usage:
smoke_ws_queries.py --flow onto-rag --duration 120 --concurrency 20
Connects via /api/v1/socket using the first-frame auth protocol.
"""
import argparse
import asyncio
import json
import os
import random
import statistics
import sys
import time
import uuid
from collections import defaultdict
from typing import Any
import websockets
DEFAULT_TEXT = (
"What caused the space shuttle to explode and what were the "
"main factors leading to the disaster?"
)
class Stats:
"""Per-service rolling counters and latency samples."""
def __init__(self) -> None:
self.sent = 0
self.ok = 0
self.err = 0
self.timeout = 0
self.latencies_ms: list[float] = []
def record_ok(self, latency_ms: float) -> None:
self.ok += 1
self.latencies_ms.append(latency_ms)
def record_err(self) -> None:
self.err += 1
def record_timeout(self) -> None:
self.timeout += 1
def percentile(self, p: float) -> float:
if not self.latencies_ms:
return 0.0
s = sorted(self.latencies_ms)
idx = min(len(s) - 1, int(len(s) * p))
return s[idx]
def summary(self) -> str:
if self.latencies_ms:
mn = min(self.latencies_ms)
mx = max(self.latencies_ms)
mean = statistics.mean(self.latencies_ms)
p50 = self.percentile(0.50)
p95 = self.percentile(0.95)
p99 = self.percentile(0.99)
lat = (
f"min={mn:.0f} mean={mean:.0f} p50={p50:.0f} "
f"p95={p95:.0f} p99={p99:.0f} max={mx:.0f} ms"
)
else:
lat = "no successful samples"
return (
f"sent={self.sent} ok={self.ok} err={self.err} "
f"timeout={self.timeout} | {lat}"
)
class WSClient:
"""Thin async websocket client with first-frame auth and a shared
reader task that demuxes responses to per-request asyncio queues."""
def __init__(
self, url: str, token: str | None, workspace: str,
ping_timeout: int,
) -> None:
self.url = url
self.token = token
self.workspace = workspace
self.ping_timeout = ping_timeout
self._ws: Any = None
self._pending: dict[str, asyncio.Queue] = {}
self._reader_task: asyncio.Task | None = None
self._closed = asyncio.Event()
async def connect(self) -> None:
ws_url = self.url.rstrip("/") + "/api/v1/socket"
if ws_url.startswith("http://"):
ws_url = "ws://" + ws_url[len("http://"):]
elif ws_url.startswith("https://"):
ws_url = "wss://" + ws_url[len("https://"):]
elif not (
ws_url.startswith("ws://") or ws_url.startswith("wss://")
):
ws_url = "ws://" + ws_url
self._ws = await websockets.connect(
ws_url,
ping_interval=20,
ping_timeout=self.ping_timeout,
max_size=64 * 1024 * 1024,
)
if self.token:
# First-frame auth handshake.
await self._ws.send(json.dumps({
"type": "auth", "token": self.token,
}))
raw = await asyncio.wait_for(self._ws.recv(), timeout=10)
resp = json.loads(raw)
if resp.get("type") != "auth-ok":
await self._ws.close()
raise RuntimeError(f"auth failed: {resp}")
if "workspace" in resp:
# Server-resolved workspace overrides the user-supplied
# one, mirroring AsyncSocketClient behaviour.
self.workspace = resp["workspace"]
else:
print(
"WARNING: no token provided — skipping auth handshake. "
"Requests will be rejected unless the gateway is "
"running without IAM enforcement.",
file=sys.stderr,
)
self._reader_task = asyncio.create_task(self._reader())
async def _reader(self) -> None:
try:
async for raw in self._ws:
msg = json.loads(raw)
rid = msg.get("id")
if rid and rid in self._pending:
await self._pending[rid].put(msg)
except websockets.exceptions.ConnectionClosed:
pass
except Exception as e:
for q in list(self._pending.values()):
try:
q.put_nowait({"error": {"message": str(e)}})
except Exception:
pass
finally:
self._closed.set()
async def request(
self, service: str, flow: str | None, body: dict, timeout: float,
) -> tuple[dict | None, str | None, float]:
"""Send one request, await final response.
Returns ``(response, error, latency_ms)``. ``response`` is None
on error/timeout. ``error`` describes the failure category.
"""
rid = str(uuid.uuid4())
q: asyncio.Queue = asyncio.Queue()
self._pending[rid] = q
env = {
"id": rid,
"workspace": self.workspace,
"service": service,
"request": body,
}
if flow:
env["flow"] = flow
t0 = time.monotonic()
try:
await self._ws.send(json.dumps(env))
while True:
try:
msg = await asyncio.wait_for(q.get(), timeout=timeout)
except asyncio.TimeoutError:
return None, "timeout", (time.monotonic() - t0) * 1000
if "error" in msg and msg["error"]:
err = msg["error"]
err_msg = (
err.get("message") if isinstance(err, dict) else str(err)
)
return None, f"error: {err_msg}", (time.monotonic() - t0) * 1000
if msg.get("complete"):
return msg.get("response"), None, (time.monotonic() - t0) * 1000
# Otherwise an intermediate streaming chunk — keep waiting.
finally:
self._pending.pop(rid, None)
async def close(self) -> None:
if self._ws is not None:
await self._ws.close()
if self._reader_task is not None:
try:
await asyncio.wait_for(self._reader_task, timeout=2)
except (asyncio.TimeoutError, asyncio.CancelledError):
pass
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument(
"--url",
default=os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/"),
help="Gateway URL (http or ws). Default: %(default)s",
)
p.add_argument(
"--token",
default=os.getenv("TRUSTGRAPH_TOKEN"),
help="Auth token (or set TRUSTGRAPH_TOKEN). Optional — if "
"omitted, the auth handshake is skipped (only works "
"when the gateway is running without IAM enforcement).",
)
p.add_argument(
"--workspace", default="default",
help="Workspace. Default: %(default)s",
)
p.add_argument(
"--flow", required=True,
help="Flow id. Comma-separated for round-robin across flows "
"(e.g. onto-rag,doc-rag).",
)
p.add_argument(
"--duration", type=int, default=60,
help="Test duration in seconds. Default: %(default)s",
)
p.add_argument(
"--concurrency", type=int, default=15,
help="Target in-flight request count. Default: %(default)s",
)
p.add_argument(
"--services",
default="embeddings,graph-embeddings,triples",
help="Comma-separated services to exercise. "
"Default: %(default)s",
)
p.add_argument(
"--limit", type=int, default=3,
help="limit for triples / graph-embeddings queries. "
"Default: %(default)s",
)
p.add_argument(
"--collection", default="default",
help="Collection. Default: %(default)s",
)
p.add_argument(
"--text", default=DEFAULT_TEXT,
help="Text to embed for embeddings/seed.",
)
p.add_argument(
"--vector-dim", type=int, default=384,
help="Dimension of synthetic vector when --no-seed is used. "
"Default: %(default)s",
)
p.add_argument(
"--no-seed", action="store_true",
help="Skip the embeddings warm-up call. Use a random vector "
"for graph-embeddings queries instead.",
)
p.add_argument(
"--request-timeout", type=float, default=30.0,
help="Per-request timeout (seconds). Default: %(default)s",
)
p.add_argument(
"--report-interval", type=float, default=5.0,
help="How often to print stats (seconds). Default: %(default)s",
)
p.add_argument(
"--ping-timeout", type=int, default=120,
help="Websocket ping timeout. Default: %(default)s",
)
p.add_argument(
"--seed", type=int, default=None,
help="Random seed (for reproducibility).",
)
return p.parse_args()
async def seed_vector(
client: WSClient, flow: str, text: str, timeout: float,
) -> list[float]:
"""Issue one embeddings request to obtain a real vector that
later graph-embeddings calls can reuse."""
resp, err, _ = await client.request(
"embeddings", flow, {"texts": [text]}, timeout,
)
if err or not resp:
raise RuntimeError(f"seed embeddings failed: {err or resp}")
vectors = resp.get("vectors")
if not vectors:
raise RuntimeError(f"seed embeddings: no vectors in response: {resp}")
return vectors[0]
def make_request_body(
service: str, args: argparse.Namespace, vector: list[float],
) -> dict:
if service == "embeddings":
return {"texts": [args.text]}
if service == "graph-embeddings":
return {
"vector": vector,
"limit": args.limit,
"collection": args.collection,
}
if service == "triples":
return {
"limit": args.limit,
"collection": args.collection,
}
raise ValueError(f"Unknown service: {service}")
async def worker(
name: int,
client: WSClient,
flows: list[str],
services: list[str],
args: argparse.Namespace,
vector: list[float],
stats: dict[str, Stats],
in_flight: dict[str, int],
stop_at: float,
) -> None:
rng = random.Random((args.seed or 0) + name)
while time.monotonic() < stop_at:
svc = rng.choice(services)
flow = rng.choice(flows)
body = make_request_body(svc, args, vector)
stats[svc].sent += 1
in_flight[svc] += 1
try:
resp, err, lat = await client.request(
svc, flow, body, args.request_timeout,
)
if err == "timeout":
stats[svc].record_timeout()
elif err:
stats[svc].record_err()
else:
stats[svc].record_ok(lat)
except Exception as e:
stats[svc].record_err()
print(f"worker {name}: unexpected {svc} exception: {e}",
file=sys.stderr)
finally:
in_flight[svc] -= 1
async def reporter(
services: list[str],
stats: dict[str, Stats],
in_flight: dict[str, int],
stop_at: float,
interval: float,
) -> None:
started = time.monotonic()
last_sent = {s: 0 for s in services}
while time.monotonic() < stop_at:
await asyncio.sleep(interval)
now = time.monotonic()
elapsed = now - started
total_inflight = sum(in_flight.values())
print(
f"\n[{elapsed:6.1f}s] in-flight={total_inflight} "
f"per-svc={dict(in_flight)}"
)
for svc in services:
s = stats[svc]
delta = s.sent - last_sent[svc]
rate = delta / interval
last_sent[svc] = s.sent
print(f" {svc:20s} {rate:6.1f}/s | {s.summary()}")
async def run(args: argparse.Namespace) -> int:
if args.seed is not None:
random.seed(args.seed)
services = [s.strip() for s in args.services.split(",") if s.strip()]
flows = [f.strip() for f in args.flow.split(",") if f.strip()]
valid = {"embeddings", "graph-embeddings", "triples"}
bad = [s for s in services if s not in valid]
if bad:
print(f"ERROR: unknown service(s): {bad}. "
f"Supported: {sorted(valid)}", file=sys.stderr)
return 2
client = WSClient(
args.url, args.token, args.workspace, args.ping_timeout,
)
print(f"Connecting to {args.url} ...")
await client.connect()
print(f"Connected. workspace={client.workspace} flows={flows} "
f"services={services} concurrency={args.concurrency} "
f"duration={args.duration}s")
if "graph-embeddings" in services and not args.no_seed:
print("Seeding embedding vector ...")
vector = await seed_vector(
client, flows[0], args.text, args.request_timeout,
)
print(f"Got vector of length {len(vector)}")
else:
vector = [random.uniform(-1.0, 1.0) for _ in range(args.vector_dim)]
stats: dict[str, Stats] = defaultdict(Stats)
in_flight: dict[str, int] = defaultdict(int)
for svc in services:
stats[svc] # initialise
in_flight[svc] = 0
stop_at = time.monotonic() + args.duration
print(f"Starting load: {args.concurrency} workers for "
f"{args.duration}s ...")
workers = [
asyncio.create_task(
worker(
i, client, flows, services, args, vector,
stats, in_flight, stop_at,
)
)
for i in range(args.concurrency)
]
rep = asyncio.create_task(
reporter(services, stats, in_flight, stop_at, args.report_interval)
)
try:
await asyncio.gather(*workers)
finally:
rep.cancel()
try:
await rep
except asyncio.CancelledError:
pass
print("\n=== Final results ===")
any_failures = False
for svc in services:
s = stats[svc]
print(f" {svc:20s} {s.summary()}")
if s.timeout > 0 or s.err > 0:
any_failures = True
await client.close()
return 1 if any_failures else 0
def main() -> int:
args = parse_args()
try:
return asyncio.run(run(args))
except KeyboardInterrupt:
return 130
if __name__ == "__main__":
sys.exit(main())