mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-01 03:16:23 +02:00
476 lines
15 KiB
Python
476 lines
15 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
WebSocket smoke / load test that hammers a TrustGraph gateway with a
|
||
|
|
mix of `embeddings`, `graph-embeddings`, and `triples` queries while
|
||
|
|
keeping a target number of in-flight requests at all times.
|
||
|
|
|
||
|
|
Useful for reproducing the "worker hangs after a while, all subsequent
|
||
|
|
requests time out" failure mode — leaves enough load on the system to
|
||
|
|
saturate worker concurrency and reports per-service success/timeout
|
||
|
|
rates and latency distributions over time.
|
||
|
|
|
||
|
|
Usage:
|
||
|
|
smoke_ws_queries.py --flow onto-rag --duration 120 --concurrency 20
|
||
|
|
|
||
|
|
Connects via /api/v1/socket using the first-frame auth protocol.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import os
|
||
|
|
import random
|
||
|
|
import statistics
|
||
|
|
import sys
|
||
|
|
import time
|
||
|
|
import uuid
|
||
|
|
from collections import defaultdict
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
import websockets
|
||
|
|
|
||
|
|
|
||
|
|
DEFAULT_TEXT = (
|
||
|
|
"What caused the space shuttle to explode and what were the "
|
||
|
|
"main factors leading to the disaster?"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class Stats:
|
||
|
|
"""Per-service rolling counters and latency samples."""
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self.sent = 0
|
||
|
|
self.ok = 0
|
||
|
|
self.err = 0
|
||
|
|
self.timeout = 0
|
||
|
|
self.latencies_ms: list[float] = []
|
||
|
|
|
||
|
|
def record_ok(self, latency_ms: float) -> None:
|
||
|
|
self.ok += 1
|
||
|
|
self.latencies_ms.append(latency_ms)
|
||
|
|
|
||
|
|
def record_err(self) -> None:
|
||
|
|
self.err += 1
|
||
|
|
|
||
|
|
def record_timeout(self) -> None:
|
||
|
|
self.timeout += 1
|
||
|
|
|
||
|
|
def percentile(self, p: float) -> float:
|
||
|
|
if not self.latencies_ms:
|
||
|
|
return 0.0
|
||
|
|
s = sorted(self.latencies_ms)
|
||
|
|
idx = min(len(s) - 1, int(len(s) * p))
|
||
|
|
return s[idx]
|
||
|
|
|
||
|
|
def summary(self) -> str:
|
||
|
|
if self.latencies_ms:
|
||
|
|
mn = min(self.latencies_ms)
|
||
|
|
mx = max(self.latencies_ms)
|
||
|
|
mean = statistics.mean(self.latencies_ms)
|
||
|
|
p50 = self.percentile(0.50)
|
||
|
|
p95 = self.percentile(0.95)
|
||
|
|
p99 = self.percentile(0.99)
|
||
|
|
lat = (
|
||
|
|
f"min={mn:.0f} mean={mean:.0f} p50={p50:.0f} "
|
||
|
|
f"p95={p95:.0f} p99={p99:.0f} max={mx:.0f} ms"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
lat = "no successful samples"
|
||
|
|
return (
|
||
|
|
f"sent={self.sent} ok={self.ok} err={self.err} "
|
||
|
|
f"timeout={self.timeout} | {lat}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class WSClient:
|
||
|
|
"""Thin async websocket client with first-frame auth and a shared
|
||
|
|
reader task that demuxes responses to per-request asyncio queues."""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self, url: str, token: str | None, workspace: str,
|
||
|
|
ping_timeout: int,
|
||
|
|
) -> None:
|
||
|
|
self.url = url
|
||
|
|
self.token = token
|
||
|
|
self.workspace = workspace
|
||
|
|
self.ping_timeout = ping_timeout
|
||
|
|
self._ws: Any = None
|
||
|
|
self._pending: dict[str, asyncio.Queue] = {}
|
||
|
|
self._reader_task: asyncio.Task | None = None
|
||
|
|
self._closed = asyncio.Event()
|
||
|
|
|
||
|
|
async def connect(self) -> None:
|
||
|
|
ws_url = self.url.rstrip("/") + "/api/v1/socket"
|
||
|
|
if ws_url.startswith("http://"):
|
||
|
|
ws_url = "ws://" + ws_url[len("http://"):]
|
||
|
|
elif ws_url.startswith("https://"):
|
||
|
|
ws_url = "wss://" + ws_url[len("https://"):]
|
||
|
|
elif not (
|
||
|
|
ws_url.startswith("ws://") or ws_url.startswith("wss://")
|
||
|
|
):
|
||
|
|
ws_url = "ws://" + ws_url
|
||
|
|
|
||
|
|
self._ws = await websockets.connect(
|
||
|
|
ws_url,
|
||
|
|
ping_interval=20,
|
||
|
|
ping_timeout=self.ping_timeout,
|
||
|
|
max_size=64 * 1024 * 1024,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self.token:
|
||
|
|
# First-frame auth handshake.
|
||
|
|
await self._ws.send(json.dumps({
|
||
|
|
"type": "auth", "token": self.token,
|
||
|
|
}))
|
||
|
|
raw = await asyncio.wait_for(self._ws.recv(), timeout=10)
|
||
|
|
resp = json.loads(raw)
|
||
|
|
if resp.get("type") != "auth-ok":
|
||
|
|
await self._ws.close()
|
||
|
|
raise RuntimeError(f"auth failed: {resp}")
|
||
|
|
if "workspace" in resp:
|
||
|
|
# Server-resolved workspace overrides the user-supplied
|
||
|
|
# one, mirroring AsyncSocketClient behaviour.
|
||
|
|
self.workspace = resp["workspace"]
|
||
|
|
else:
|
||
|
|
print(
|
||
|
|
"WARNING: no token provided — skipping auth handshake. "
|
||
|
|
"Requests will be rejected unless the gateway is "
|
||
|
|
"running without IAM enforcement.",
|
||
|
|
file=sys.stderr,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._reader_task = asyncio.create_task(self._reader())
|
||
|
|
|
||
|
|
async def _reader(self) -> None:
|
||
|
|
try:
|
||
|
|
async for raw in self._ws:
|
||
|
|
msg = json.loads(raw)
|
||
|
|
rid = msg.get("id")
|
||
|
|
if rid and rid in self._pending:
|
||
|
|
await self._pending[rid].put(msg)
|
||
|
|
except websockets.exceptions.ConnectionClosed:
|
||
|
|
pass
|
||
|
|
except Exception as e:
|
||
|
|
for q in list(self._pending.values()):
|
||
|
|
try:
|
||
|
|
q.put_nowait({"error": {"message": str(e)}})
|
||
|
|
except Exception:
|
||
|
|
pass
|
||
|
|
finally:
|
||
|
|
self._closed.set()
|
||
|
|
|
||
|
|
async def request(
|
||
|
|
self, service: str, flow: str | None, body: dict, timeout: float,
|
||
|
|
) -> tuple[dict | None, str | None, float]:
|
||
|
|
"""Send one request, await final response.
|
||
|
|
|
||
|
|
Returns ``(response, error, latency_ms)``. ``response`` is None
|
||
|
|
on error/timeout. ``error`` describes the failure category.
|
||
|
|
"""
|
||
|
|
rid = str(uuid.uuid4())
|
||
|
|
q: asyncio.Queue = asyncio.Queue()
|
||
|
|
self._pending[rid] = q
|
||
|
|
env = {
|
||
|
|
"id": rid,
|
||
|
|
"workspace": self.workspace,
|
||
|
|
"service": service,
|
||
|
|
"request": body,
|
||
|
|
}
|
||
|
|
if flow:
|
||
|
|
env["flow"] = flow
|
||
|
|
|
||
|
|
t0 = time.monotonic()
|
||
|
|
try:
|
||
|
|
await self._ws.send(json.dumps(env))
|
||
|
|
while True:
|
||
|
|
try:
|
||
|
|
msg = await asyncio.wait_for(q.get(), timeout=timeout)
|
||
|
|
except asyncio.TimeoutError:
|
||
|
|
return None, "timeout", (time.monotonic() - t0) * 1000
|
||
|
|
if "error" in msg and msg["error"]:
|
||
|
|
err = msg["error"]
|
||
|
|
err_msg = (
|
||
|
|
err.get("message") if isinstance(err, dict) else str(err)
|
||
|
|
)
|
||
|
|
return None, f"error: {err_msg}", (time.monotonic() - t0) * 1000
|
||
|
|
if msg.get("complete"):
|
||
|
|
return msg.get("response"), None, (time.monotonic() - t0) * 1000
|
||
|
|
# Otherwise an intermediate streaming chunk — keep waiting.
|
||
|
|
finally:
|
||
|
|
self._pending.pop(rid, None)
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
if self._ws is not None:
|
||
|
|
await self._ws.close()
|
||
|
|
if self._reader_task is not None:
|
||
|
|
try:
|
||
|
|
await asyncio.wait_for(self._reader_task, timeout=2)
|
||
|
|
except (asyncio.TimeoutError, asyncio.CancelledError):
|
||
|
|
pass
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args() -> argparse.Namespace:
|
||
|
|
p = argparse.ArgumentParser(description=__doc__)
|
||
|
|
p.add_argument(
|
||
|
|
"--url",
|
||
|
|
default=os.getenv("TRUSTGRAPH_URL", "http://localhost:8088/"),
|
||
|
|
help="Gateway URL (http or ws). Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--token",
|
||
|
|
default=os.getenv("TRUSTGRAPH_TOKEN"),
|
||
|
|
help="Auth token (or set TRUSTGRAPH_TOKEN). Optional — if "
|
||
|
|
"omitted, the auth handshake is skipped (only works "
|
||
|
|
"when the gateway is running without IAM enforcement).",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--workspace", default="default",
|
||
|
|
help="Workspace. Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--flow", required=True,
|
||
|
|
help="Flow id. Comma-separated for round-robin across flows "
|
||
|
|
"(e.g. onto-rag,doc-rag).",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--duration", type=int, default=60,
|
||
|
|
help="Test duration in seconds. Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--concurrency", type=int, default=15,
|
||
|
|
help="Target in-flight request count. Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--services",
|
||
|
|
default="embeddings,graph-embeddings,triples",
|
||
|
|
help="Comma-separated services to exercise. "
|
||
|
|
"Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--limit", type=int, default=3,
|
||
|
|
help="limit for triples / graph-embeddings queries. "
|
||
|
|
"Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--collection", default="default",
|
||
|
|
help="Collection. Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--text", default=DEFAULT_TEXT,
|
||
|
|
help="Text to embed for embeddings/seed.",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--vector-dim", type=int, default=384,
|
||
|
|
help="Dimension of synthetic vector when --no-seed is used. "
|
||
|
|
"Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--no-seed", action="store_true",
|
||
|
|
help="Skip the embeddings warm-up call. Use a random vector "
|
||
|
|
"for graph-embeddings queries instead.",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--request-timeout", type=float, default=30.0,
|
||
|
|
help="Per-request timeout (seconds). Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--report-interval", type=float, default=5.0,
|
||
|
|
help="How often to print stats (seconds). Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--ping-timeout", type=int, default=120,
|
||
|
|
help="Websocket ping timeout. Default: %(default)s",
|
||
|
|
)
|
||
|
|
p.add_argument(
|
||
|
|
"--seed", type=int, default=None,
|
||
|
|
help="Random seed (for reproducibility).",
|
||
|
|
)
|
||
|
|
return p.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
async def seed_vector(
|
||
|
|
client: WSClient, flow: str, text: str, timeout: float,
|
||
|
|
) -> list[float]:
|
||
|
|
"""Issue one embeddings request to obtain a real vector that
|
||
|
|
later graph-embeddings calls can reuse."""
|
||
|
|
resp, err, _ = await client.request(
|
||
|
|
"embeddings", flow, {"texts": [text]}, timeout,
|
||
|
|
)
|
||
|
|
if err or not resp:
|
||
|
|
raise RuntimeError(f"seed embeddings failed: {err or resp}")
|
||
|
|
vectors = resp.get("vectors")
|
||
|
|
if not vectors:
|
||
|
|
raise RuntimeError(f"seed embeddings: no vectors in response: {resp}")
|
||
|
|
return vectors[0]
|
||
|
|
|
||
|
|
|
||
|
|
def make_request_body(
|
||
|
|
service: str, args: argparse.Namespace, vector: list[float],
|
||
|
|
) -> dict:
|
||
|
|
if service == "embeddings":
|
||
|
|
return {"texts": [args.text]}
|
||
|
|
if service == "graph-embeddings":
|
||
|
|
return {
|
||
|
|
"vector": vector,
|
||
|
|
"limit": args.limit,
|
||
|
|
"collection": args.collection,
|
||
|
|
}
|
||
|
|
if service == "triples":
|
||
|
|
return {
|
||
|
|
"limit": args.limit,
|
||
|
|
"collection": args.collection,
|
||
|
|
}
|
||
|
|
raise ValueError(f"Unknown service: {service}")
|
||
|
|
|
||
|
|
|
||
|
|
async def worker(
|
||
|
|
name: int,
|
||
|
|
client: WSClient,
|
||
|
|
flows: list[str],
|
||
|
|
services: list[str],
|
||
|
|
args: argparse.Namespace,
|
||
|
|
vector: list[float],
|
||
|
|
stats: dict[str, Stats],
|
||
|
|
in_flight: dict[str, int],
|
||
|
|
stop_at: float,
|
||
|
|
) -> None:
|
||
|
|
rng = random.Random((args.seed or 0) + name)
|
||
|
|
while time.monotonic() < stop_at:
|
||
|
|
svc = rng.choice(services)
|
||
|
|
flow = rng.choice(flows)
|
||
|
|
body = make_request_body(svc, args, vector)
|
||
|
|
|
||
|
|
stats[svc].sent += 1
|
||
|
|
in_flight[svc] += 1
|
||
|
|
try:
|
||
|
|
resp, err, lat = await client.request(
|
||
|
|
svc, flow, body, args.request_timeout,
|
||
|
|
)
|
||
|
|
if err == "timeout":
|
||
|
|
stats[svc].record_timeout()
|
||
|
|
elif err:
|
||
|
|
stats[svc].record_err()
|
||
|
|
else:
|
||
|
|
stats[svc].record_ok(lat)
|
||
|
|
except Exception as e:
|
||
|
|
stats[svc].record_err()
|
||
|
|
print(f"worker {name}: unexpected {svc} exception: {e}",
|
||
|
|
file=sys.stderr)
|
||
|
|
finally:
|
||
|
|
in_flight[svc] -= 1
|
||
|
|
|
||
|
|
|
||
|
|
async def reporter(
|
||
|
|
services: list[str],
|
||
|
|
stats: dict[str, Stats],
|
||
|
|
in_flight: dict[str, int],
|
||
|
|
stop_at: float,
|
||
|
|
interval: float,
|
||
|
|
) -> None:
|
||
|
|
started = time.monotonic()
|
||
|
|
last_sent = {s: 0 for s in services}
|
||
|
|
while time.monotonic() < stop_at:
|
||
|
|
await asyncio.sleep(interval)
|
||
|
|
now = time.monotonic()
|
||
|
|
elapsed = now - started
|
||
|
|
total_inflight = sum(in_flight.values())
|
||
|
|
print(
|
||
|
|
f"\n[{elapsed:6.1f}s] in-flight={total_inflight} "
|
||
|
|
f"per-svc={dict(in_flight)}"
|
||
|
|
)
|
||
|
|
for svc in services:
|
||
|
|
s = stats[svc]
|
||
|
|
delta = s.sent - last_sent[svc]
|
||
|
|
rate = delta / interval
|
||
|
|
last_sent[svc] = s.sent
|
||
|
|
print(f" {svc:20s} {rate:6.1f}/s | {s.summary()}")
|
||
|
|
|
||
|
|
|
||
|
|
async def run(args: argparse.Namespace) -> int:
|
||
|
|
if args.seed is not None:
|
||
|
|
random.seed(args.seed)
|
||
|
|
|
||
|
|
services = [s.strip() for s in args.services.split(",") if s.strip()]
|
||
|
|
flows = [f.strip() for f in args.flow.split(",") if f.strip()]
|
||
|
|
valid = {"embeddings", "graph-embeddings", "triples"}
|
||
|
|
bad = [s for s in services if s not in valid]
|
||
|
|
if bad:
|
||
|
|
print(f"ERROR: unknown service(s): {bad}. "
|
||
|
|
f"Supported: {sorted(valid)}", file=sys.stderr)
|
||
|
|
return 2
|
||
|
|
|
||
|
|
client = WSClient(
|
||
|
|
args.url, args.token, args.workspace, args.ping_timeout,
|
||
|
|
)
|
||
|
|
print(f"Connecting to {args.url} ...")
|
||
|
|
await client.connect()
|
||
|
|
print(f"Connected. workspace={client.workspace} flows={flows} "
|
||
|
|
f"services={services} concurrency={args.concurrency} "
|
||
|
|
f"duration={args.duration}s")
|
||
|
|
|
||
|
|
if "graph-embeddings" in services and not args.no_seed:
|
||
|
|
print("Seeding embedding vector ...")
|
||
|
|
vector = await seed_vector(
|
||
|
|
client, flows[0], args.text, args.request_timeout,
|
||
|
|
)
|
||
|
|
print(f"Got vector of length {len(vector)}")
|
||
|
|
else:
|
||
|
|
vector = [random.uniform(-1.0, 1.0) for _ in range(args.vector_dim)]
|
||
|
|
|
||
|
|
stats: dict[str, Stats] = defaultdict(Stats)
|
||
|
|
in_flight: dict[str, int] = defaultdict(int)
|
||
|
|
for svc in services:
|
||
|
|
stats[svc] # initialise
|
||
|
|
in_flight[svc] = 0
|
||
|
|
|
||
|
|
stop_at = time.monotonic() + args.duration
|
||
|
|
print(f"Starting load: {args.concurrency} workers for "
|
||
|
|
f"{args.duration}s ...")
|
||
|
|
|
||
|
|
workers = [
|
||
|
|
asyncio.create_task(
|
||
|
|
worker(
|
||
|
|
i, client, flows, services, args, vector,
|
||
|
|
stats, in_flight, stop_at,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
for i in range(args.concurrency)
|
||
|
|
]
|
||
|
|
rep = asyncio.create_task(
|
||
|
|
reporter(services, stats, in_flight, stop_at, args.report_interval)
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
await asyncio.gather(*workers)
|
||
|
|
finally:
|
||
|
|
rep.cancel()
|
||
|
|
try:
|
||
|
|
await rep
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
pass
|
||
|
|
|
||
|
|
print("\n=== Final results ===")
|
||
|
|
any_failures = False
|
||
|
|
for svc in services:
|
||
|
|
s = stats[svc]
|
||
|
|
print(f" {svc:20s} {s.summary()}")
|
||
|
|
if s.timeout > 0 or s.err > 0:
|
||
|
|
any_failures = True
|
||
|
|
|
||
|
|
await client.close()
|
||
|
|
|
||
|
|
return 1 if any_failures else 0
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> int:
|
||
|
|
args = parse_args()
|
||
|
|
try:
|
||
|
|
return asyncio.run(run(args))
|
||
|
|
except KeyboardInterrupt:
|
||
|
|
return 130
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
sys.exit(main())
|