From ffaa4da22cafac3f4187fd683f4301f87b529029 Mon Sep 17 00:00:00 2001 From: Syed Hashmi Date: Wed, 22 Apr 2026 12:28:22 -0700 Subject: [PATCH] test: parity harness for rust vs python signals analyzer Validates the brightstaff signals port against the katanemo/signals Python reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python- compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch CI job. Made-with: Cursor --- .github/workflows/parity-signals.yml | 97 ++++++ crates/brightstaff/Cargo.toml | 8 + crates/brightstaff/src/bin/signals_replay.rs | 175 ++++++++++ tests/parity/signals/.gitignore | 4 + tests/parity/signals/README.md | 98 ++++++ tests/parity/signals/_smoke_test.py | 97 ++++++ tests/parity/signals/compare.py | 320 +++++++++++++++++++ tests/parity/signals/requirements.txt | 3 + tests/parity/signals/run_parity.py | 316 ++++++++++++++++++ 9 files changed, 1118 insertions(+) create mode 100644 .github/workflows/parity-signals.yml create mode 100644 crates/brightstaff/src/bin/signals_replay.rs create mode 100644 tests/parity/signals/.gitignore create mode 100644 tests/parity/signals/README.md create mode 100644 tests/parity/signals/_smoke_test.py create mode 100644 tests/parity/signals/compare.py create mode 100644 tests/parity/signals/requirements.txt create mode 100644 tests/parity/signals/run_parity.py diff --git a/.github/workflows/parity-signals.yml b/.github/workflows/parity-signals.yml new file mode 100644 index 00000000..716f4665 --- /dev/null +++ b/.github/workflows/parity-signals.yml @@ -0,0 +1,97 @@ +name: parity-signals + +# On-demand parity validation of the Rust signals port against the Python +# reference (https://github.com/katanemo/signals). Not run on every PR +# because it downloads several GB of dataset content. + +on: + workflow_dispatch: + inputs: + num_samples: + description: "Number of conversations to sample from lmsys-chat-1m" + required: true + default: "2000" + seed: + description: "Sampling seed (use the same value for reproducibility)" + required: true + default: "42" + dataset_revision: + description: "HF dataset revision (commit SHA). Empty = latest (NOT pinned)." + required: false + default: "" + signals_ref: + description: "Git ref of katanemo/signals to install" + required: true + default: "main" + +permissions: + contents: read + +jobs: + parity: + runs-on: ubuntu-latest + timeout-minutes: 90 + steps: + - uses: actions/checkout@v6 + + - name: Install Rust (stable) + uses: dtolnay/rust-toolchain@stable + + - name: Cache cargo + uses: Swatinem/rust-cache@v2 + with: + workspaces: crates + + - name: Build signals_replay + working-directory: crates + run: cargo build --release -p brightstaff --bin signals_replay + + - uses: actions/setup-python@v6 + with: + python-version: "3.11" + + - name: Install harness deps + working-directory: tests/parity/signals + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + pip install "signals @ git+https://github.com/katanemo/signals@${{ inputs.signals_ref }}" + + - name: Authenticate Hugging Face (lmsys-chat-1m is gated) + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + if [ -z "$HF_TOKEN" ]; then + echo "::error::HF_TOKEN secret is required to download lmsys-chat-1m" + exit 1 + fi + mkdir -p ~/.cache/huggingface + echo -n "$HF_TOKEN" > ~/.cache/huggingface/token + + - name: Run parity harness + working-directory: tests/parity/signals + env: + DATASET_REV: ${{ inputs.dataset_revision }} + run: | + ARGS=( + --num-samples "${{ inputs.num_samples }}" + --seed "${{ inputs.seed }}" + --rust-binary "${GITHUB_WORKSPACE}/crates/target/release/signals_replay" + --output-dir out/ + ) + if [ -n "$DATASET_REV" ]; then + ARGS+=(--dataset-revision "$DATASET_REV") + fi + python run_parity.py "${ARGS[@]}" + + - name: Compare reports + working-directory: tests/parity/signals + run: python compare.py --output-dir out/ + + - name: Upload artifacts + if: always() + uses: actions/upload-artifact@v4 + with: + name: signals-parity-out + path: tests/parity/signals/out/ + if-no-files-found: warn diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index edbcb994..0de89a72 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -3,6 +3,14 @@ name = "brightstaff" version = "0.1.0" edition = "2021" +[[bin]] +name = "brightstaff" +path = "src/main.rs" + +[[bin]] +name = "signals_replay" +path = "src/bin/signals_replay.rs" + [dependencies] async-openai = "0.30.1" async-trait = "0.1" diff --git a/crates/brightstaff/src/bin/signals_replay.rs b/crates/brightstaff/src/bin/signals_replay.rs new file mode 100644 index 00000000..41879ac1 --- /dev/null +++ b/crates/brightstaff/src/bin/signals_replay.rs @@ -0,0 +1,175 @@ +//! `signals-replay` — batch driver for the `brightstaff` signal analyzer. +//! +//! Reads JSONL conversations from stdin (one per line) and emits matching +//! JSONL reports on stdout, one per input conversation, in the same order. +//! +//! Input shape (per line): +//! ```json +//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]} +//! ``` +//! +//! Output shape (per line, success): +//! ```json +//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }} +//! ``` +//! +//! On per-line failure (parse / analyzer error), emits: +//! ```json +//! {"id": "convo-42", "error": "..."} +//! ``` +//! +//! The output report dict is shaped to match the Python reference's +//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a +//! direct structural diff. + +use std::io::{self, BufRead, BufWriter, Write}; + +use serde::Deserialize; +use serde_json::{json, Map, Value}; + +use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport}; + +#[derive(Debug, Deserialize)] +struct InputLine { + id: Value, + messages: Vec, +} + +#[derive(Debug, Deserialize)] +struct MessageRow { + #[serde(default)] + from: String, + #[serde(default)] + value: String, +} + +fn main() { + let stdin = io::stdin(); + let stdout = io::stdout(); + let mut out = BufWriter::new(stdout.lock()); + let analyzer = SignalAnalyzer::default(); + + for line in stdin.lock().lines() { + let line = match line { + Ok(l) => l, + Err(e) => { + eprintln!("read error: {e}"); + std::process::exit(1); + } + }; + let trimmed = line.trim(); + if trimmed.is_empty() { + continue; + } + let result = process_line(&analyzer, trimmed); + // Always emit one line per input line so id ordering stays aligned. + if let Err(e) = writeln!(out, "{result}") { + eprintln!("write error: {e}"); + std::process::exit(1); + } + // Flush periodically isn't strictly needed — BufWriter handles it, + // and the parent process reads the whole stream when we're done. + } + let _ = out.flush(); +} + +fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value { + let parsed: InputLine = match serde_json::from_str(line) { + Ok(p) => p, + Err(e) => { + return json!({ + "id": Value::Null, + "error": format!("input parse: {e}"), + }); + } + }; + + let id = parsed.id.clone(); + + let view: Vec> = parsed + .messages + .iter() + .map(|m| brightstaff::signals::analyzer::ShareGptMessage { + from: m.from.as_str(), + value: m.value.as_str(), + }) + .collect(); + + let report = analyzer.analyze_sharegpt(&view); + let report_dict = report_to_python_dict(&report); + json!({ + "id": id, + "report": report_dict, + }) +} + +/// Convert a `SignalReport` into the Python reference's `to_dict()` shape. +/// +/// Ordering of category keys in each layer dict follows the Python source +/// exactly so even string-equality comparisons behave deterministically. +fn report_to_python_dict(r: &SignalReport) -> Value { + let mut interaction = Map::new(); + interaction.insert( + "misalignment".to_string(), + signal_group_to_python(&r.interaction.misalignment), + ); + interaction.insert( + "stagnation".to_string(), + signal_group_to_python(&r.interaction.stagnation), + ); + interaction.insert( + "disengagement".to_string(), + signal_group_to_python(&r.interaction.disengagement), + ); + interaction.insert( + "satisfaction".to_string(), + signal_group_to_python(&r.interaction.satisfaction), + ); + + let mut execution = Map::new(); + execution.insert( + "failure".to_string(), + signal_group_to_python(&r.execution.failure), + ); + execution.insert( + "loops".to_string(), + signal_group_to_python(&r.execution.loops), + ); + + let mut environment = Map::new(); + environment.insert( + "exhaustion".to_string(), + signal_group_to_python(&r.environment.exhaustion), + ); + + json!({ + "interaction_signals": Value::Object(interaction), + "execution_signals": Value::Object(execution), + "environment_signals": Value::Object(environment), + "overall_quality": r.overall_quality.as_str(), + "summary": r.summary, + }) +} + +fn signal_group_to_python(g: &SignalGroup) -> Value { + let signals: Vec = g + .signals + .iter() + .map(|s| { + json!({ + "signal_type": s.signal_type.as_str(), + "message_index": s.message_index, + "snippet": s.snippet, + "confidence": s.confidence, + "metadata": s.metadata, + }) + }) + .collect(); + + json!({ + "category": g.category, + "count": g.count, + "severity": g.severity, + "signals": signals, + }) +} diff --git a/tests/parity/signals/.gitignore b/tests/parity/signals/.gitignore new file mode 100644 index 00000000..3a7e0d4f --- /dev/null +++ b/tests/parity/signals/.gitignore @@ -0,0 +1,4 @@ +out/ +.venv/ +__pycache__/ +*.pyc diff --git a/tests/parity/signals/README.md b/tests/parity/signals/README.md new file mode 100644 index 00000000..67193d60 --- /dev/null +++ b/tests/parity/signals/README.md @@ -0,0 +1,98 @@ +# Signals Parity Harness + +Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same +`SignalReport` as the Python reference at +on a fixed sample of `lmsys/lmsys-chat-1m` conversations. + +This harness is **not** part of normal CI. It downloads several GB and is run +on demand to gate releases of the signals subsystem (or to investigate +regressions reported in production). + +## What gets compared + +For each conversation, both analyzers emit a `SignalReport`. The comparator +classifies any divergence into three tiers: + +| Tier | Field | Action on divergence | +|------|------------------------------------------------|----------------------| +| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run | +| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail | +| C | metadata, snippet text, summary | Information only | + +Quality buckets are compared by string (`excellent` / `good` / ...). + +## What this harness does *not* cover + +`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction** +layer well (misalignment, stagnation, disengagement, satisfaction) but does +**not** exercise: + +- `execution.failure.*` +- `execution.loops.*` +- `environment.exhaustion.*` + +Those signals require `function_call` / `observation` ShareGPT roles. They are +covered by the Rust unit tests and the Python repo's own test fixtures, both +of which run on every PR. A synthetic tool-trace dataset for full coverage is +deferred to a follow-up. + +## One-time setup + +```bash +# 1. Build the Rust replay binary. +cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay + +# 2. Set up the Python environment for the harness driver. +cd ../tests/parity/signals +python3 -m venv .venv && source .venv/bin/activate +pip install -r requirements.txt + +# 3. Install the Python signals reference. +# Either point at a local checkout: +pip install -e /path/to/signals +# or pull from git: +pip install 'signals @ git+https://github.com/katanemo/signals@' +``` + +## Running + +```bash +source .venv/bin/activate + +python run_parity.py \ + --num-samples 2000 \ + --seed 42 \ + --dataset-revision \ + --rust-binary ../../../crates/target/release/signals_replay \ + --output-dir out/ + +python compare.py --output-dir out/ +``` + +`run_parity.py` will: + +1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`). +2. Pick `--num-samples` rows under `--seed`. +3. Convert each to ShareGPT, write `out/conversations.jsonl`. +4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`. +5. Run the Python analyzer in-process → `out/python_reports.jsonl`. + +`compare.py` reads both report files and writes: + +- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff +- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix +- `out/summary.md` — human-readable PR-ready report + +Exit code is non-zero iff any Tier-A divergence is observed. + +## Reproducibility + +Every run pins: + +- `dataset_revision` — the HF dataset commit +- `seed` — RNG seed for sampling +- `signals_python_version` — `pip show signals` version +- `plano_git_sha` — `git rev-parse HEAD` of this repo +- `signals_replay_binary_sha256` — the hash of the Rust bin + +All are stamped into `metrics.json`. diff --git a/tests/parity/signals/_smoke_test.py b/tests/parity/signals/_smoke_test.py new file mode 100644 index 00000000..ccd753bb --- /dev/null +++ b/tests/parity/signals/_smoke_test.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +Local smoke test for the parity harness — runs both runners on a tiny +hand-picked set of conversations without touching the lmsys dataset. + +Run from this directory: + python _smoke_test.py --rust-binary +""" +from __future__ import annotations + +import argparse +import json +import subprocess +import sys +from pathlib import Path + +from signals.analyzer import SignalAnalyzer + +SAMPLES = [ + { + "id": "smoke-gratitude", + "messages": [ + {"from": "human", "value": "What is the weather in Istanbul?"}, + {"from": "gpt", "value": "Istanbul is 14C and partly cloudy."}, + {"from": "human", "value": "That worked, exactly what I needed. Thanks!"}, + ], + }, + { + "id": "smoke-escalation", + "messages": [ + {"from": "human", "value": "This isn't helpful at all"}, + {"from": "gpt", "value": "I'm sorry, can you tell me more?"}, + {"from": "human", "value": "Get me a human, this is useless"}, + ], + }, + { + "id": "smoke-correction", + "messages": [ + {"from": "human", "value": "Book me a flight to NYC for tomorrow"}, + {"from": "gpt", "value": "Sure, here are flights to NYC for Friday."}, + {"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"}, + ], + }, + { + "id": "smoke-clean", + "messages": [ + {"from": "human", "value": "Hi"}, + {"from": "gpt", "value": "Hello, how can I help?"}, + ], + }, + { + "id": "smoke-rephrase", + "messages": [ + {"from": "human", "value": "Can you summarize the news please"}, + {"from": "gpt", "value": "Sure, here is a summary."}, + {"from": "human", "value": "Could you please summarize the news"}, + ], + }, +] + + +def main() -> int: + p = argparse.ArgumentParser() + p.add_argument("--rust-binary", required=True, type=Path) + args = p.parse_args() + + out_dir = Path("out_smoke") + out_dir.mkdir(exist_ok=True) + conv_path = out_dir / "conversations.jsonl" + rust_path = out_dir / "rust_reports.jsonl" + py_path = out_dir / "python_reports.jsonl" + + with conv_path.open("w") as f: + for s in SAMPLES: + f.write(json.dumps(s) + "\n") + + with conv_path.open("rb") as fin, rust_path.open("wb") as fout: + proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE) + if proc.returncode != 0: + sys.stderr.write(proc.stderr.decode("utf-8", errors="replace")) + return 2 + + analyzer = SignalAnalyzer() + with conv_path.open() as fin, py_path.open("w") as fout: + for line in fin: + obj = json.loads(line) + r = analyzer.analyze(obj["messages"]) + fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n") + + rc = subprocess.call( + [sys.executable, "compare.py", "--output-dir", str(out_dir)], + ) + return rc + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/parity/signals/compare.py b/tests/parity/signals/compare.py new file mode 100644 index 00000000..427e5bba --- /dev/null +++ b/tests/parity/signals/compare.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +""" +Diff Rust vs Python signal reports produced by run_parity.py. + +See README.md for the tier definitions. Exits non-zero iff any Tier-A +divergence is found. +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter, defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +CATEGORIES_BY_LAYER = { + "interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"], + "execution_signals": ["failure", "loops"], + "environment_signals": ["exhaustion"], +} + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--output-dir", type=Path, default=Path("out")) + return p.parse_args() + + +def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]: + """Load a JSONL file keyed by `id`. Lines with errors are still indexed.""" + out: Dict[str, Dict[str, Any]] = {} + with path.open() as f: + for line in f: + line = line.strip() + if not line: + continue + obj = json.loads(line) + out[str(obj.get("id"))] = obj + return out + + +def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]: + """Return {signal_type: count} across all groups in a report dict.""" + counts: Counter[str] = Counter() + for layer in CATEGORIES_BY_LAYER: + groups = report.get(layer, {}) or {} + for category in CATEGORIES_BY_LAYER[layer]: + group = groups.get(category) + if not group: + continue + for sig in group.get("signals", []) or []: + counts[sig["signal_type"]] += 1 + return dict(counts) + + +def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]: + out: Dict[str, List[int]] = defaultdict(list) + for layer in CATEGORIES_BY_LAYER: + groups = report.get(layer, {}) or {} + for category in CATEGORIES_BY_LAYER[layer]: + group = groups.get(category) + if not group: + continue + for sig in group.get("signals", []) or []: + out[sig["signal_type"]].append(sig.get("message_index")) + for k in out: + out[k].sort(key=lambda x: (x is None, x)) + return dict(out) + + +def diff_counts( + a: Dict[str, int], b: Dict[str, int] +) -> List[Tuple[str, int, int]]: + """Return [(signal_type, a_count, b_count)] for entries that differ.""" + keys = set(a) | set(b) + out = [] + for k in sorted(keys): + ac = a.get(k, 0) + bc = b.get(k, 0) + if ac != bc: + out.append((k, ac, bc)) + return out + + +def diff_indices( + a: Dict[str, List[int]], b: Dict[str, List[int]] +) -> List[Tuple[str, List[int], List[int]]]: + keys = set(a) | set(b) + out = [] + for k in sorted(keys): + ai = a.get(k, []) + bi = b.get(k, []) + if ai != bi: + out.append((k, ai, bi)) + return out + + +def compare_one( + convo_id: str, py: Dict[str, Any], rust: Dict[str, Any] +) -> Dict[str, Any] | None: + """Compare a single conversation. Return diff record, or None if identical.""" + if "error" in py or "error" in rust: + return { + "id": convo_id, + "tier": "A", + "kind": "error_in_runner", + "python_error": py.get("error"), + "rust_error": rust.get("error"), + } + py_report = py["report"] + rust_report = rust["report"] + + py_counts = per_type_counts(py_report) + rust_counts = per_type_counts(rust_report) + count_diff = diff_counts(py_counts, rust_counts) + + py_quality = py_report.get("overall_quality") + rust_quality = rust_report.get("overall_quality") + quality_mismatch = py_quality != rust_quality + + if count_diff or quality_mismatch: + return { + "id": convo_id, + "tier": "A", + "kind": "signal_or_quality_mismatch", + "quality": {"python": py_quality, "rust": rust_quality}, + "count_diff": [ + {"signal_type": st, "python": pc, "rust": rc} + for (st, pc, rc) in count_diff + ], + } + + py_idx = per_type_indices(py_report) + rust_idx = per_type_indices(rust_report) + idx_diff = diff_indices(py_idx, rust_idx) + if idx_diff: + return { + "id": convo_id, + "tier": "B", + "kind": "instance_index_mismatch", + "diff": [ + {"signal_type": st, "python_indices": pi, "rust_indices": ri} + for (st, pi, ri) in idx_diff + ], + } + + return None + + +def confusion_matrix( + pairs: List[Tuple[str, str]], labels: List[str] +) -> Dict[str, Dict[str, int]]: + cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels} + for py, rust in pairs: + if py not in cm: + cm[py] = {b: 0 for b in labels} + if rust not in cm[py]: + cm[py][rust] = 0 + cm[py][rust] += 1 + return cm + + +def main() -> int: + args = parse_args() + out_dir = args.output_dir + + py_reports = load_jsonl(out_dir / "python_reports.jsonl") + rust_reports = load_jsonl(out_dir / "rust_reports.jsonl") + + common_ids = sorted(set(py_reports) & set(rust_reports)) + only_py = sorted(set(py_reports) - set(rust_reports)) + only_rust = sorted(set(rust_reports) - set(py_reports)) + + diffs: List[Dict[str, Any]] = [] + quality_pairs: List[Tuple[str, str]] = [] + per_type_total = Counter() + per_type_disagree = Counter() + + tier_a = 0 + tier_b = 0 + for cid in common_ids: + d = compare_one(cid, py_reports[cid], rust_reports[cid]) + if d is None: + quality_pairs.append( + ( + py_reports[cid]["report"]["overall_quality"], + rust_reports[cid]["report"]["overall_quality"], + ) + ) + for st, _ in per_type_counts(py_reports[cid]["report"]).items(): + per_type_total[st] += 1 + else: + diffs.append(d) + if d["tier"] == "A": + tier_a += 1 + elif d["tier"] == "B": + tier_b += 1 + if "report" in py_reports[cid] and "report" in rust_reports[cid]: + quality_pairs.append( + ( + py_reports[cid]["report"].get("overall_quality", "?"), + rust_reports[cid]["report"].get("overall_quality", "?"), + ) + ) + for cd in d.get("count_diff", []) or []: + per_type_disagree[cd["signal_type"]] += 1 + per_type_total[cd["signal_type"]] += 1 + + n_total = len(common_ids) + n_match = n_total - len(diffs) + agreement = (n_match / n_total) if n_total else 0.0 + + quality_labels = ["excellent", "good", "neutral", "poor", "severe"] + cm = confusion_matrix(quality_pairs, quality_labels) + + metrics = { + "n_python_reports": len(py_reports), + "n_rust_reports": len(rust_reports), + "n_common": n_total, + "n_only_python": len(only_py), + "n_only_rust": len(only_rust), + "n_full_match": n_match, + "agreement_pct": round(100.0 * agreement, 4), + "tier_a_divergences": tier_a, + "tier_b_divergences": tier_b, + "quality_confusion_matrix": cm, + "per_signal_type_total": dict(per_type_total), + "per_signal_type_disagree": dict(per_type_disagree), + } + + # Pull in run metadata if present. + rm_path = out_dir / "run_metadata.json" + if rm_path.exists(): + metrics["run_metadata"] = json.loads(rm_path.read_text()) + + (out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2)) + with (out_dir / "diffs.jsonl").open("w") as f: + for d in diffs: + f.write(json.dumps(d, ensure_ascii=False)) + f.write("\n") + + write_summary_md(out_dir / "summary.md", metrics, diffs[:20]) + + print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2)) + print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}") + print(f"summary: {out_dir / 'summary.md'}") + + if tier_a > 0: + print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr) + return 1 + return 0 + + +def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None: + lines: List[str] = [] + lines.append("# Signals Parity Report") + lines.append("") + rm = metrics.get("run_metadata", {}) + if rm: + lines.append("## Run metadata") + lines.append("") + for k in ( + "dataset_name", + "dataset_revision", + "seed", + "num_samples_actual", + "plano_git_sha", + "signals_python_version", + "rust_binary_sha256", + ): + if k in rm: + lines.append(f"- **{k}**: `{rm[k]}`") + lines.append("") + + lines.append("## Summary") + lines.append("") + lines.append(f"- Conversations compared: **{metrics['n_common']}**") + lines.append(f"- Full matches: **{metrics['n_full_match']}**") + lines.append(f"- Agreement: **{metrics['agreement_pct']}%**") + lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**") + lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**") + lines.append("") + + lines.append("## Per-signal-type disagreement") + lines.append("") + lines.append("| Signal type | Total reports | Disagreements |") + lines.append("|---|---:|---:|") + totals = metrics["per_signal_type_total"] + disagrees = metrics["per_signal_type_disagree"] + for k in sorted(set(totals) | set(disagrees)): + lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |") + lines.append("") + + lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)") + lines.append("") + cm = metrics["quality_confusion_matrix"] + labels = list(cm.keys()) + lines.append("| | " + " | ".join(labels) + " |") + lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|") + for r in labels: + lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |") + lines.append("") + + if sample_diffs: + lines.append("## Sample divergences (first 20)") + lines.append("") + for d in sample_diffs: + lines.append(f"### `{d['id']}` — tier {d['tier']} — {d['kind']}") + lines.append("") + lines.append("```json") + lines.append(json.dumps(d, indent=2)) + lines.append("```") + lines.append("") + + path.write_text("\n".join(lines)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/parity/signals/requirements.txt b/tests/parity/signals/requirements.txt new file mode 100644 index 00000000..7b25f179 --- /dev/null +++ b/tests/parity/signals/requirements.txt @@ -0,0 +1,3 @@ +huggingface_hub>=0.25 +pyarrow>=15 +tqdm>=4.65 diff --git a/tests/parity/signals/run_parity.py b/tests/parity/signals/run_parity.py new file mode 100644 index 00000000..bdb54966 --- /dev/null +++ b/tests/parity/signals/run_parity.py @@ -0,0 +1,316 @@ +#!/usr/bin/env python3 +""" +Parity harness driver. + +Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python +reference analyzer (in-process) and the Rust port (subprocess), writes both +reports to disk for `compare.py` to diff. + +Usage: + python run_parity.py \\ + --num-samples 2000 \\ + --seed 42 \\ + --dataset-revision \\ + --rust-binary ../../../crates/target/release/signals_replay \\ + --output-dir out/ +""" +from __future__ import annotations + +import argparse +import hashlib +import json +import random +import subprocess +import sys +import time +from pathlib import Path +from typing import Any, Dict, Iterator, List + +try: + import pyarrow.parquet as pq + from huggingface_hub import hf_hub_download, list_repo_files +except ImportError: + print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr) + sys.exit(2) + +try: + from signals.analyzer import SignalAnalyzer +except ImportError: + print( + "error: the python `signals` package is not installed. " + "install it from your local checkout: pip install -e /path/to/signals", + file=sys.stderr, + ) + sys.exit(2) + +try: + from tqdm import tqdm +except ImportError: + def tqdm(it, **_kwargs): # type: ignore[no-redef] + return it + + +DATASET_NAME = "lmsys/lmsys-chat-1m" + + +def parse_args() -> argparse.Namespace: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--num-samples", type=int, default=2000) + p.add_argument("--seed", type=int, default=42) + p.add_argument( + "--dataset-revision", + default=None, + help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)", + ) + p.add_argument( + "--rust-binary", + type=Path, + required=True, + help="path to the `signals_replay` binary built from crates/brightstaff", + ) + p.add_argument( + "--output-dir", + type=Path, + default=Path("out"), + help="directory to write the conversations + both runners' outputs", + ) + p.add_argument( + "--max-conv-messages", + type=int, + default=200, + help="drop conversations with more than this many messages (the analyzer " + "truncates to last 100 anyway; this is a sanity cap on input parsing)", + ) + return p.parse_args() + + +def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]: + """Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`. + + lmsys uses `user` / `assistant` (no tools, no system role in `conversation`). + """ + out = [] + for m in conversation: + role = m.get("role", "") + content = m.get("content", "") + if not isinstance(content, str): + content = str(content) if content is not None else "" + if role == "user": + from_ = "human" + elif role == "assistant": + from_ = "gpt" + else: + # lmsys is human/assistant only; skip anything else defensively. + continue + out.append({"from": from_, "value": content}) + return out + + +def _list_parquet_files(revision: str | None) -> List[str]: + """Return the list of parquet shard paths in the dataset repo.""" + files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision) + return sorted(f for f in files if f.endswith(".parquet")) + + +def _download_shards(paths: List[str], revision: str | None) -> List[Path]: + """Download each parquet shard to the HF cache, return local paths.""" + local: List[Path] = [] + for rel in tqdm(paths, desc="downloading shards", unit="shard"): + p = hf_hub_download( + DATASET_NAME, + filename=rel, + repo_type="dataset", + revision=revision, + ) + local.append(Path(p)) + return local + + +def sample_conversations( + *, + num_samples: int, + seed: int, + revision: str | None, + max_conv_messages: int, +) -> Iterator[Dict[str, Any]]: + """Yield `num_samples` conversations sampled uniformly across the dataset. + + We bypass the `datasets` loader (which has a Python 3.14 pickle issue) + and read the parquet shards directly via pyarrow. + """ + print( + f"listing {DATASET_NAME}" + f"{' @ ' + revision if revision else ' (no revision pinned!)'}", + file=sys.stderr, + ) + shard_paths = _list_parquet_files(revision) + if not shard_paths: + raise SystemExit(f"no parquet shards found for {DATASET_NAME}") + local_paths = _download_shards(shard_paths, revision) + + # Collect row counts without reading data. + shard_row_counts: List[int] = [] + for p in local_paths: + pf = pq.ParquetFile(str(p)) + shard_row_counts.append(pf.metadata.num_rows) + total_rows = sum(shard_row_counts) + print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr) + + rng = random.Random(seed) + global_indices = sorted(rng.sample(range(total_rows), num_samples)) + + # Bucket indices by shard. + by_shard: Dict[int, List[int]] = {} + cumulative = 0 + shard_offsets = [] + for c in shard_row_counts: + shard_offsets.append(cumulative) + cumulative += c + for gi in global_indices: + # Find which shard this index belongs to. + for si, off in enumerate(shard_offsets): + if gi < off + shard_row_counts[si]: + by_shard.setdefault(si, []).append(gi - off) + break + + yielded = 0 + for si in sorted(by_shard.keys()): + local_rows = by_shard[si] + pf = pq.ParquetFile(str(local_paths[si])) + table = pf.read(columns=["conversation"]) + conv_col = table.column("conversation") + for local_idx in local_rows: + raw = conv_col[local_idx].as_py() + if not raw: + continue + conversation = raw if isinstance(raw, list) else raw.get("conversation", []) + if len(conversation) > max_conv_messages: + continue + messages = lmsys_to_sharegpt(conversation) + if not messages: + continue + global_idx = shard_offsets[si] + local_idx + yield { + "id": f"lmsys-{global_idx}", + "messages": messages, + } + yielded += 1 + print(f"yielded {yielded} conversations after filtering", file=sys.stderr) + + +def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int: + n = 0 + with out_path.open("w") as f: + for s in tqdm(samples, desc="sampling", unit="convo"): + f.write(json.dumps(s, ensure_ascii=False)) + f.write("\n") + n += 1 + return n + + +def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None: + print(f"running rust analyzer: {rust_binary}", file=sys.stderr) + t0 = time.monotonic() + with conv_path.open("rb") as fin, out_path.open("wb") as fout: + proc = subprocess.run( + [str(rust_binary)], + stdin=fin, + stdout=fout, + stderr=subprocess.PIPE, + check=False, + ) + if proc.returncode != 0: + sys.stderr.write(proc.stderr.decode("utf-8", errors="replace")) + raise SystemExit(f"rust runner exited {proc.returncode}") + elapsed = time.monotonic() - t0 + print(f" rust runner: {elapsed:.1f}s", file=sys.stderr) + + +def run_python(conv_path: Path, out_path: Path) -> None: + print("running python analyzer...", file=sys.stderr) + t0 = time.monotonic() + analyzer = SignalAnalyzer() + with conv_path.open() as fin, out_path.open("w") as fout: + for line in tqdm(fin, desc="python", unit="convo"): + line = line.strip() + if not line: + continue + try: + obj = json.loads(line) + report = analyzer.analyze(obj["messages"]) + fout.write( + json.dumps( + {"id": obj["id"], "report": report.to_dict()}, + ensure_ascii=False, + ) + ) + except Exception as e: + fout.write(json.dumps({"id": obj.get("id"), "error": str(e)})) + fout.write("\n") + elapsed = time.monotonic() - t0 + print(f" python runner: {elapsed:.1f}s", file=sys.stderr) + + +def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None: + """Write the input metadata so compare.py can include it in the report.""" + binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest() + try: + plano_sha = subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent + ).decode().strip() + except Exception: + plano_sha = "unknown" + try: + signals_version = subprocess.check_output( + [sys.executable, "-m", "pip", "show", "signals"] + ).decode() + signals_version = next( + (l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")), + "unknown", + ) + except Exception: + signals_version = "unknown" + + meta = { + "dataset_name": DATASET_NAME, + "dataset_revision": args.dataset_revision, + "seed": args.seed, + "num_samples_requested": args.num_samples, + "num_samples_actual": n_samples, + "rust_binary": str(args.rust_binary.resolve()), + "rust_binary_sha256": binary_sha, + "plano_git_sha": plano_sha, + "signals_python_version": signals_version, + "max_conv_messages": args.max_conv_messages, + } + (output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2)) + print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr) + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + if not args.rust_binary.exists(): + raise SystemExit(f"rust binary not found at {args.rust_binary}") + + conv_path = args.output_dir / "conversations.jsonl" + rust_path = args.output_dir / "rust_reports.jsonl" + py_path = args.output_dir / "python_reports.jsonl" + + samples = sample_conversations( + num_samples=args.num_samples, + seed=args.seed, + revision=args.dataset_revision, + max_conv_messages=args.max_conv_messages, + ) + n = write_conversations(conv_path, samples) + print(f"wrote {n} conversations to {conv_path}", file=sys.stderr) + + run_rust(args.rust_binary, conv_path, rust_path) + run_python(conv_path, py_path) + stamp_metadata(args, args.output_dir, n) + print("done. now run: python compare.py --output-dir " + str(args.output_dir)) + + +if __name__ == "__main__": + main()