plano/tests/parity/signals/run_parity.py
Syed Hashmi 12b6b3d814 style: format parity harness with black
Made-with: Cursor
2026-04-22 21:34:23 -07:00

332 lines
11 KiB
Python

#!/usr/bin/env python3
"""
Parity harness driver.
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
reference analyzer (in-process) and the Rust port (subprocess), writes both
reports to disk for `compare.py` to diff.
Usage:
python run_parity.py \\
--num-samples 2000 \\
--seed 42 \\
--dataset-revision <hf-revision-sha> \\
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
import hashlib
import json
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterator, List
try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print(
"error: install dependencies first: pip install -r requirements.txt",
file=sys.stderr,
)
sys.exit(2)
try:
from signals.analyzer import SignalAnalyzer
except ImportError:
print(
"error: the python `signals` package is not installed. "
"install it from your local checkout: pip install -e /path/to/signals",
file=sys.stderr,
)
sys.exit(2)
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
DATASET_NAME = "lmsys/lmsys-chat-1m"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--num-samples", type=int, default=2000)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--dataset-revision",
default=None,
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
)
p.add_argument(
"--rust-binary",
type=Path,
required=True,
help="path to the `signals_replay` binary built from crates/brightstaff",
)
p.add_argument(
"--output-dir",
type=Path,
default=Path("out"),
help="directory to write the conversations + both runners' outputs",
)
p.add_argument(
"--max-conv-messages",
type=int,
default=200,
help="drop conversations with more than this many messages (the analyzer "
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
)
return p.parse_args()
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
"""
out = []
for m in conversation:
role = m.get("role", "")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content) if content is not None else ""
if role == "user":
from_ = "human"
elif role == "assistant":
from_ = "gpt"
else:
# lmsys is human/assistant only; skip anything else defensively.
continue
out.append({"from": from_, "value": content})
return out
def _list_parquet_files(revision: str | None) -> List[str]:
"""Return the list of parquet shard paths in the dataset repo."""
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
return sorted(f for f in files if f.endswith(".parquet"))
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
"""Download each parquet shard to the HF cache, return local paths."""
local: List[Path] = []
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
p = hf_hub_download(
DATASET_NAME,
filename=rel,
repo_type="dataset",
revision=revision,
)
local.append(Path(p))
return local
def sample_conversations(
*,
num_samples: int,
seed: int,
revision: str | None,
max_conv_messages: int,
) -> Iterator[Dict[str, Any]]:
"""Yield `num_samples` conversations sampled uniformly across the dataset.
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
and read the parquet shards directly via pyarrow.
"""
print(
f"listing {DATASET_NAME}"
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
file=sys.stderr,
)
shard_paths = _list_parquet_files(revision)
if not shard_paths:
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
local_paths = _download_shards(shard_paths, revision)
# Collect row counts without reading data.
shard_row_counts: List[int] = []
for p in local_paths:
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
file=sys.stderr,
)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
# Bucket indices by shard.
by_shard: Dict[int, List[int]] = {}
cumulative = 0
shard_offsets = []
for c in shard_row_counts:
shard_offsets.append(cumulative)
cumulative += c
for gi in global_indices:
# Find which shard this index belongs to.
for si, off in enumerate(shard_offsets):
if gi < off + shard_row_counts[si]:
by_shard.setdefault(si, []).append(gi - off)
break
yielded = 0
for si in sorted(by_shard.keys()):
local_rows = by_shard[si]
pf = pq.ParquetFile(str(local_paths[si]))
table = pf.read(columns=["conversation"])
conv_col = table.column("conversation")
for local_idx in local_rows:
raw = conv_col[local_idx].as_py()
if not raw:
continue
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
if len(conversation) > max_conv_messages:
continue
messages = lmsys_to_sharegpt(conversation)
if not messages:
continue
global_idx = shard_offsets[si] + local_idx
yield {
"id": f"lmsys-{global_idx}",
"messages": messages,
}
yielded += 1
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
n = 0
with out_path.open("w") as f:
for s in tqdm(samples, desc="sampling", unit="convo"):
f.write(json.dumps(s, ensure_ascii=False))
f.write("\n")
n += 1
return n
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
t0 = time.monotonic()
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
proc = subprocess.run(
[str(rust_binary)],
stdin=fin,
stdout=fout,
stderr=subprocess.PIPE,
check=False,
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
raise SystemExit(f"rust runner exited {proc.returncode}")
elapsed = time.monotonic() - t0
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
def run_python(conv_path: Path, out_path: Path) -> None:
print("running python analyzer...", file=sys.stderr)
t0 = time.monotonic()
analyzer = SignalAnalyzer()
with conv_path.open() as fin, out_path.open("w") as fout:
for line in tqdm(fin, desc="python", unit="convo"):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
report = analyzer.analyze(obj["messages"])
fout.write(
json.dumps(
{"id": obj["id"], "report": report.to_dict()},
ensure_ascii=False,
)
)
except Exception as e:
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
fout.write("\n")
elapsed = time.monotonic() - t0
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
)
.decode()
.strip()
)
except Exception:
plano_sha = "unknown"
try:
signals_version = subprocess.check_output(
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(
l.split(":", 1)[1].strip()
for l in signals_version.splitlines()
if l.startswith("Version")
),
"unknown",
)
except Exception:
signals_version = "unknown"
meta = {
"dataset_name": DATASET_NAME,
"dataset_revision": args.dataset_revision,
"seed": args.seed,
"num_samples_requested": args.num_samples,
"num_samples_actual": n_samples,
"rust_binary": str(args.rust_binary.resolve()),
"rust_binary_sha256": binary_sha,
"plano_git_sha": plano_sha,
"signals_python_version": signals_version,
"max_conv_messages": args.max_conv_messages,
}
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
if not args.rust_binary.exists():
raise SystemExit(f"rust binary not found at {args.rust_binary}")
conv_path = args.output_dir / "conversations.jsonl"
rust_path = args.output_dir / "rust_reports.jsonl"
py_path = args.output_dir / "python_reports.jsonl"
samples = sample_conversations(
num_samples=args.num_samples,
seed=args.seed,
revision=args.dataset_revision,
max_conv_messages=args.max_conv_messages,
)
n = write_conversations(conv_path, samples)
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
run_rust(args.rust_binary, conv_path, rust_path)
run_python(conv_path, py_path)
stamp_metadata(args, args.output_dir, n)
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
if __name__ == "__main__":
main()