test: parity harness for rust vs python signals analyzer

Validates the brightstaff signals port against the katanemo/signals Python
reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python-
compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle
bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch
CI job.

Made-with: Cursor
This commit is contained in:
Syed Hashmi 2026-04-22 12:28:22 -07:00
parent cb2fdc9dbf
commit ffaa4da22c
No known key found for this signature in database
9 changed files with 1118 additions and 0 deletions

97
.github/workflows/parity-signals.yml vendored Normal file
View file

@ -0,0 +1,97 @@
name: parity-signals
# On-demand parity validation of the Rust signals port against the Python
# reference (https://github.com/katanemo/signals). Not run on every PR
# because it downloads several GB of dataset content.
on:
workflow_dispatch:
inputs:
num_samples:
description: "Number of conversations to sample from lmsys-chat-1m"
required: true
default: "2000"
seed:
description: "Sampling seed (use the same value for reproducibility)"
required: true
default: "42"
dataset_revision:
description: "HF dataset revision (commit SHA). Empty = latest (NOT pinned)."
required: false
default: ""
signals_ref:
description: "Git ref of katanemo/signals to install"
required: true
default: "main"
permissions:
contents: read
jobs:
parity:
runs-on: ubuntu-latest
timeout-minutes: 90
steps:
- uses: actions/checkout@v6
- name: Install Rust (stable)
uses: dtolnay/rust-toolchain@stable
- name: Cache cargo
uses: Swatinem/rust-cache@v2
with:
workspaces: crates
- name: Build signals_replay
working-directory: crates
run: cargo build --release -p brightstaff --bin signals_replay
- uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install harness deps
working-directory: tests/parity/signals
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install "signals @ git+https://github.com/katanemo/signals@${{ inputs.signals_ref }}"
- name: Authenticate Hugging Face (lmsys-chat-1m is gated)
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
if [ -z "$HF_TOKEN" ]; then
echo "::error::HF_TOKEN secret is required to download lmsys-chat-1m"
exit 1
fi
mkdir -p ~/.cache/huggingface
echo -n "$HF_TOKEN" > ~/.cache/huggingface/token
- name: Run parity harness
working-directory: tests/parity/signals
env:
DATASET_REV: ${{ inputs.dataset_revision }}
run: |
ARGS=(
--num-samples "${{ inputs.num_samples }}"
--seed "${{ inputs.seed }}"
--rust-binary "${GITHUB_WORKSPACE}/crates/target/release/signals_replay"
--output-dir out/
)
if [ -n "$DATASET_REV" ]; then
ARGS+=(--dataset-revision "$DATASET_REV")
fi
python run_parity.py "${ARGS[@]}"
- name: Compare reports
working-directory: tests/parity/signals
run: python compare.py --output-dir out/
- name: Upload artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: signals-parity-out
path: tests/parity/signals/out/
if-no-files-found: warn

View file

@ -3,6 +3,14 @@ name = "brightstaff"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
[[bin]]
name = "brightstaff"
path = "src/main.rs"
[[bin]]
name = "signals_replay"
path = "src/bin/signals_replay.rs"
[dependencies] [dependencies]
async-openai = "0.30.1" async-openai = "0.30.1"
async-trait = "0.1" async-trait = "0.1"

View file

@ -0,0 +1,175 @@
//! `signals-replay` — batch driver for the `brightstaff` signal analyzer.
//!
//! Reads JSONL conversations from stdin (one per line) and emits matching
//! JSONL reports on stdout, one per input conversation, in the same order.
//!
//! Input shape (per line):
//! ```json
//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]}
//! ```
//!
//! Output shape (per line, success):
//! ```json
//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }}
//! ```
//!
//! On per-line failure (parse / analyzer error), emits:
//! ```json
//! {"id": "convo-42", "error": "..."}
//! ```
//!
//! The output report dict is shaped to match the Python reference's
//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a
//! direct structural diff.
use std::io::{self, BufRead, BufWriter, Write};
use serde::Deserialize;
use serde_json::{json, Map, Value};
use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport};
#[derive(Debug, Deserialize)]
struct InputLine {
id: Value,
messages: Vec<MessageRow>,
}
#[derive(Debug, Deserialize)]
struct MessageRow {
#[serde(default)]
from: String,
#[serde(default)]
value: String,
}
fn main() {
let stdin = io::stdin();
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
let analyzer = SignalAnalyzer::default();
for line in stdin.lock().lines() {
let line = match line {
Ok(l) => l,
Err(e) => {
eprintln!("read error: {e}");
std::process::exit(1);
}
};
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
let result = process_line(&analyzer, trimmed);
// Always emit one line per input line so id ordering stays aligned.
if let Err(e) = writeln!(out, "{result}") {
eprintln!("write error: {e}");
std::process::exit(1);
}
// Flush periodically isn't strictly needed — BufWriter handles it,
// and the parent process reads the whole stream when we're done.
}
let _ = out.flush();
}
fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value {
let parsed: InputLine = match serde_json::from_str(line) {
Ok(p) => p,
Err(e) => {
return json!({
"id": Value::Null,
"error": format!("input parse: {e}"),
});
}
};
let id = parsed.id.clone();
let view: Vec<brightstaff::signals::analyzer::ShareGptMessage<'_>> = parsed
.messages
.iter()
.map(|m| brightstaff::signals::analyzer::ShareGptMessage {
from: m.from.as_str(),
value: m.value.as_str(),
})
.collect();
let report = analyzer.analyze_sharegpt(&view);
let report_dict = report_to_python_dict(&report);
json!({
"id": id,
"report": report_dict,
})
}
/// Convert a `SignalReport` into the Python reference's `to_dict()` shape.
///
/// Ordering of category keys in each layer dict follows the Python source
/// exactly so even string-equality comparisons behave deterministically.
fn report_to_python_dict(r: &SignalReport) -> Value {
let mut interaction = Map::new();
interaction.insert(
"misalignment".to_string(),
signal_group_to_python(&r.interaction.misalignment),
);
interaction.insert(
"stagnation".to_string(),
signal_group_to_python(&r.interaction.stagnation),
);
interaction.insert(
"disengagement".to_string(),
signal_group_to_python(&r.interaction.disengagement),
);
interaction.insert(
"satisfaction".to_string(),
signal_group_to_python(&r.interaction.satisfaction),
);
let mut execution = Map::new();
execution.insert(
"failure".to_string(),
signal_group_to_python(&r.execution.failure),
);
execution.insert(
"loops".to_string(),
signal_group_to_python(&r.execution.loops),
);
let mut environment = Map::new();
environment.insert(
"exhaustion".to_string(),
signal_group_to_python(&r.environment.exhaustion),
);
json!({
"interaction_signals": Value::Object(interaction),
"execution_signals": Value::Object(execution),
"environment_signals": Value::Object(environment),
"overall_quality": r.overall_quality.as_str(),
"summary": r.summary,
})
}
fn signal_group_to_python(g: &SignalGroup) -> Value {
let signals: Vec<Value> = g
.signals
.iter()
.map(|s| {
json!({
"signal_type": s.signal_type.as_str(),
"message_index": s.message_index,
"snippet": s.snippet,
"confidence": s.confidence,
"metadata": s.metadata,
})
})
.collect();
json!({
"category": g.category,
"count": g.count,
"severity": g.severity,
"signals": signals,
})
}

4
tests/parity/signals/.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
out/
.venv/
__pycache__/
*.pyc

View file

