mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
signals: feature parity with the latest Signals paper. Porting logic from python repo (#903)
* signals: port to layered taxonomy with dual-emit OTel Made-with: Cursor * fix: silence collapsible_match clippy lint (rustc 1.95) Made-with: Cursor * test: parity harness for rust vs python signals analyzer Validates the brightstaff signals port against the katanemo/signals Python reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python- compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch CI job. Made-with: Cursor * Remove signals test from the gitops flow * style: format parity harness with black Made-with: Cursor * signals: group summary by taxonomy, factor misalignment_ratio Addresses #903 review feedback from @nehcgs: - generate_summary() now renders explicit Interaction / Execution / Environment headers so the paper taxonomy is visible at a glance, even when no signals fired in a given layer. Quality-driving callouts (high misalignment rate, looping detected, escalation requested) are appended after the layer summary as an alerts tail. - repair_ratio (legacy taxonomy name) renamed to misalignment_ratio and factored into a single InteractionSignals::misalignment_ratio() helper so assess_quality and generate_summary share one source of truth instead of recomputing the same divide twice. Two new unit tests pin the layer headers and the (sev N) severity suffix. Parity with the python reference is preserved at the Tier-A level (per-type counts + overall_quality); only the human-readable summary string diverges, which the parity comparator already classifies as Tier-C. Made-with: Cursor
This commit is contained in:
parent
6701195a5d
commit
c8079ac971
31 changed files with 5246 additions and 3261 deletions
4
tests/parity/signals/.gitignore
vendored
Normal file
4
tests/parity/signals/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
out/
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
98
tests/parity/signals/README.md
Normal file
98
tests/parity/signals/README.md
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
# Signals Parity Harness
|
||||
|
||||
Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same
|
||||
`SignalReport` as the Python reference at <https://github.com/katanemo/signals>
|
||||
on a fixed sample of `lmsys/lmsys-chat-1m` conversations.
|
||||
|
||||
This harness is **not** part of normal CI. It downloads several GB and is run
|
||||
on demand to gate releases of the signals subsystem (or to investigate
|
||||
regressions reported in production).
|
||||
|
||||
## What gets compared
|
||||
|
||||
For each conversation, both analyzers emit a `SignalReport`. The comparator
|
||||
classifies any divergence into three tiers:
|
||||
|
||||
| Tier | Field | Action on divergence |
|
||||
|------|------------------------------------------------|----------------------|
|
||||
| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run |
|
||||
| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail |
|
||||
| C | metadata, snippet text, summary | Information only |
|
||||
|
||||
Quality buckets are compared by string (`excellent` / `good` / ...).
|
||||
|
||||
## What this harness does *not* cover
|
||||
|
||||
`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction**
|
||||
layer well (misalignment, stagnation, disengagement, satisfaction) but does
|
||||
**not** exercise:
|
||||
|
||||
- `execution.failure.*`
|
||||
- `execution.loops.*`
|
||||
- `environment.exhaustion.*`
|
||||
|
||||
Those signals require `function_call` / `observation` ShareGPT roles. They are
|
||||
covered by the Rust unit tests and the Python repo's own test fixtures, both
|
||||
of which run on every PR. A synthetic tool-trace dataset for full coverage is
|
||||
deferred to a follow-up.
|
||||
|
||||
## One-time setup
|
||||
|
||||
```bash
|
||||
# 1. Build the Rust replay binary.
|
||||
cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay
|
||||
|
||||
# 2. Set up the Python environment for the harness driver.
|
||||
cd ../tests/parity/signals
|
||||
python3 -m venv .venv && source .venv/bin/activate
|
||||
pip install -r requirements.txt
|
||||
|
||||
# 3. Install the Python signals reference.
|
||||
# Either point at a local checkout:
|
||||
pip install -e /path/to/signals
|
||||
# or pull from git:
|
||||
pip install 'signals @ git+https://github.com/katanemo/signals@<sha>'
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
```bash
|
||||
source .venv/bin/activate
|
||||
|
||||
python run_parity.py \
|
||||
--num-samples 2000 \
|
||||
--seed 42 \
|
||||
--dataset-revision <hf-dataset-revision-sha> \
|
||||
--rust-binary ../../../crates/target/release/signals_replay \
|
||||
--output-dir out/
|
||||
|
||||
python compare.py --output-dir out/
|
||||
```
|
||||
|
||||
`run_parity.py` will:
|
||||
|
||||
1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`).
|
||||
2. Pick `--num-samples` rows under `--seed`.
|
||||
3. Convert each to ShareGPT, write `out/conversations.jsonl`.
|
||||
4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`.
|
||||
5. Run the Python analyzer in-process → `out/python_reports.jsonl`.
|
||||
|
||||
`compare.py` reads both report files and writes:
|
||||
|
||||
- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff
|
||||
- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix
|
||||
- `out/summary.md` — human-readable PR-ready report
|
||||
|
||||
Exit code is non-zero iff any Tier-A divergence is observed.
|
||||
|
||||
## Reproducibility
|
||||
|
||||
Every run pins:
|
||||
|
||||
- `dataset_revision` — the HF dataset commit
|
||||
- `seed` — RNG seed for sampling
|
||||
- `signals_python_version` — `pip show signals` version
|
||||
- `plano_git_sha` — `git rev-parse HEAD` of this repo
|
||||
- `signals_replay_binary_sha256` — the hash of the Rust bin
|
||||
|
||||
All are stamped into `metrics.json`.
|
||||
103
tests/parity/signals/_smoke_test.py
Normal file
103
tests/parity/signals/_smoke_test.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Local smoke test for the parity harness — runs both runners on a tiny
|
||||
hand-picked set of conversations without touching the lmsys dataset.
|
||||
|
||||
Run from this directory:
|
||||
python _smoke_test.py --rust-binary <path>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from signals.analyzer import SignalAnalyzer
|
||||
|
||||
SAMPLES = [
|
||||
{
|
||||
"id": "smoke-gratitude",
|
||||
"messages": [
|
||||
{"from": "human", "value": "What is the weather in Istanbul?"},
|
||||
{"from": "gpt", "value": "Istanbul is 14C and partly cloudy."},
|
||||
{"from": "human", "value": "That worked, exactly what I needed. Thanks!"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "smoke-escalation",
|
||||
"messages": [
|
||||
{"from": "human", "value": "This isn't helpful at all"},
|
||||
{"from": "gpt", "value": "I'm sorry, can you tell me more?"},
|
||||
{"from": "human", "value": "Get me a human, this is useless"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "smoke-correction",
|
||||
"messages": [
|
||||
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
|
||||
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "No, I meant flights for Saturday, not tomorrow",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "smoke-clean",
|
||||
"messages": [
|
||||
{"from": "human", "value": "Hi"},
|
||||
{"from": "gpt", "value": "Hello, how can I help?"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "smoke-rephrase",
|
||||
"messages": [
|
||||
{"from": "human", "value": "Can you summarize the news please"},
|
||||
{"from": "gpt", "value": "Sure, here is a summary."},
|
||||
{"from": "human", "value": "Could you please summarize the news"},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def main() -> int:
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rust-binary", required=True, type=Path)
|
||||
args = p.parse_args()
|
||||
|
||||
out_dir = Path("out_smoke")
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
conv_path = out_dir / "conversations.jsonl"
|
||||
rust_path = out_dir / "rust_reports.jsonl"
|
||||
py_path = out_dir / "python_reports.jsonl"
|
||||
|
||||
with conv_path.open("w") as f:
|
||||
for s in SAMPLES:
|
||||
f.write(json.dumps(s) + "\n")
|
||||
|
||||
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
|
||||
proc = subprocess.run(
|
||||
[str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
|
||||
return 2
|
||||
|
||||
analyzer = SignalAnalyzer()
|
||||
with conv_path.open() as fin, py_path.open("w") as fout:
|
||||
for line in fin:
|
||||
obj = json.loads(line)
|
||||
r = analyzer.analyze(obj["messages"])
|
||||
fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n")
|
||||
|
||||
rc = subprocess.call(
|
||||
[sys.executable, "compare.py", "--output-dir", str(out_dir)],
|
||||
)
|
||||
return rc
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
333
tests/parity/signals/compare.py
Normal file
333
tests/parity/signals/compare.py
Normal file
|
|
@ -0,0 +1,333 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Diff Rust vs Python signal reports produced by run_parity.py.
|
||||
|
||||
See README.md for the tier definitions. Exits non-zero iff any Tier-A
|
||||
divergence is found.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
CATEGORIES_BY_LAYER = {
|
||||
"interaction_signals": [
|
||||
"misalignment",
|
||||
"stagnation",
|
||||
"disengagement",
|
||||
"satisfaction",
|
||||
],
|
||||
"execution_signals": ["failure", "loops"],
|
||||
"environment_signals": ["exhaustion"],
|
||||
}
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--output-dir", type=Path, default=Path("out"))
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]:
|
||||
"""Load a JSONL file keyed by `id`. Lines with errors are still indexed."""
|
||||
out: Dict[str, Dict[str, Any]] = {}
|
||||
with path.open() as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
obj = json.loads(line)
|
||||
out[str(obj.get("id"))] = obj
|
||||
return out
|
||||
|
||||
|
||||
def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]:
|
||||
"""Return {signal_type: count} across all groups in a report dict."""
|
||||
counts: Counter[str] = Counter()
|
||||
for layer in CATEGORIES_BY_LAYER:
|
||||
groups = report.get(layer, {}) or {}
|
||||
for category in CATEGORIES_BY_LAYER[layer]:
|
||||
group = groups.get(category)
|
||||
if not group:
|
||||
continue
|
||||
for sig in group.get("signals", []) or []:
|
||||
counts[sig["signal_type"]] += 1
|
||||
return dict(counts)
|
||||
|
||||
|
||||
def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
|
||||
out: Dict[str, List[int]] = defaultdict(list)
|
||||
for layer in CATEGORIES_BY_LAYER:
|
||||
groups = report.get(layer, {}) or {}
|
||||
for category in CATEGORIES_BY_LAYER[layer]:
|
||||
group = groups.get(category)
|
||||
if not group:
|
||||
continue
|
||||
for sig in group.get("signals", []) or []:
|
||||
out[sig["signal_type"]].append(sig.get("message_index"))
|
||||
for k in out:
|
||||
out[k].sort(key=lambda x: (x is None, x))
|
||||
return dict(out)
|
||||
|
||||
|
||||
def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]:
|
||||
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
|
||||
keys = set(a) | set(b)
|
||||
out = []
|
||||
for k in sorted(keys):
|
||||
ac = a.get(k, 0)
|
||||
bc = b.get(k, 0)
|
||||
if ac != bc:
|
||||
out.append((k, ac, bc))
|
||||
return out
|
||||
|
||||
|
||||
def diff_indices(
|
||||
a: Dict[str, List[int]], b: Dict[str, List[int]]
|
||||
) -> List[Tuple[str, List[int], List[int]]]:
|
||||
keys = set(a) | set(b)
|
||||
out = []
|
||||
for k in sorted(keys):
|
||||
ai = a.get(k, [])
|
||||
bi = b.get(k, [])
|
||||
if ai != bi:
|
||||
out.append((k, ai, bi))
|
||||
return out
|
||||
|
||||
|
||||
def compare_one(
|
||||
convo_id: str, py: Dict[str, Any], rust: Dict[str, Any]
|
||||
) -> Dict[str, Any] | None:
|
||||
"""Compare a single conversation. Return diff record, or None if identical."""
|
||||
if "error" in py or "error" in rust:
|
||||
return {
|
||||
"id": convo_id,
|
||||
"tier": "A",
|
||||
"kind": "error_in_runner",
|
||||
"python_error": py.get("error"),
|
||||
"rust_error": rust.get("error"),
|
||||
}
|
||||
py_report = py["report"]
|
||||
rust_report = rust["report"]
|
||||
|
||||
py_counts = per_type_counts(py_report)
|
||||
rust_counts = per_type_counts(rust_report)
|
||||
count_diff = diff_counts(py_counts, rust_counts)
|
||||
|
||||
py_quality = py_report.get("overall_quality")
|
||||
rust_quality = rust_report.get("overall_quality")
|
||||
quality_mismatch = py_quality != rust_quality
|
||||
|
||||
if count_diff or quality_mismatch:
|
||||
return {
|
||||
"id": convo_id,
|
||||
"tier": "A",
|
||||
"kind": "signal_or_quality_mismatch",
|
||||
"quality": {"python": py_quality, "rust": rust_quality},
|
||||
"count_diff": [
|
||||
{"signal_type": st, "python": pc, "rust": rc}
|
||||
for (st, pc, rc) in count_diff
|
||||
],
|
||||
}
|
||||
|
||||
py_idx = per_type_indices(py_report)
|
||||
rust_idx = per_type_indices(rust_report)
|
||||
idx_diff = diff_indices(py_idx, rust_idx)
|
||||
if idx_diff:
|
||||
return {
|
||||
"id": convo_id,
|
||||
"tier": "B",
|
||||
"kind": "instance_index_mismatch",
|
||||
"diff": [
|
||||
{"signal_type": st, "python_indices": pi, "rust_indices": ri}
|
||||
for (st, pi, ri) in idx_diff
|
||||
],
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def confusion_matrix(
|
||||
pairs: List[Tuple[str, str]], labels: List[str]
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels}
|
||||
for py, rust in pairs:
|
||||
if py not in cm:
|
||||
cm[py] = {b: 0 for b in labels}
|
||||
if rust not in cm[py]:
|
||||
cm[py][rust] = 0
|
||||
cm[py][rust] += 1
|
||||
return cm
|
||||
|
||||
|
||||
def main() -> int:
|
||||
args = parse_args()
|
||||
out_dir = args.output_dir
|
||||
|
||||
py_reports = load_jsonl(out_dir / "python_reports.jsonl")
|
||||
rust_reports = load_jsonl(out_dir / "rust_reports.jsonl")
|
||||
|
||||
common_ids = sorted(set(py_reports) & set(rust_reports))
|
||||
only_py = sorted(set(py_reports) - set(rust_reports))
|
||||
only_rust = sorted(set(rust_reports) - set(py_reports))
|
||||
|
||||
diffs: List[Dict[str, Any]] = []
|
||||
quality_pairs: List[Tuple[str, str]] = []
|
||||
per_type_total = Counter()
|
||||
per_type_disagree = Counter()
|
||||
|
||||
tier_a = 0
|
||||
tier_b = 0
|
||||
for cid in common_ids:
|
||||
d = compare_one(cid, py_reports[cid], rust_reports[cid])
|
||||
if d is None:
|
||||
quality_pairs.append(
|
||||
(
|
||||
py_reports[cid]["report"]["overall_quality"],
|
||||
rust_reports[cid]["report"]["overall_quality"],
|
||||
)
|
||||
)
|
||||
for st, _ in per_type_counts(py_reports[cid]["report"]).items():
|
||||
per_type_total[st] += 1
|
||||
else:
|
||||
diffs.append(d)
|
||||
if d["tier"] == "A":
|
||||
tier_a += 1
|
||||
elif d["tier"] == "B":
|
||||
tier_b += 1
|
||||
if "report" in py_reports[cid] and "report" in rust_reports[cid]:
|
||||
quality_pairs.append(
|
||||
(
|
||||
py_reports[cid]["report"].get("overall_quality", "?"),
|
||||
rust_reports[cid]["report"].get("overall_quality", "?"),
|
||||
)
|
||||
)
|
||||
for cd in d.get("count_diff", []) or []:
|
||||
per_type_disagree[cd["signal_type"]] += 1
|
||||
per_type_total[cd["signal_type"]] += 1
|
||||
|
||||
n_total = len(common_ids)
|
||||
n_match = n_total - len(diffs)
|
||||
agreement = (n_match / n_total) if n_total else 0.0
|
||||
|
||||
quality_labels = ["excellent", "good", "neutral", "poor", "severe"]
|
||||
cm = confusion_matrix(quality_pairs, quality_labels)
|
||||
|
||||
metrics = {
|
||||
"n_python_reports": len(py_reports),
|
||||
"n_rust_reports": len(rust_reports),
|
||||
"n_common": n_total,
|
||||
"n_only_python": len(only_py),
|
||||
"n_only_rust": len(only_rust),
|
||||
"n_full_match": n_match,
|
||||
"agreement_pct": round(100.0 * agreement, 4),
|
||||
"tier_a_divergences": tier_a,
|
||||
"tier_b_divergences": tier_b,
|
||||
"quality_confusion_matrix": cm,
|
||||
"per_signal_type_total": dict(per_type_total),
|
||||
"per_signal_type_disagree": dict(per_type_disagree),
|
||||
}
|
||||
|
||||
# Pull in run metadata if present.
|
||||
rm_path = out_dir / "run_metadata.json"
|
||||
if rm_path.exists():
|
||||
metrics["run_metadata"] = json.loads(rm_path.read_text())
|
||||
|
||||
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
|
||||
with (out_dir / "diffs.jsonl").open("w") as f:
|
||||
for d in diffs:
|
||||
f.write(json.dumps(d, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
|
||||
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
|
||||
|
||||
print(
|
||||
json.dumps(
|
||||
{k: v for k, v in metrics.items() if k != "quality_confusion_matrix"},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
|
||||
print(f"summary: {out_dir / 'summary.md'}")
|
||||
|
||||
if tier_a > 0:
|
||||
print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
|
||||
def write_summary_md(
|
||||
path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
lines: List[str] = []
|
||||
lines.append("# Signals Parity Report")
|
||||
lines.append("")
|
||||
rm = metrics.get("run_metadata", {})
|
||||
if rm:
|
||||
lines.append("## Run metadata")
|
||||
lines.append("")
|
||||
for k in (
|
||||
"dataset_name",
|
||||
"dataset_revision",
|
||||
"seed",
|
||||
"num_samples_actual",
|
||||
"plano_git_sha",
|
||||
"signals_python_version",
|
||||
"rust_binary_sha256",
|
||||
):
|
||||
if k in rm:
|
||||
lines.append(f"- **{k}**: `{rm[k]}`")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Summary")
|
||||
lines.append("")
|
||||
lines.append(f"- Conversations compared: **{metrics['n_common']}**")
|
||||
lines.append(f"- Full matches: **{metrics['n_full_match']}**")
|
||||
lines.append(f"- Agreement: **{metrics['agreement_pct']}%**")
|
||||
lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**")
|
||||
lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Per-signal-type disagreement")
|
||||
lines.append("")
|
||||
lines.append("| Signal type | Total reports | Disagreements |")
|
||||
lines.append("|---|---:|---:|")
|
||||
totals = metrics["per_signal_type_total"]
|
||||
disagrees = metrics["per_signal_type_disagree"]
|
||||
for k in sorted(set(totals) | set(disagrees)):
|
||||
lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |")
|
||||
lines.append("")
|
||||
|
||||
lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)")
|
||||
lines.append("")
|
||||
cm = metrics["quality_confusion_matrix"]
|
||||
labels = list(cm.keys())
|
||||
lines.append("| | " + " | ".join(labels) + " |")
|
||||
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
|
||||
for r in labels:
|
||||
lines.append(
|
||||
f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
if sample_diffs:
|
||||
lines.append("## Sample divergences (first 20)")
|
||||
lines.append("")
|
||||
for d in sample_diffs:
|
||||
lines.append(f"### `{d['id']}` — tier {d['tier']} — {d['kind']}")
|
||||
lines.append("")
|
||||
lines.append("```json")
|
||||
lines.append(json.dumps(d, indent=2))
|
||||
lines.append("```")
|
||||
lines.append("")
|
||||
|
||||
path.write_text("\n".join(lines))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
3
tests/parity/signals/requirements.txt
Normal file
3
tests/parity/signals/requirements.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
huggingface_hub>=0.25
|
||||
pyarrow>=15
|
||||
tqdm>=4.65
|
||||
332
tests/parity/signals/run_parity.py
Normal file
332
tests/parity/signals/run_parity.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Parity harness driver.
|
||||
|
||||
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
|
||||
reference analyzer (in-process) and the Rust port (subprocess), writes both
|
||||
reports to disk for `compare.py` to diff.
|
||||
|
||||
Usage:
|
||||
python run_parity.py \\
|
||||
--num-samples 2000 \\
|
||||
--seed 42 \\
|
||||
--dataset-revision <hf-revision-sha> \\
|
||||
--rust-binary ../../../crates/target/release/signals_replay \\
|
||||
--output-dir out/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List
|
||||
|
||||
try:
|
||||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
except ImportError:
|
||||
print(
|
||||
"error: install dependencies first: pip install -r requirements.txt",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
try:
|
||||
from signals.analyzer import SignalAnalyzer
|
||||
except ImportError:
|
||||
print(
|
||||
"error: the python `signals` package is not installed. "
|
||||
"install it from your local checkout: pip install -e /path/to/signals",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(it, **_kwargs): # type: ignore[no-redef]
|
||||
return it
|
||||
|
||||
|
||||
DATASET_NAME = "lmsys/lmsys-chat-1m"
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(description=__doc__)
|
||||
p.add_argument("--num-samples", type=int, default=2000)
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
p.add_argument(
|
||||
"--dataset-revision",
|
||||
default=None,
|
||||
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
|
||||
)
|
||||
p.add_argument(
|
||||
"--rust-binary",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="path to the `signals_replay` binary built from crates/brightstaff",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("out"),
|
||||
help="directory to write the conversations + both runners' outputs",
|
||||
)
|
||||
p.add_argument(
|
||||
"--max-conv-messages",
|
||||
type=int,
|
||||
default=200,
|
||||
help="drop conversations with more than this many messages (the analyzer "
|
||||
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
|
||||
)
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
|
||||
|
||||
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
|
||||
"""
|
||||
out = []
|
||||
for m in conversation:
|
||||
role = m.get("role", "")
|
||||
content = m.get("content", "")
|
||||
if not isinstance(content, str):
|
||||
content = str(content) if content is not None else ""
|
||||
if role == "user":
|
||||
from_ = "human"
|
||||
elif role == "assistant":
|
||||
from_ = "gpt"
|
||||
else:
|
||||
# lmsys is human/assistant only; skip anything else defensively.
|
||||
continue
|
||||
out.append({"from": from_, "value": content})
|
||||
return out
|
||||
|
||||
|
||||
def _list_parquet_files(revision: str | None) -> List[str]:
|
||||
"""Return the list of parquet shard paths in the dataset repo."""
|
||||
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
|
||||
return sorted(f for f in files if f.endswith(".parquet"))
|
||||
|
||||
|
||||
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
|
||||
"""Download each parquet shard to the HF cache, return local paths."""
|
||||
local: List[Path] = []
|
||||
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
|
||||
p = hf_hub_download(
|
||||
DATASET_NAME,
|
||||
filename=rel,
|
||||
repo_type="dataset",
|
||||
revision=revision,
|
||||
)
|
||||
local.append(Path(p))
|
||||
return local
|
||||
|
||||
|
||||
def sample_conversations(
|
||||
*,
|
||||
num_samples: int,
|
||||
seed: int,
|
||||
revision: str | None,
|
||||
max_conv_messages: int,
|
||||
) -> Iterator[Dict[str, Any]]:
|
||||
"""Yield `num_samples` conversations sampled uniformly across the dataset.
|
||||
|
||||
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
|
||||
and read the parquet shards directly via pyarrow.
|
||||
"""
|
||||
print(
|
||||
f"listing {DATASET_NAME}"
|
||||
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
shard_paths = _list_parquet_files(revision)
|
||||
if not shard_paths:
|
||||
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
|
||||
local_paths = _download_shards(shard_paths, revision)
|
||||
|
||||
# Collect row counts without reading data.
|
||||
shard_row_counts: List[int] = []
|
||||
for p in local_paths:
|
||||
pf = pq.ParquetFile(str(p))
|
||||
shard_row_counts.append(pf.metadata.num_rows)
|
||||
total_rows = sum(shard_row_counts)
|
||||
print(
|
||||
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
rng = random.Random(seed)
|
||||
global_indices = sorted(rng.sample(range(total_rows), num_samples))
|
||||
|
||||
# Bucket indices by shard.
|
||||
by_shard: Dict[int, List[int]] = {}
|
||||
cumulative = 0
|
||||
shard_offsets = []
|
||||
for c in shard_row_counts:
|
||||
shard_offsets.append(cumulative)
|
||||
cumulative += c
|
||||
for gi in global_indices:
|
||||
# Find which shard this index belongs to.
|
||||
for si, off in enumerate(shard_offsets):
|
||||
if gi < off + shard_row_counts[si]:
|
||||
by_shard.setdefault(si, []).append(gi - off)
|
||||
break
|
||||
|
||||
yielded = 0
|
||||
for si in sorted(by_shard.keys()):
|
||||
local_rows = by_shard[si]
|
||||
pf = pq.ParquetFile(str(local_paths[si]))
|
||||
table = pf.read(columns=["conversation"])
|
||||
conv_col = table.column("conversation")
|
||||
for local_idx in local_rows:
|
||||
raw = conv_col[local_idx].as_py()
|
||||
if not raw:
|
||||
continue
|
||||
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
|
||||
if len(conversation) > max_conv_messages:
|
||||
continue
|
||||
messages = lmsys_to_sharegpt(conversation)
|
||||
if not messages:
|
||||
continue
|
||||
global_idx = shard_offsets[si] + local_idx
|
||||
yield {
|
||||
"id": f"lmsys-{global_idx}",
|
||||
"messages": messages,
|
||||
}
|
||||
yielded += 1
|
||||
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
|
||||
|
||||
|
||||
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
|
||||
n = 0
|
||||
with out_path.open("w") as f:
|
||||
for s in tqdm(samples, desc="sampling", unit="convo"):
|
||||
f.write(json.dumps(s, ensure_ascii=False))
|
||||
f.write("\n")
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
|
||||
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
|
||||
t0 = time.monotonic()
|
||||
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
|
||||
proc = subprocess.run(
|
||||
[str(rust_binary)],
|
||||
stdin=fin,
|
||||
stdout=fout,
|
||||
stderr=subprocess.PIPE,
|
||||
check=False,
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
|
||||
raise SystemExit(f"rust runner exited {proc.returncode}")
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
|
||||
|
||||
|
||||
def run_python(conv_path: Path, out_path: Path) -> None:
|
||||
print("running python analyzer...", file=sys.stderr)
|
||||
t0 = time.monotonic()
|
||||
analyzer = SignalAnalyzer()
|
||||
with conv_path.open() as fin, out_path.open("w") as fout:
|
||||
for line in tqdm(fin, desc="python", unit="convo"):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
obj = json.loads(line)
|
||||
report = analyzer.analyze(obj["messages"])
|
||||
fout.write(
|
||||
json.dumps(
|
||||
{"id": obj["id"], "report": report.to_dict()},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
|
||||
fout.write("\n")
|
||||
elapsed = time.monotonic() - t0
|
||||
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
|
||||
|
||||
|
||||
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
|
||||
"""Write the input metadata so compare.py can include it in the report."""
|
||||
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
|
||||
try:
|
||||
plano_sha = (
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
except Exception:
|
||||
plano_sha = "unknown"
|
||||
try:
|
||||
signals_version = subprocess.check_output(
|
||||
[sys.executable, "-m", "pip", "show", "signals"]
|
||||
).decode()
|
||||
signals_version = next(
|
||||
(
|
||||
l.split(":", 1)[1].strip()
|
||||
for l in signals_version.splitlines()
|
||||
if l.startswith("Version")
|
||||
),
|
||||
"unknown",
|
||||
)
|
||||
except Exception:
|
||||
signals_version = "unknown"
|
||||
|
||||
meta = {
|
||||
"dataset_name": DATASET_NAME,
|
||||
"dataset_revision": args.dataset_revision,
|
||||
"seed": args.seed,
|
||||
"num_samples_requested": args.num_samples,
|
||||
"num_samples_actual": n_samples,
|
||||
"rust_binary": str(args.rust_binary.resolve()),
|
||||
"rust_binary_sha256": binary_sha,
|
||||
"plano_git_sha": plano_sha,
|
||||
"signals_python_version": signals_version,
|
||||
"max_conv_messages": args.max_conv_messages,
|
||||
}
|
||||
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
|
||||
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
if not args.rust_binary.exists():
|
||||
raise SystemExit(f"rust binary not found at {args.rust_binary}")
|
||||
|
||||
conv_path = args.output_dir / "conversations.jsonl"
|
||||
rust_path = args.output_dir / "rust_reports.jsonl"
|
||||
py_path = args.output_dir / "python_reports.jsonl"
|
||||
|
||||
samples = sample_conversations(
|
||||
num_samples=args.num_samples,
|
||||
seed=args.seed,
|
||||
revision=args.dataset_revision,
|
||||
max_conv_messages=args.max_conv_messages,
|
||||
)
|
||||
n = write_conversations(conv_path, samples)
|
||||
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
|
||||
|
||||
run_rust(args.rust_binary, conv_path, rust_path)
|
||||
run_python(conv_path, py_path)
|
||||
stamp_metadata(args, args.output_dir, n)
|
||||
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue