mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-29 10:26:21 +02:00
Smoke-test websocket tool (#852)
This commit is contained in:
parent
666af1c4b3
commit
b15f1a167c
1 changed files with 475 additions and 0 deletions
475
dev-tools/tests/smoke/smoke_ws_queries.py
Executable file
475
dev-tools/tests/smoke/smoke_ws_queries.py
Executable file
|
|
@ -0,0 +1,475 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
WebSocket smoke / load test that hammers a TrustGraph gateway with a
|
||||
mix of `embeddings`, `graph-embeddings`, and `triples` queries while
|
||||
keeping a target number of in-flight requests at all times.
|
||||
|
||||
Useful for reproducing the "worker hangs after a while, all subsequent
|
||||
requests time out" failure mode — leaves enough load on the system to
|
||||
saturate worker concurrency and reports per-service success/timeout
|
||||
rates and latency distributions over time.
|
||||
|
||||
Usage:
|
||||
smoke_ws_queries.py --flow onto-rag --duration 120 --concurrency 20
|
||||
|
||||
Connects via /api/v1/socket using the first-frame auth protocol.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import statistics
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
import websockets
|
||||
|
||||
|
||||
DEFAULT_TEXT = (
|
||||
"What caused the space shuttle to explode and what were the "
|
||||
"main factors leading to the disaster?"
|
||||
)
|
||||
|
||||
|
||||
class Stats:
|
||||
"""Per-service rolling counters and latency samples."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.sent = 0
|
||||
self.ok = 0
|
||||
self.err = 0
|
||||
self.timeout = 0
|
||||
self.latencies_ms: list[float] = []
|
||||
|
||||
def record_ok(self, latency_ms: float) -> None:
|
||||
self.ok += 1
|
||||
self.latencies_ms.append(latency_ms)
|
||||
|
||||
def record_err(self) -> None:
|
||||
self.err += 1
|
||||
|
||||
def record_timeout(self) -> None:
|
||||
self.timeout += 1
|
||||
|
||||
def percentile(self, p: float) -> float:
|
||||
if not self.latencies_ms:
|
||||
return 0.0
|
||||
s = sorted(self.latencies_ms)
|
||||
idx = min(len(s) - 1, int(len(s) * p))
|
||||
return s[idx]
|
||||
|
||||
def summary(self) -> str:
|
||||
if self.latencies_ms:
|
||||
mn = min(self.latencies_ms)
|
||||
mx = max(self.latencies_ms)
|
||||
mean = statistics.mean(self.latencies_ms)
|
||||
p50 = self.percentile(0.50)
|
||||
p95 = self.percentile(0.95)
|
||||
p99 = self.percentile(0.99)
|
||||
lat = (
|
||||
f"min={mn:.0f} mean={mean:.0f} p50={p50:.0f} "
|
||||
f"p95={p95:.0f} p99={p99:.0f} max={mx:.0f} ms"
|
||||
)
|
||||
else:
|
||||
lat = "no successful samples"
|
||||
return (
|
||||
f"sent={self.sent} ok={self.ok} err={self.err} "
|
||||
f"timeout={self.timeout} | {lat}"
|
||||
)
|
||||
|
||||
|
||||
class WSClient:
|
||||
"""Thin async websocket client with first-frame auth and a shared
|
||||
reader task that demuxes responses to per-request asyncio queues."""
|
||||
|
||||
def __init__(
|
||||
self, url: str, token: str | None, workspace: str,
|
||||
ping_timeout: int,
|
||||
) -> None:
|
||||
self.url = url
|
||||
self.token = token
|
||||
self.workspace = workspace
|
||||
self.ping_timeout = ping_timeout
|
||||
self._ws: Any = None
|
||||
self._pending: dict[str, asyncio.Queue] = {}
|
||||
self._reader_task: asyncio.Task | None = None
|
||||
self._closed = asyncio.Event()
|
||||
|
||||
async def connect(self) -> None:
|
||||
ws_url = self.url.rstrip("/") + "/api/v1/socket"
|
||||
if ws_url.startswith("http://"):
|
||||
ws_url = "ws://" + ws_url[len("http://"):]
|
||||
elif ws_url.startswith("https://"):
|
||||
ws_url = "wss://" + ws_url[len("https://"):]
|
||||
elif not (
|
||||
ws_url.startswith("ws://") or ws_url.startswith("wss://")
|
||||
):
|
||||
ws_url = "ws://" + ws_url
|
||||
|
||||
self._ws = await websockets.connect(
|
||||
ws_url,
|
||||
ping_interval=20,
|
||||
ping_timeout=self.ping_timeout,
|
||||
max_size=64 * 1024 * 1024,
|
||||
)
|
||||
|
||||
if self.token:
|
||||
# First-frame auth handshake.
|
||||
await self._ws.send(json.dumps({
|
||||
"type": "auth", "token": self.token,
|
||||
}))
|
||||
raw = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||||
resp = json.loads(raw)
|
||||
if resp.get("type") != "auth-ok":
|
||||
await self._ws.close()
|
||||
raise RuntimeError(f"auth failed: {resp}")
|
||||
if "workspace" in resp:
|
||||
# Server-resolved workspace overrides the user-supplied
|
||||
# one, mirroring AsyncSocketClient behaviour.
|
||||
self.workspace = resp["workspace"]
|
||||
else:
|
||||
print(
|
||||
"WARNING: no token provided — skipping auth handshake. "
|
||||
"Requests will be rejected unless the gateway is "
|
||||
"running without IAM enforcement.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
self._reader_task = asyncio.create_task(self._reader())
|
||||
|
||||
async def _reader(self) -> None:
|
||||
try:
|
||||
async for raw in self._ws:
|
||||
msg = json.loads(raw)
|
||||
rid = msg.get("id")
|
||||
if rid and rid in self._pending:
|
||||
await self._pending[rid].put(msg)
|
||||
except websockets.exceptions.ConnectionClosed:
|
||||
pass
|
||||
except Exception as e:
|
||||
for q in list(self._pending.values()):
|
||||
try:
|
||||
q.put_nowait({"error": {"message": str(e)}})
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self._closed.set()
|
||||
|
||||
async def request(
|
||||
self, service: str, flow: str | None, body: dict, timeout: float,
|
||||
) -> tuple[dict | None, str | None, float]:
|
||||
"""Send one request, await final response.
|
||||
|
||||
Returns ``(response, error, latency_ms)``. ``response`` is None
|
||||
on error/timeout. ``error`` describes the failure category.
|
||||
"""
|
||||
rid = str(uuid.uuid4())
|
||||
q: asyncio.Queue = asyncio.Queue()
|
||||
self._pending[rid] = q
|
||||
env = {
|
||||
"id": rid,
|
||||
"workspace": self.workspace,
|
||||
"service": service,
|
||||
"request": body,
|
||||
}
|
||||
if flow:
|
||||
env["flow"] = flow
|
||||
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
await self._ws.send(json.dumps(env))
|
||||
while True:
|
||||
try:
|
||||
msg = await asyncio.wait_for(q.get(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
return None, "timeout", (time.monotonic() - t0) * 1000
|
||||
if "error" in msg and msg["error"]:
|
||||
err = msg["error"]
|
||||
err_msg = (
|
||||
err.get("message") if isinstance(err, dict) else str(err)
|
||||
)
|
||||
return None, f"error: {err_msg}", (time.monotonic() - t0) * 1000
|
||||
if msg.get("complete"):
|
||||
return msg.get("response"), None, (time.monotonic() - t0) * 1000
|
||||
# Otherwise an intermediate streaming chunk — keep waiting.
|
||||
finally:
|
||||
self._pending.pop(rid, None)
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._ws is not None:
|
||||
await self._ws.close()
|
||||
if self._reader_task is not None:
|
||||
try:
|
||||
await asyncio.wait_for(self._reader_task, timeout=2)
|
||||
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||||
pass
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument(
|
||||
"--url",
|
||||
default=os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/"),
|
||||
help="Gateway URL (http or ws). Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--token",
|
||||
default=os.getenv("TRUSTGRAPH_TOKEN"),
|
||||
help="Auth token (or set TRUSTGRAPH_TOKEN). Optional — if "
|
||||
"omitted, the auth handshake is skipped (only works "
|
||||
"when the gateway is running without IAM enforcement).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--workspace", default="default",
|
||||
help="Workspace. Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--flow", required=True,
|
||||
help="Flow id. Comma-separated for round-robin across flows "
|
||||
"(e.g. onto-rag,doc-rag).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--duration", type=int, default=60,
|
||||
help="Test duration in seconds. Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--concurrency", type=int, default=15,
|
||||
help="Target in-flight request count. Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--services",
|
||||
default="embeddings,graph-embeddings,triples",
|
||||
help="Comma-separated services to exercise. "
|
||||
"Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--limit", type=int, default=3,
|
||||
help="limit for triples / graph-embeddings queries. "
|
||||
"Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--collection", default="default",
|
||||
help="Collection. Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--text", default=DEFAULT_TEXT,
|
||||
help="Text to embed for embeddings/seed.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--vector-dim", type=int, default=384,
|
||||
help="Dimension of synthetic vector when --no-seed is used. "
|
||||
"Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--no-seed", action="store_true",
|
||||
help="Skip the embeddings warm-up call. Use a random vector "
|
||||
"for graph-embeddings queries instead.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--request-timeout", type=float, default=30.0,
|
||||
help="Per-request timeout (seconds). Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--report-interval", type=float, default=5.0,
|
||||
help="How often to print stats (seconds). Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--ping-timeout", type=int, default=120,
|
||||
help="Websocket ping timeout. Default: %(default)s",
|
||||
)
|
||||
p.add_argument(
|
||||
"--seed", type=int, default=None,
|
||||
help="Random seed (for reproducibility).",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
async def seed_vector(
|
||||
client: WSClient, flow: str, text: str, timeout: float,
|
||||
) -> list[float]:
|
||||
"""Issue one embeddings request to obtain a real vector that
|
||||
later graph-embeddings calls can reuse."""
|
||||
resp, err, _ = await client.request(
|
||||
"embeddings", flow, {"texts": [text]}, timeout,
|
||||
)
|
||||
if err or not resp:
|
||||
raise RuntimeError(f"seed embeddings failed: {err or resp}")
|
||||
vectors = resp.get("vectors")
|
||||
if not vectors:
|
||||
raise RuntimeError(f"seed embeddings: no vectors in response: {resp}")
|
||||
return vectors[0]
|
||||
|
||||
|
||||
def make_request_body(
|
||||
service: str, args: argparse.Namespace, vector: list[float],
|
||||
) -> dict:
|
||||
if service == "embeddings":
|
||||
return {"texts": [args.text]}
|
||||
if service == "graph-embeddings":
|
||||
return {
|
||||
"vector": vector,
|
||||
"limit": args.limit,
|
||||
"collection": args.collection,
|
||||
}
|
||||
if service == "triples":
|
||||
return {
|
||||
"limit": args.limit,
|
||||
"collection": args.collection,
|
||||
}
|
||||
raise ValueError(f"Unknown service: {service}")
|
||||
|
||||
|
||||
async def worker(
|
||||
name: int,
|
||||
client: WSClient,
|
||||
flows: list[str],
|
||||
services: list[str],
|
||||
args: argparse.Namespace,
|
||||
vector: list[float],
|
||||
stats: dict[str, Stats],
|
||||
in_flight: dict[str, int],
|
||||
stop_at: float,
|
||||
) -> None:
|
||||
rng = random.Random((args.seed or 0) + name)
|
||||
while time.monotonic() < stop_at:
|
||||
svc = rng.choice(services)
|
||||
flow = rng.choice(flows)
|
||||
body = make_request_body(svc, args, vector)
|
||||
|
||||
stats[svc].sent += 1
|
||||
in_flight[svc] += 1
|
||||
try:
|
||||
resp, err, lat = await client.request(
|
||||
svc, flow, body, args.request_timeout,
|
||||
)
|
||||
if err == "timeout":
|
||||
stats[svc].record_timeout()
|
||||
elif err:
|
||||
stats[svc].record_err()
|
||||
else:
|
||||
stats[svc].record_ok(lat)
|
||||
except Exception as e:
|
||||
stats[svc].record_err()
|
||||
print(f"worker {name}: unexpected {svc} exception: {e}",
|
||||
file=sys.stderr)
|
||||
finally:
|
||||
in_flight[svc] -= 1
|
||||
|
||||
|
||||
async def reporter(
|
||||
services: list[str],
|
||||
stats: dict[str, Stats],
|
||||
in_flight: dict[str, int],
|
||||
stop_at: float,
|
||||
interval: float,
|
||||
) -> None:
|
||||
started = time.monotonic()
|
||||
last_sent = {s: 0 for s in services}
|
||||
while time.monotonic() < stop_at:
|
||||
await asyncio.sleep(interval)
|
||||
now = time.monotonic()
|
||||
elapsed = now - started
|
||||
total_inflight = sum(in_flight.values())
|
||||
print(
|
||||
f"\n[{elapsed:6.1f}s] in-flight={total_inflight} "
|
||||
f"per-svc={dict(in_flight)}"
|
||||
)
|
||||
for svc in services:
|
||||
s = stats[svc]
|
||||
delta = s.sent - last_sent[svc]
|
||||
rate = delta / interval
|
||||
last_sent[svc] = s.sent
|
||||
print(f" {svc:20s} {rate:6.1f}/s | {s.summary()}")
|
||||
|
||||
|
||||
async def run(args: argparse.Namespace) -> int:
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
services = [s.strip() for s in args.services.split(",") if s.strip()]
|
||||
flows = [f.strip() for f in args.flow.split(",") if f.strip()]
|
||||
valid = {"embeddings", "graph-embeddings", "triples"}
|
||||
bad = [s for s in services if s not in valid]
|
||||
if bad:
|
||||
print(f"ERROR: unknown service(s): {bad}. "
|
||||
f"Supported: {sorted(valid)}", file=sys.stderr)
|
||||
return 2
|
||||
|
||||
client = WSClient(
|
||||
args.url, args.token, args.workspace, args.ping_timeout,
|
||||
)
|
||||
print(f"Connecting to {args.url} ...")
|
||||
await client.connect()
|
||||
print(f"Connected. workspace={client.workspace} flows={flows} "
|
||||
f"services={services} concurrency={args.concurrency} "
|
||||
f"duration={args.duration}s")
|
||||
|
||||
if "graph-embeddings" in services and not args.no_seed:
|
||||
print("Seeding embedding vector ...")
|
||||
vector = await seed_vector(
|
||||
client, flows[0], args.text, args.request_timeout,
|
||||
)
|
||||
print(f"Got vector of length {len(vector)}")
|
||||
else:
|
||||
vector = [random.uniform(-1.0, 1.0) for _ in range(args.vector_dim)]
|
||||
|
||||
stats: dict[str, Stats] = defaultdict(Stats)
|
||||
in_flight: dict[str, int] = defaultdict(int)
|
||||
for svc in services:
|
||||
stats[svc] # initialise
|
||||
in_flight[svc] = 0
|
||||
|
||||
stop_at = time.monotonic() + args.duration
|
||||
print(f"Starting load: {args.concurrency} workers for "
|
||||
f"{args.duration}s ...")
|
||||
|
||||
workers = [
|
||||
asyncio.create_task(
|
||||
worker(
|
||||
i, client, flows, services, args, vector,
|
||||
stats, in_flight, stop_at,
|
||||
)
|
||||
)
|
||||
for i in range(args.concurrency)
|
||||
]
|
||||
rep = asyncio.create_task(
|
||||
reporter(services, stats, in_flight, stop_at, args.report_interval)
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.gather(*workers)
|
||||
finally:
|
||||
rep.cancel()
|
||||
try:
|
||||
await rep
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
print("\n=== Final results ===")
|
||||
any_failures = False
|
||||
for svc in services:
|
||||
s = stats[svc]
|
||||
print(f" {svc:20s} {s.summary()}")
|
||||
if s.timeout > 0 or s.err > 0:
|
||||
any_failures = True
|
||||
|
||||
await client.close()
|
||||
|
||||
return 1 if any_failures else 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
try:
|
||||
return asyncio.run(run(args))
|
||||
except KeyboardInterrupt:
|
||||
return 130
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
Loading…
Add table
Add a link
Reference in a new issue