plano/tests/parity/signals/run_parity.py
Syed A. Hashmi c8079ac971
signals: feature parity with the latest Signals paper. Porting logic from python repo (#903)
* signals: port to layered taxonomy with dual-emit OTel

Made-with: Cursor

* fix: silence collapsible_match clippy lint (rustc 1.95)

Made-with: Cursor

* test: parity harness for rust vs python signals analyzer

Validates the brightstaff signals port against the katanemo/signals Python
reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python-
compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle
bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch
CI job.

Made-with: Cursor

* Remove signals test from the gitops flow

* style: format parity harness with black

Made-with: Cursor

* signals: group summary by taxonomy, factor misalignment_ratio

Addresses #903 review feedback from @nehcgs:

- generate_summary() now renders explicit Interaction / Execution /
  Environment headers so the paper taxonomy is visible at a glance,
  even when no signals fired in a given layer. Quality-driving callouts
  (high misalignment rate, looping detected, escalation requested) are
  appended after the layer summary as an alerts tail.

- repair_ratio (legacy taxonomy name) renamed to misalignment_ratio
  and factored into a single InteractionSignals::misalignment_ratio()
  helper so assess_quality and generate_summary share one source of
  truth instead of recomputing the same divide twice.

Two new unit tests pin the layer headers and the (sev N) severity
suffix. Parity with the python reference is preserved at the Tier-A
level (per-type counts + overall_quality); only the human-readable
summary string diverges, which the parity comparator already classifies
as Tier-C.

Made-with: Cursor
2026-04-23 12:02:30 -07:00

332 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Parity harness driver.
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
reference analyzer (in-process) and the Rust port (subprocess), writes both
reports to disk for `compare.py` to diff.
Usage:
python run_parity.py \\
--num-samples 2000 \\
--seed 42 \\
--dataset-revision <hf-revision-sha> \\
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
import hashlib
import json
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterator, List
try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print(
"error: install dependencies first: pip install -r requirements.txt",
file=sys.stderr,
)
sys.exit(2)
try:
from signals.analyzer import SignalAnalyzer
except ImportError:
print(
"error: the python `signals` package is not installed. "
"install it from your local checkout: pip install -e /path/to/signals",
file=sys.stderr,
)
sys.exit(2)
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
DATASET_NAME = "lmsys/lmsys-chat-1m"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--num-samples", type=int, default=2000)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--dataset-revision",
default=None,
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
)
p.add_argument(
"--rust-binary",
type=Path,
required=True,
help="path to the `signals_replay` binary built from crates/brightstaff",
)
p.add_argument(
"--output-dir",
type=Path,
default=Path("out"),
help="directory to write the conversations + both runners' outputs",
)
p.add_argument(
"--max-conv-messages",
type=int,
default=200,
help="drop conversations with more than this many messages (the analyzer "
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
)
return p.parse_args()
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
"""
out = []
for m in conversation:
role = m.get("role", "")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content) if content is not None else ""
if role == "user":
from_ = "human"
elif role == "assistant":
from_ = "gpt"
else:
# lmsys is human/assistant only; skip anything else defensively.
continue
out.append({"from": from_, "value": content})
return out
def _list_parquet_files(revision: str | None) -> List[str]:
"""Return the list of parquet shard paths in the dataset repo."""
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
return sorted(f for f in files if f.endswith(".parquet"))
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
"""Download each parquet shard to the HF cache, return local paths."""
local: List[Path] = []
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
p = hf_hub_download(
DATASET_NAME,
filename=rel,
repo_type="dataset",
revision=revision,
)
local.append(Path(p))
return local
def sample_conversations(
*,
num_samples: int,
seed: int,
revision: str | None,
max_conv_messages: int,
) -> Iterator[Dict[str, Any]]:
"""Yield `num_samples` conversations sampled uniformly across the dataset.
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
and read the parquet shards directly via pyarrow.
"""
print(
f"listing {DATASET_NAME}"
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
file=sys.stderr,
)
shard_paths = _list_parquet_files(revision)
if not shard_paths:
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
local_paths = _download_shards(shard_paths, revision)
# Collect row counts without reading data.
shard_row_counts: List[int] = []
for p in local_paths:
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
file=sys.stderr,
)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
# Bucket indices by shard.
by_shard: Dict[int, List[int]] = {}
cumulative = 0
shard_offsets = []
for c in shard_row_counts:
shard_offsets.append(cumulative)
cumulative += c
for gi in global_indices:
# Find which shard this index belongs to.
for si, off in enumerate(shard_offsets):
if gi < off + shard_row_counts[si]:
by_shard.setdefault(si, []).append(gi - off)
break
yielded = 0
for si in sorted(by_shard.keys()):
local_rows = by_shard[si]
pf = pq.ParquetFile(str(local_paths[si]))
table = pf.read(columns=["conversation"])
conv_col = table.column("conversation")
for local_idx in local_rows:
raw = conv_col[local_idx].as_py()
if not raw:
continue
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
if len(conversation) > max_conv_messages:
continue
messages = lmsys_to_sharegpt(conversation)
if not messages:
continue
global_idx = shard_offsets[si] + local_idx
yield {
"id": f"lmsys-{global_idx}",
"messages": messages,
}
yielded += 1
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
n = 0
with out_path.open("w") as f:
for s in tqdm(samples, desc="sampling", unit="convo"):
f.write(json.dumps(s, ensure_ascii=False))
f.write("\n")
n += 1
return n
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
t0 = time.monotonic()
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
proc = subprocess.run(
[str(rust_binary)],
stdin=fin,
stdout=fout,
stderr=subprocess.PIPE,
check=False,
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
raise SystemExit(f"rust runner exited {proc.returncode}")
elapsed = time.monotonic() - t0
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
def run_python(conv_path: Path, out_path: Path) -> None:
print("running python analyzer...", file=sys.stderr)
t0 = time.monotonic()
analyzer = SignalAnalyzer()
with conv_path.open() as fin, out_path.open("w") as fout:
for line in tqdm(fin, desc="python", unit="convo"):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
report = analyzer.analyze(obj["messages"])
fout.write(
json.dumps(
{"id": obj["id"], "report": report.to_dict()},
ensure_ascii=False,
)
)
except Exception as e:
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
fout.write("\n")
elapsed = time.monotonic() - t0
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
)
.decode()
.strip()
)
except Exception:
plano_sha = "unknown"
try:
signals_version = subprocess.check_output(
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(
l.split(":", 1)[1].strip()
for l in signals_version.splitlines()
if l.startswith("Version")
),
"unknown",
)
except Exception:
signals_version = "unknown"
meta = {
"dataset_name": DATASET_NAME,
"dataset_revision": args.dataset_revision,
"seed": args.seed,
"num_samples_requested": args.num_samples,
"num_samples_actual": n_samples,
"rust_binary": str(args.rust_binary.resolve()),
"rust_binary_sha256": binary_sha,
"plano_git_sha": plano_sha,
"signals_python_version": signals_version,
"max_conv_messages": args.max_conv_messages,
}
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
if not args.rust_binary.exists():
raise SystemExit(f"rust binary not found at {args.rust_binary}")
conv_path = args.output_dir / "conversations.jsonl"
rust_path = args.output_dir / "rust_reports.jsonl"
py_path = args.output_dir / "python_reports.jsonl"
samples = sample_conversations(
num_samples=args.num_samples,
seed=args.seed,
revision=args.dataset_revision,
max_conv_messages=args.max_conv_messages,
)
n = write_conversations(conv_path, samples)
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
run_rust(args.rust_binary, conv_path, rust_path)
run_python(conv_path, py_path)
stamp_metadata(args, args.output_dir, n)
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
if __name__ == "__main__":
main()