diff --git a/dev-tools/tests/smoke/smoke_ws_queries.py b/dev-tools/tests/smoke/smoke_ws_queries.py new file mode 100755 index 00000000..c6a4dfb6 --- /dev/null +++ b/dev-tools/tests/smoke/smoke_ws_queries.py @@ -0,0 +1,475 @@ +#!/usr/bin/env python3 +""" +WebSocket smoke / load test that hammers a TrustGraph gateway with a +mix of `embeddings`, `graph-embeddings`, and `triples` queries while +keeping a target number of in-flight requests at all times. + +Useful for reproducing the "worker hangs after a while, all subsequent +requests time out" failure mode — leaves enough load on the system to +saturate worker concurrency and reports per-service success/timeout +rates and latency distributions over time. + +Usage: + smoke_ws_queries.py --flow onto-rag --duration 120 --concurrency 20 + +Connects via /api/v1/socket using the first-frame auth protocol. +""" + +import argparse +import asyncio +import json +import os +import random +import statistics +import sys +import time +import uuid +from collections import defaultdict +from typing import Any + +import websockets + + +DEFAULT_TEXT = ( + "What caused the space shuttle to explode and what were the " + "main factors leading to the disaster?" +) + + +class Stats: + """Per-service rolling counters and latency samples.""" + + def __init__(self) -> None: + self.sent = 0 + self.ok = 0 + self.err = 0 + self.timeout = 0 + self.latencies_ms: list[float] = [] + + def record_ok(self, latency_ms: float) -> None: + self.ok += 1 + self.latencies_ms.append(latency_ms) + + def record_err(self) -> None: + self.err += 1 + + def record_timeout(self) -> None: + self.timeout += 1 + + def percentile(self, p: float) -> float: + if not self.latencies_ms: + return 0.0 + s = sorted(self.latencies_ms) + idx = min(len(s) - 1, int(len(s) * p)) + return s[idx] + + def summary(self) -> str: + if self.latencies_ms: + mn = min(self.latencies_ms) + mx = max(self.latencies_ms) + mean = statistics.mean(self.latencies_ms) + p50 = self.percentile(0.50) + p95 = self.percentile(0.95) + p99 = self.percentile(0.99) + lat = ( + f"min={mn:.0f} mean={mean:.0f} p50={p50:.0f} " + f"p95={p95:.0f} p99={p99:.0f} max={mx:.0f} ms" + ) + else: + lat = "no successful samples" + return ( + f"sent={self.sent} ok={self.ok} err={self.err} " + f"timeout={self.timeout} | {lat}" + ) + + +class WSClient: + """Thin async websocket client with first-frame auth and a shared + reader task that demuxes responses to per-request asyncio queues.""" + + def __init__( + self, url: str, token: str | None, workspace: str, + ping_timeout: int, + ) -> None: + self.url = url + self.token = token + self.workspace = workspace + self.ping_timeout = ping_timeout + self._ws: Any = None + self._pending: dict[str, asyncio.Queue] = {} + self._reader_task: asyncio.Task | None = None + self._closed = asyncio.Event() + + async def connect(self) -> None: + ws_url = self.url.rstrip("/") + "/api/v1/socket" + if ws_url.startswith("http://"): + ws_url = "ws://" + ws_url[len("http://"):] + elif ws_url.startswith("https://"): + ws_url = "wss://" + ws_url[len("https://"):] + elif not ( + ws_url.startswith("ws://") or ws_url.startswith("wss://") + ): + ws_url = "ws://" + ws_url + + self._ws = await websockets.connect( + ws_url, + ping_interval=20, + ping_timeout=self.ping_timeout, + max_size=64 * 1024 * 1024, + ) + + if self.token: + # First-frame auth handshake. + await self._ws.send(json.dumps({ + "type": "auth", "token": self.token, + })) + raw = await asyncio.wait_for(self._ws.recv(), timeout=10) + resp = json.loads(raw) + if resp.get("type") != "auth-ok": + await self._ws.close() + raise RuntimeError(f"auth failed: {resp}") + if "workspace" in resp: + # Server-resolved workspace overrides the user-supplied + # one, mirroring AsyncSocketClient behaviour. + self.workspace = resp["workspace"] + else: + print( + "WARNING: no token provided — skipping auth handshake. " + "Requests will be rejected unless the gateway is " + "running without IAM enforcement.", + file=sys.stderr, + ) + + self._reader_task = asyncio.create_task(self._reader()) + + async def _reader(self) -> None: + try: + async for raw in self._ws: + msg = json.loads(raw) + rid = msg.get("id") + if rid and rid in self._pending: + await self._pending[rid].put(msg) + except websockets.exceptions.ConnectionClosed: + pass + except Exception as e: + for q in list(self._pending.values()): + try: + q.put_nowait({"error": {"message": str(e)}}) + except Exception: + pass + finally: + self._closed.set() + + async def request( + self, service: str, flow: str | None, body: dict, timeout: float, + ) -> tuple[dict | None, str | None, float]: + """Send one request, await final response. + + Returns ``(response, error, latency_ms)``. ``response`` is None + on error/timeout. ``error`` describes the failure category. + """ + rid = str(uuid.uuid4()) + q: asyncio.Queue = asyncio.Queue() + self._pending[rid] = q + env = { + "id": rid, + "workspace": self.workspace, + "service": service, + "request": body, + } + if flow: + env["flow"] = flow + + t0 = time.monotonic() + try: + await self._ws.send(json.dumps(env)) + while True: + try: + msg = await asyncio.wait_for(q.get(), timeout=timeout) + except asyncio.TimeoutError: + return None, "timeout", (time.monotonic() - t0) * 1000 + if "error" in msg and msg["error"]: + err = msg["error"] + err_msg = ( + err.get("message") if isinstance(err, dict) else str(err) + ) + return None, f"error: {err_msg}", (time.monotonic() - t0) * 1000 + if msg.get("complete"): + return msg.get("response"), None, (time.monotonic() - t0) * 1000 + # Otherwise an intermediate streaming chunk — keep waiting. + finally: + self._pending.pop(rid, None) + + async def close(self) -> None: + if self._ws is not None: + await self._ws.close() + if self._reader_task is not None: + try: + await asyncio.wait_for(self._reader_task, timeout=2) + except (asyncio.TimeoutError, asyncio.CancelledError): + pass + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument( + "--url", + default=os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/"), + help="Gateway URL (http or ws). Default: %(default)s", + ) + p.add_argument( + "--token", + default=os.getenv("TRUSTGRAPH_TOKEN"), + help="Auth token (or set TRUSTGRAPH_TOKEN). Optional — if " + "omitted, the auth handshake is skipped (only works " + "when the gateway is running without IAM enforcement).", + ) + p.add_argument( + "--workspace", default="default", + help="Workspace. Default: %(default)s", + ) + p.add_argument( + "--flow", required=True, + help="Flow id. Comma-separated for round-robin across flows " + "(e.g. onto-rag,doc-rag).", + ) + p.add_argument( + "--duration", type=int, default=60, + help="Test duration in seconds. Default: %(default)s", + ) + p.add_argument( + "--concurrency", type=int, default=15, + help="Target in-flight request count. Default: %(default)s", + ) + p.add_argument( + "--services", + default="embeddings,graph-embeddings,triples", + help="Comma-separated services to exercise. " + "Default: %(default)s", + ) + p.add_argument( + "--limit", type=int, default=3, + help="limit for triples / graph-embeddings queries. " + "Default: %(default)s", + ) + p.add_argument( + "--collection", default="default", + help="Collection. Default: %(default)s", + ) + p.add_argument( + "--text", default=DEFAULT_TEXT, + help="Text to embed for embeddings/seed.", + ) + p.add_argument( + "--vector-dim", type=int, default=384, + help="Dimension of synthetic vector when --no-seed is used. " + "Default: %(default)s", + ) + p.add_argument( + "--no-seed", action="store_true", + help="Skip the embeddings warm-up call. Use a random vector " + "for graph-embeddings queries instead.", + ) + p.add_argument( + "--request-timeout", type=float, default=30.0, + help="Per-request timeout (seconds). Default: %(default)s", + ) + p.add_argument( + "--report-interval", type=float, default=5.0, + help="How often to print stats (seconds). Default: %(default)s", + ) + p.add_argument( + "--ping-timeout", type=int, default=120, + help="Websocket ping timeout. Default: %(default)s", + ) + p.add_argument( + "--seed", type=int, default=None, + help="Random seed (for reproducibility).", + ) + return p.parse_args() + + +async def seed_vector( + client: WSClient, flow: str, text: str, timeout: float, +) -> list[float]: + """Issue one embeddings request to obtain a real vector that + later graph-embeddings calls can reuse.""" + resp, err, _ = await client.request( + "embeddings", flow, {"texts": [text]}, timeout, + ) + if err or not resp: + raise RuntimeError(f"seed embeddings failed: {err or resp}") + vectors = resp.get("vectors") + if not vectors: + raise RuntimeError(f"seed embeddings: no vectors in response: {resp}") + return vectors[0] + + +def make_request_body( + service: str, args: argparse.Namespace, vector: list[float], +) -> dict: + if service == "embeddings": + return {"texts": [args.text]} + if service == "graph-embeddings": + return { + "vector": vector, + "limit": args.limit, + "collection": args.collection, + } + if service == "triples": + return { + "limit": args.limit, + "collection": args.collection, + } + raise ValueError(f"Unknown service: {service}") + + +async def worker( + name: int, + client: WSClient, + flows: list[str], + services: list[str], + args: argparse.Namespace, + vector: list[float], + stats: dict[str, Stats], + in_flight: dict[str, int], + stop_at: float, +) -> None: + rng = random.Random((args.seed or 0) + name) + while time.monotonic() < stop_at: + svc = rng.choice(services) + flow = rng.choice(flows) + body = make_request_body(svc, args, vector) + + stats[svc].sent += 1 + in_flight[svc] += 1 + try: + resp, err, lat = await client.request( + svc, flow, body, args.request_timeout, + ) + if err == "timeout": + stats[svc].record_timeout() + elif err: + stats[svc].record_err() + else: + stats[svc].record_ok(lat) + except Exception as e: + stats[svc].record_err() + print(f"worker {name}: unexpected {svc} exception: {e}", + file=sys.stderr) + finally: + in_flight[svc] -= 1 + + +async def reporter( + services: list[str], + stats: dict[str, Stats], + in_flight: dict[str, int], + stop_at: float, + interval: float, +) -> None: + started = time.monotonic() + last_sent = {s: 0 for s in services} + while time.monotonic() < stop_at: + await asyncio.sleep(interval) + now = time.monotonic() + elapsed = now - started + total_inflight = sum(in_flight.values()) + print( + f"\n[{elapsed:6.1f}s] in-flight={total_inflight} " + f"per-svc={dict(in_flight)}" + ) + for svc in services: + s = stats[svc] + delta = s.sent - last_sent[svc] + rate = delta / interval + last_sent[svc] = s.sent + print(f" {svc:20s} {rate:6.1f}/s | {s.summary()}") + + +async def run(args: argparse.Namespace) -> int: + if args.seed is not None: + random.seed(args.seed) + + services = [s.strip() for s in args.services.split(",") if s.strip()] + flows = [f.strip() for f in args.flow.split(",") if f.strip()] + valid = {"embeddings", "graph-embeddings", "triples"} + bad = [s for s in services if s not in valid] + if bad: + print(f"ERROR: unknown service(s): {bad}. " + f"Supported: {sorted(valid)}", file=sys.stderr) + return 2 + + client = WSClient( + args.url, args.token, args.workspace, args.ping_timeout, + ) + print(f"Connecting to {args.url} ...") + await client.connect() + print(f"Connected. workspace={client.workspace} flows={flows} " + f"services={services} concurrency={args.concurrency} " + f"duration={args.duration}s") + + if "graph-embeddings" in services and not args.no_seed: + print("Seeding embedding vector ...") + vector = await seed_vector( + client, flows[0], args.text, args.request_timeout, + ) + print(f"Got vector of length {len(vector)}") + else: + vector = [random.uniform(-1.0, 1.0) for _ in range(args.vector_dim)] + + stats: dict[str, Stats] = defaultdict(Stats) + in_flight: dict[str, int] = defaultdict(int) + for svc in services: + stats[svc] # initialise + in_flight[svc] = 0 + + stop_at = time.monotonic() + args.duration + print(f"Starting load: {args.concurrency} workers for " + f"{args.duration}s ...") + + workers = [ + asyncio.create_task( + worker( + i, client, flows, services, args, vector, + stats, in_flight, stop_at, + ) + ) + for i in range(args.concurrency) + ] + rep = asyncio.create_task( + reporter(services, stats, in_flight, stop_at, args.report_interval) + ) + + try: + await asyncio.gather(*workers) + finally: + rep.cancel() + try: + await rep + except asyncio.CancelledError: + pass + + print("\n=== Final results ===") + any_failures = False + for svc in services: + s = stats[svc] + print(f" {svc:20s} {s.summary()}") + if s.timeout > 0 or s.err > 0: + any_failures = True + + await client.close() + + return 1 if any_failures else 0 + + +def main() -> int: + args = parse_args() + try: + return asyncio.run(run(args)) + except KeyboardInterrupt: + return 130 + + +if __name__ == "__main__": + sys.exit(main())