@ -0,0 +1,98 @@
# Signals Parity Harness
Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same
`SignalReport` as the Python reference at <https://github.com/katanemo/signals>
on a fixed sample of `lmsys/lmsys-chat-1m` conversations.
This harness is **not** part of normal CI. It downloads several GB and is run
on demand to gate releases of the signals subsystem (or to investigate
regressions reported in production).
## What gets compared
For each conversation, both analyzers emit a `SignalReport`. The comparator
classifies any divergence into three tiers:
| Tier | Field | Action on divergence |
|------|------------------------------------------------|----------------------|
| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run |
| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail |
| C | metadata, snippet text, summary | Information only |
Quality buckets are compared by string (`excellent` / `good` / ...).
## What this harness does *not* cover
`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction**
layer well (misalignment, stagnation, disengagement, satisfaction) but does
**not** exercise:
- `execution.failure.*`
- `execution.loops.*`
- `environment.exhaustion.*`
Those signals require `function_call` / `observation` ShareGPT roles. They are
covered by the Rust unit tests and the Python repo's own test fixtures, both
of which run on every PR. A synthetic tool-trace dataset for full coverage is
deferred to a follow-up.
## One-time setup
```bash
# 1. Build the Rust replay binary.
cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay
# 2. Set up the Python environment for the harness driver.
cd ../tests/parity/signals
python3 -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
# 3. Install the Python signals reference.
# Either point at a local checkout:
pip install -e /path/to/signals
# or pull from git:
pip install 'signals @ git+https://github.com/katanemo/signals@<sha>'
```
## Running
```bash
source .venv/bin/activate
python run_parity.py \
--num-samples 2000 \
--seed 42 \
--dataset-revision <hf-dataset-revision-sha> \
--rust-binary ../../../crates/target/release/signals_replay \
--output-dir out/
python compare.py --output-dir out/
```
`run_parity.py` will:
1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`).
2. Pick `--num-samples` rows under `--seed`.
3. Convert each to ShareGPT, write `out/conversations.jsonl`.
4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`.
5. Run the Python analyzer in-process → `out/python_reports.jsonl`.
`compare.py` reads both report files and writes:
- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff
- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix
- `out/summary.md` — human-readable PR-ready report
Exit code is non-zero iff any Tier-A divergence is observed.
## Reproducibility
Every run pins:
- `dataset_revision` — the HF dataset commit
- `seed` — RNG seed for sampling
- `signals_python_version``pip show signals` version
- `plano_git_sha``git rev-parse HEAD` of this repo
- `signals_replay_binary_sha256` — the hash of the Rust bin
All are stamped into `metrics.json`.

View file

@ -0,0 +1,97 @@
#!/usr/bin/env python3
"""
Local smoke test for the parity harness runs both runners on a tiny
hand-picked set of conversations without touching the lmsys dataset.
Run from this directory:
python _smoke_test.py --rust-binary <path>
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from signals.analyzer import SignalAnalyzer
SAMPLES = [
{
"id": "smoke-gratitude",
"messages": [
{"from": "human", "value": "What is the weather in Istanbul?"},
{"from": "gpt", "value": "Istanbul is 14C and partly cloudy."},
{"from": "human", "value": "That worked, exactly what I needed. Thanks!"},
],
},
{
"id": "smoke-escalation",
"messages": [
{"from": "human", "value": "This isn't helpful at all"},
{"from": "gpt", "value": "I'm sorry, can you tell me more?"},
{"from": "human", "value": "Get me a human, this is useless"},
],
},
{
"id": "smoke-correction",
"messages": [
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
{"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"},
],
},
{
"id": "smoke-clean",
"messages": [
{"from": "human", "value": "Hi"},
{"from": "gpt", "value": "Hello, how can I help?"},
],
},
{
"id": "smoke-rephrase",
"messages": [
{"from": "human", "value": "Can you summarize the news please"},
{"from": "gpt", "value": "Sure, here is a summary."},
{"from": "human", "value": "Could you please summarize the news"},
],
},
]
def main() -> int:
p = argparse.ArgumentParser()
p.add_argument("--rust-binary", required=True, type=Path)
args = p.parse_args()
out_dir = Path("out_smoke")
out_dir.mkdir(exist_ok=True)
conv_path = out_dir / "conversations.jsonl"
rust_path = out_dir / "rust_reports.jsonl"
py_path = out_dir / "python_reports.jsonl"
with conv_path.open("w") as f:
for s in SAMPLES:
f.write(json.dumps(s) + "\n")
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
return 2
analyzer = SignalAnalyzer()
with conv_path.open() as fin, py_path.open("w") as fout:
for line in fin:
obj = json.loads(line)
r = analyzer.analyze(obj["messages"])
fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n")
rc = subprocess.call(
[sys.executable, "compare.py", "--output-dir", str(out_dir)],
)
return rc
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,320 @@
#!/usr/bin/env python3
"""
Diff Rust vs Python signal reports produced by run_parity.py.
See README.md for the tier definitions. Exits non-zero iff any Tier-A
divergence is found.
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
CATEGORIES_BY_LAYER = {
"interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"],
"execution_signals": ["failure", "loops"],
"environment_signals": ["exhaustion"],
}
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--output-dir", type=Path, default=Path("out"))
return p.parse_args()
def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]:
"""Load a JSONL file keyed by `id`. Lines with errors are still indexed."""
out: Dict[str, Dict[str, Any]] = {}
with path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
out[str(obj.get("id"))] = obj
return out
def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]:
"""Return {signal_type: count} across all groups in a report dict."""
counts: Counter[str] = Counter()
for layer in CATEGORIES_BY_LAYER:
groups = report.get(layer, {}) or {}
for category in CATEGORIES_BY_LAYER[layer]:
group = groups.get(category)
if not group:
continue
for sig in group.get("signals", []) or []:
counts[sig["signal_type"]] += 1
return dict(counts)
def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
out: Dict[str, List[int]] = defaultdict(list)
for layer in CATEGORIES_BY_LAYER:
groups = report.get(layer, {}) or {}
for category in CATEGORIES_BY_LAYER[layer]:
group = groups.get(category)
if not group:
continue
for sig in group.get("signals", []) or []:
out[sig["signal_type"]].append(sig.get("message_index"))
for k in out:
out[k].sort(key=lambda x: (x is None, x))
return dict(out)
def diff_counts(
a: Dict[str, int], b: Dict[str, int]
) -> List[Tuple[str, int, int]]:
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
keys = set(a) | set(b)
out = []
for k in sorted(keys):
ac = a.get(k, 0)
bc = b.get(k, 0)
if ac != bc:
out.append((k, ac, bc))
return out
def diff_indices(
a: Dict[str, List[int]], b: Dict[str, List[int]]
) -> List[Tuple[str, List[int], List[int]]]:
keys = set(a) | set(b)
out = []
for k in sorted(keys):
ai = a.get(k, [])
bi = b.get(k, [])
if ai != bi:
out.append((k, ai, bi))
return out
def compare_one(
convo_id: str, py: Dict[str, Any], rust: Dict[str, Any]
) -> Dict[str, Any] | None:
"""Compare a single conversation. Return diff record, or None if identical."""
if "error" in py or "error" in rust:
return {
"id": convo_id,
"tier": "A",
"kind": "error_in_runner",
"python_error": py.get("error"),
"rust_error": rust.get("error"),
}
py_report = py["report"]
rust_report = rust["report"]
py_counts = per_type_counts(py_report)
rust_counts = per_type_counts(rust_report)
count_diff = diff_counts(py_counts, rust_counts)
py_quality = py_report.get("overall_quality")
rust_quality = rust_report.get("overall_quality")
quality_mismatch = py_quality != rust_quality
if count_diff or quality_mismatch:
return {
"id": convo_id,
"tier": "A",
"kind": "signal_or_quality_mismatch",
"quality": {"python": py_quality, "rust": rust_quality},
"count_diff": [
{"signal_type": st, "python": pc, "rust": rc}
for (st, pc, rc) in count_diff
],
}
py_idx = per_type_indices(py_report)
rust_idx = per_type_indices(rust_report)
idx_diff = diff_indices(py_idx, rust_idx)
if idx_diff:
return {
"id": convo_id,
"tier": "B",
"kind": "instance_index_mismatch",
"diff": [
{"signal_type": st, "python_indices": pi, "rust_indices": ri}
for (st, pi, ri) in idx_diff
],
}
return None
def confusion_matrix(
pairs: List[Tuple[str, str]], labels: List[str]
) -> Dict[str, Dict[str, int]]:
cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels}
for py, rust in pairs:
if py not in cm:
cm[py] = {b: 0 for b in labels}
if rust not in cm[py]:
cm[py][rust] = 0
cm[py][rust] += 1
return cm
def main() -> int:
args = parse_args()
out_dir = args.output_dir
py_reports = load_jsonl(out_dir / "python_reports.jsonl")
rust_reports = load_jsonl(out_dir / "rust_reports.jsonl")
common_ids = sorted(set(py_reports) & set(rust_reports))
only_py = sorted(set(py_reports) - set(rust_reports))
only_rust = sorted(set(rust_reports) - set(py_reports))
diffs: List[Dict[str, Any]] = []
quality_pairs: List[Tuple[str, str]] = []
per_type_total = Counter()
per_type_disagree = Counter()
tier_a = 0
tier_b = 0
for cid in common_ids:
d = compare_one(cid, py_reports[cid], rust_reports[cid])
if d is None:
quality_pairs.append(
(
py_reports[cid]["report"]["overall_quality"],
rust_reports[cid]["report"]["overall_quality"],
)
)
for st, _ in per_type_counts(py_reports[cid]["report"]).items():
per_type_total[st] += 1
else:
diffs.append(d)
if d["tier"] == "A":
tier_a += 1
elif d["tier"] == "B":
tier_b += 1
if "report" in py_reports[cid] and "report" in rust_reports[cid]:
quality_pairs.append(
(
py_reports[cid]["report"].get("overall_quality", "?"),
rust_reports[cid]["report"].get("overall_quality", "?"),
)
)
for cd in d.get("count_diff", []) or []:
per_type_disagree[cd["signal_type"]] += 1
per_type_total[cd["signal_type"]] += 1
n_total = len(common_ids)
n_match = n_total - len(diffs)
agreement = (n_match / n_total) if n_total else 0.0
quality_labels = ["excellent", "good", "neutral", "poor", "severe"]
cm = confusion_matrix(quality_pairs, quality_labels)
metrics = {
"n_python_reports": len(py_reports),
"n_rust_reports": len(rust_reports),
"n_common": n_total,
"n_only_python": len(only_py),
"n_only_rust": len(only_rust),
"n_full_match": n_match,
"agreement_pct": round(100.0 * agreement, 4),
"tier_a_divergences": tier_a,
"tier_b_divergences": tier_b,
"quality_confusion_matrix": cm,
"per_signal_type_total": dict(per_type_total),
"per_signal_type_disagree": dict(per_type_disagree),
}
# Pull in run metadata if present.
rm_path = out_dir / "run_metadata.json"
if rm_path.exists():
metrics["run_metadata"] = json.loads(rm_path.read_text())
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
with (out_dir / "diffs.jsonl").open("w") as f:
for d in diffs:
f.write(json.dumps(d, ensure_ascii=False))
f.write("\n")
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2))
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
print(f"summary: {out_dir / 'summary.md'}")
if tier_a > 0:
print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr)
return 1
return 0
def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None:
lines: List[str] = []
lines.append("# Signals Parity Report")
lines.append("")
rm = metrics.get("run_metadata", {})
if rm:
lines.append("## Run metadata")
lines.append("")
for k in (
"dataset_name",
"dataset_revision",
"seed",
"num_samples_actual",
"plano_git_sha",
"signals_python_version",
"rust_binary_sha256",
):
if k in rm:
lines.append(f"- **{k}**: `{rm[k]}`")
lines.append("")
lines.append("## Summary")
lines.append("")
lines.append(f"- Conversations compared: **{metrics['n_common']}**")
lines.append(f"- Full matches: **{metrics['n_full_match']}**")
lines.append(f"- Agreement: **{metrics['agreement_pct']}%**")
lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**")
lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**")
lines.append("")
lines.append("## Per-signal-type disagreement")
lines.append("")
lines.append("| Signal type | Total reports | Disagreements |")
lines.append("|---|---:|---:|")
totals = metrics["per_signal_type_total"]
disagrees = metrics["per_signal_type_disagree"]
for k in sorted(set(totals) | set(disagrees)):
lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |")
lines.append("")
lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)")
lines.append("")
cm = metrics["quality_confusion_matrix"]
labels = list(cm.keys())
lines.append("| | " + " | ".join(labels) + " |")
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
for r in labels:
lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |")
lines.append("")
if sample_diffs:
lines.append("## Sample divergences (first 20)")
lines.append("")
for d in sample_diffs:
lines.append(f"### `{d['id']}` — tier {d['tier']}{d['kind']}")
lines.append("")
lines.append("```json")
lines.append(json.dumps(d, indent=2))
lines.append("```")
lines.append("")
path.write_text("\n".join(lines))
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,3 @@
huggingface_hub>=0.25
pyarrow>=15
tqdm>=4.65

View file

@ -0,0 +1,316 @@
#!/usr/bin/env python3
"""
Parity harness driver.
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
reference analyzer (in-process) and the Rust port (subprocess), writes both
reports to disk for `compare.py` to diff.
Usage:
python run_parity.py \\
--num-samples 2000 \\
--seed 42 \\
--dataset-revision <hf-revision-sha> \\
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
import hashlib
import json
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterator, List
try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
sys.exit(2)
try:
from signals.analyzer import SignalAnalyzer
except ImportError:
print(
"error: the python `signals` package is not installed. "
"install it from your local checkout: pip install -e /path/to/signals",
file=sys.stderr,
)
sys.exit(2)
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
DATASET_NAME = "lmsys/lmsys-chat-1m"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--num-samples", type=int, default=2000)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--dataset-revision",
default=None,
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
)
p.add_argument(
"--rust-binary",
type=Path,
required=True,
help="path to the `signals_replay` binary built from crates/brightstaff",
)
p.add_argument(
"--output-dir",
type=Path,
default=Path("out"),
help="directory to write the conversations + both runners' outputs",
)
p.add_argument(
"--max-conv-messages",
type=int,
default=200,
help="drop conversations with more than this many messages (the analyzer "
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
)
return p.parse_args()
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
"""
out = []
for m in conversation:
role = m.get("role", "")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content) if content is not None else ""
if role == "user":
from_ = "human"
elif role == "assistant":
from_ = "gpt"
else:
# lmsys is human/assistant only; skip anything else defensively.
continue
out.append({"from": from_, "value": content})
return out
def _list_parquet_files(revision: str | None) -> List[str]:
"""Return the list of parquet shard paths in the dataset repo."""
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
return sorted(f for f in files if f.endswith(".parquet"))
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
"""Download each parquet shard to the HF cache, return local paths."""
local: List[Path] = []
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
p = hf_hub_download(
DATASET_NAME,
filename=rel,
repo_type="dataset",
revision=revision,
)
local.append(Path(p))
return local
def sample_conversations(
*,
num_samples: int,
seed: int,
revision: str | None,
max_conv_messages: int,
) -> Iterator[Dict[str, Any]]:
"""Yield `num_samples` conversations sampled uniformly across the dataset.
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
and read the parquet shards directly via pyarrow.
"""
print(
f"listing {DATASET_NAME}"
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
file=sys.stderr,
)
shard_paths = _list_parquet_files(revision)
if not shard_paths:
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
local_paths = _download_shards(shard_paths, revision)
# Collect row counts without reading data.
shard_row_counts: List[int] = []
for p in local_paths:
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
# Bucket indices by shard.
by_shard: Dict[int, List[int]] = {}
cumulative = 0
shard_offsets = []
for c in shard_row_counts:
shard_offsets.append(cumulative)
cumulative += c
for gi in global_indices:
# Find which shard this index belongs to.
for si, off in enumerate(shard_offsets):
if gi < off + shard_row_counts[si]:
by_shard.setdefault(si, []).append(gi - off)
break
yielded = 0
for si in sorted(by_shard.keys()):
local_rows = by_shard[si]
pf = pq.ParquetFile(str(local_paths[si]))
table = pf.read(columns=["conversation"])
conv_col = table.column("conversation")
for local_idx in local_rows:
raw = conv_col[local_idx].as_py()
if not raw:
continue
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
if len(conversation) > max_conv_messages:
continue
messages = lmsys_to_sharegpt(conversation)
if not messages:
continue
global_idx = shard_offsets[si] + local_idx
yield {
"id": f"lmsys-{global_idx}",
"messages": messages,
}
yielded += 1
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
n = 0
with out_path.open("w") as f:
for s in tqdm(samples, desc="sampling", unit="convo"):
f.write(json.dumps(s, ensure_ascii=False))
f.write("\n")
n += 1
return n
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
t0 = time.monotonic()
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
proc = subprocess.run(
[str(rust_binary)],
stdin=fin,
stdout=fout,
stderr=subprocess.PIPE,
check=False,
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
raise SystemExit(f"rust runner exited {proc.returncode}")
elapsed = time.monotonic() - t0
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
def run_python(conv_path: Path, out_path: Path) -> None:
print("running python analyzer...", file=sys.stderr)
t0 = time.monotonic()
analyzer = SignalAnalyzer()
with conv_path.open() as fin, out_path.open("w") as fout:
for line in tqdm(fin, desc="python", unit="convo"):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
report = analyzer.analyze(obj["messages"])
fout.write(
json.dumps(
{"id": obj["id"], "report": report.to_dict()},
ensure_ascii=False,
)
)
except Exception as e:
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
fout.write("\n")
elapsed = time.monotonic() - t0
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
).decode().strip()
except Exception:
plano_sha = "unknown"
try:
signals_version = subprocess.check_output(
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
"unknown",
)
except Exception:
signals_version = "unknown"
meta = {
"dataset_name": DATASET_NAME,
"dataset_revision": args.dataset_revision,
"seed": args.seed,
"num_samples_requested": args.num_samples,
"num_samples_actual": n_samples,
"rust_binary": str(args.rust_binary.resolve()),
"rust_binary_sha256": binary_sha,
"plano_git_sha": plano_sha,
"signals_python_version": signals_version,
"max_conv_messages": args.max_conv_messages,
}
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
if not args.rust_binary.exists():
raise SystemExit(f"rust binary not found at {args.rust_binary}")
conv_path = args.output_dir / "conversations.jsonl"
rust_path = args.output_dir / "rust_reports.jsonl"
py_path = args.output_dir / "python_reports.jsonl"
samples = sample_conversations(
num_samples=args.num_samples,
seed=args.seed,
revision=args.dataset_revision,
max_conv_messages=args.max_conv_messages,
)
n = write_conversations(conv_path, samples)
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
run_rust(args.rust_binary, conv_path, rust_path)
run_python(conv_path, py_path)
stamp_metadata(args, args.output_dir, n)
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
if __name__ == "__main__":
main()