signals: feature parity with the latest Signals paper. Porting logic from python repo (#903)

* signals: port to layered taxonomy with dual-emit OTel

Made-with: Cursor

* fix: silence collapsible_match clippy lint (rustc 1.95)

Made-with: Cursor

* test: parity harness for rust vs python signals analyzer

Validates the brightstaff signals port against the katanemo/signals Python
reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python-
compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle
bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch
CI job.

Made-with: Cursor

* Remove signals test from the gitops flow

* style: format parity harness with black

Made-with: Cursor

* signals: group summary by taxonomy, factor misalignment_ratio

Addresses #903 review feedback from @nehcgs:

- generate_summary() now renders explicit Interaction / Execution /
  Environment headers so the paper taxonomy is visible at a glance,
  even when no signals fired in a given layer. Quality-driving callouts
  (high misalignment rate, looping detected, escalation requested) are
  appended after the layer summary as an alerts tail.

- repair_ratio (legacy taxonomy name) renamed to misalignment_ratio
  and factored into a single InteractionSignals::misalignment_ratio()
  helper so assess_quality and generate_summary share one source of
  truth instead of recomputing the same divide twice.

Two new unit tests pin the layer headers and the (sev N) severity
suffix. Parity with the python reference is preserved at the Tier-A
level (per-type counts + overall_quality); only the human-readable
summary string diverges, which the parity comparator already classifies
as Tier-C.

Made-with: Cursor
This commit is contained in:
Syed A. Hashmi 2026-04-23 12:02:30 -07:00 committed by GitHub
parent 6701195a5d
commit c8079ac971
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
31 changed files with 5246 additions and 3261 deletions

4
tests/parity/signals/.gitignore vendored Normal file
View file

@ -0,0 +1,4 @@
out/
.venv/
__pycache__/
*.pyc

View file

@ -0,0 +1,98 @@
# Signals Parity Harness
Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same
`SignalReport` as the Python reference at <https://github.com/katanemo/signals>
on a fixed sample of `lmsys/lmsys-chat-1m` conversations.
This harness is **not** part of normal CI. It downloads several GB and is run
on demand to gate releases of the signals subsystem (or to investigate
regressions reported in production).
## What gets compared
For each conversation, both analyzers emit a `SignalReport`. The comparator
classifies any divergence into three tiers:
| Tier | Field | Action on divergence |
|------|------------------------------------------------|----------------------|
| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run |
| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail |
| C | metadata, snippet text, summary | Information only |
Quality buckets are compared by string (`excellent` / `good` / ...).
## What this harness does *not* cover
`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction**
layer well (misalignment, stagnation, disengagement, satisfaction) but does
**not** exercise:
- `execution.failure.*`
- `execution.loops.*`
- `environment.exhaustion.*`
Those signals require `function_call` / `observation` ShareGPT roles. They are
covered by the Rust unit tests and the Python repo's own test fixtures, both
of which run on every PR. A synthetic tool-trace dataset for full coverage is
deferred to a follow-up.
## One-time setup
```bash
# 1. Build the Rust replay binary.
cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay
# 2. Set up the Python environment for the harness driver.
cd ../tests/parity/signals
python3 -m venv .venv && source .venv/bin/activate
pip install -r requirements.txt
# 3. Install the Python signals reference.
# Either point at a local checkout:
pip install -e /path/to/signals
# or pull from git:
pip install 'signals @ git+https://github.com/katanemo/signals@<sha>'
```
## Running
```bash
source .venv/bin/activate
python run_parity.py \
--num-samples 2000 \
--seed 42 \
--dataset-revision <hf-dataset-revision-sha> \
--rust-binary ../../../crates/target/release/signals_replay \
--output-dir out/
python compare.py --output-dir out/
```
`run_parity.py` will:
1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`).
2. Pick `--num-samples` rows under `--seed`.
3. Convert each to ShareGPT, write `out/conversations.jsonl`.
4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`.
5. Run the Python analyzer in-process → `out/python_reports.jsonl`.
`compare.py` reads both report files and writes:
- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff
- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix
- `out/summary.md` — human-readable PR-ready report
Exit code is non-zero iff any Tier-A divergence is observed.
## Reproducibility
Every run pins:
- `dataset_revision` — the HF dataset commit
- `seed` — RNG seed for sampling
- `signals_python_version``pip show signals` version
- `plano_git_sha``git rev-parse HEAD` of this repo
- `signals_replay_binary_sha256` — the hash of the Rust bin
All are stamped into `metrics.json`.

View file

@ -0,0 +1,103 @@
#!/usr/bin/env python3
"""
Local smoke test for the parity harness runs both runners on a tiny
hand-picked set of conversations without touching the lmsys dataset.
Run from this directory:
python _smoke_test.py --rust-binary <path>
"""
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
from signals.analyzer import SignalAnalyzer
SAMPLES = [
{
"id": "smoke-gratitude",
"messages": [
{"from": "human", "value": "What is the weather in Istanbul?"},
{"from": "gpt", "value": "Istanbul is 14C and partly cloudy."},
{"from": "human", "value": "That worked, exactly what I needed. Thanks!"},
],
},
{
"id": "smoke-escalation",
"messages": [
{"from": "human", "value": "This isn't helpful at all"},
{"from": "gpt", "value": "I'm sorry, can you tell me more?"},
{"from": "human", "value": "Get me a human, this is useless"},
],
},
{
"id": "smoke-correction",
"messages": [
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
{
"from": "human",
"value": "No, I meant flights for Saturday, not tomorrow",
},
],
},
{
"id": "smoke-clean",
"messages": [
{"from": "human", "value": "Hi"},
{"from": "gpt", "value": "Hello, how can I help?"},
],
},
{
"id": "smoke-rephrase",
"messages": [
{"from": "human", "value": "Can you summarize the news please"},
{"from": "gpt", "value": "Sure, here is a summary."},
{"from": "human", "value": "Could you please summarize the news"},
],
},
]
def main() -> int:
p = argparse.ArgumentParser()
p.add_argument("--rust-binary", required=True, type=Path)
args = p.parse_args()
out_dir = Path("out_smoke")
out_dir.mkdir(exist_ok=True)
conv_path = out_dir / "conversations.jsonl"
rust_path = out_dir / "rust_reports.jsonl"
py_path = out_dir / "python_reports.jsonl"
with conv_path.open("w") as f:
for s in SAMPLES:
f.write(json.dumps(s) + "\n")
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
proc = subprocess.run(
[str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
return 2
analyzer = SignalAnalyzer()
with conv_path.open() as fin, py_path.open("w") as fout:
for line in fin:
obj = json.loads(line)
r = analyzer.analyze(obj["messages"])
fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n")
rc = subprocess.call(
[sys.executable, "compare.py", "--output-dir", str(out_dir)],
)
return rc
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,333 @@
#!/usr/bin/env python3
"""
Diff Rust vs Python signal reports produced by run_parity.py.
See README.md for the tier definitions. Exits non-zero iff any Tier-A
divergence is found.
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter, defaultdict
from pathlib import Path
from typing import Any, Dict, List, Tuple
CATEGORIES_BY_LAYER = {
"interaction_signals": [
"misalignment",
"stagnation",
"disengagement",
"satisfaction",
],
"execution_signals": ["failure", "loops"],
"environment_signals": ["exhaustion"],
}
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--output-dir", type=Path, default=Path("out"))
return p.parse_args()
def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]:
"""Load a JSONL file keyed by `id`. Lines with errors are still indexed."""
out: Dict[str, Dict[str, Any]] = {}
with path.open() as f:
for line in f:
line = line.strip()
if not line:
continue
obj = json.loads(line)
out[str(obj.get("id"))] = obj
return out
def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]:
"""Return {signal_type: count} across all groups in a report dict."""
counts: Counter[str] = Counter()
for layer in CATEGORIES_BY_LAYER:
groups = report.get(layer, {}) or {}
for category in CATEGORIES_BY_LAYER[layer]:
group = groups.get(category)
if not group:
continue
for sig in group.get("signals", []) or []:
counts[sig["signal_type"]] += 1
return dict(counts)
def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
out: Dict[str, List[int]] = defaultdict(list)
for layer in CATEGORIES_BY_LAYER:
groups = report.get(layer, {}) or {}
for category in CATEGORIES_BY_LAYER[layer]:
group = groups.get(category)
if not group:
continue
for sig in group.get("signals", []) or []:
out[sig["signal_type"]].append(sig.get("message_index"))
for k in out:
out[k].sort(key=lambda x: (x is None, x))
return dict(out)
def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]:
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
keys = set(a) | set(b)
out = []
for k in sorted(keys):
ac = a.get(k, 0)
bc = b.get(k, 0)
if ac != bc:
out.append((k, ac, bc))
return out
def diff_indices(
a: Dict[str, List[int]], b: Dict[str, List[int]]
) -> List[Tuple[str, List[int], List[int]]]:
keys = set(a) | set(b)
out = []
for k in sorted(keys):
ai = a.get(k, [])
bi = b.get(k, [])
if ai != bi:
out.append((k, ai, bi))
return out
def compare_one(
convo_id: str, py: Dict[str, Any], rust: Dict[str, Any]
) -> Dict[str, Any] | None:
"""Compare a single conversation. Return diff record, or None if identical."""
if "error" in py or "error" in rust:
return {
"id": convo_id,
"tier": "A",
"kind": "error_in_runner",
"python_error": py.get("error"),
"rust_error": rust.get("error"),
}
py_report = py["report"]
rust_report = rust["report"]
py_counts = per_type_counts(py_report)
rust_counts = per_type_counts(rust_report)
count_diff = diff_counts(py_counts, rust_counts)
py_quality = py_report.get("overall_quality")
rust_quality = rust_report.get("overall_quality")
quality_mismatch = py_quality != rust_quality
if count_diff or quality_mismatch:
return {
"id": convo_id,
"tier": "A",
"kind": "signal_or_quality_mismatch",
"quality": {"python": py_quality, "rust": rust_quality},
"count_diff": [
{"signal_type": st, "python": pc, "rust": rc}
for (st, pc, rc) in count_diff
],
}
py_idx = per_type_indices(py_report)
rust_idx = per_type_indices(rust_report)
idx_diff = diff_indices(py_idx, rust_idx)
if idx_diff:
return {
"id": convo_id,
"tier": "B",
"kind": "instance_index_mismatch",
"diff": [
{"signal_type": st, "python_indices": pi, "rust_indices": ri}
for (st, pi, ri) in idx_diff
],
}
return None
def confusion_matrix(
pairs: List[Tuple[str, str]], labels: List[str]
) -> Dict[str, Dict[str, int]]:
cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels}
for py, rust in pairs:
if py not in cm:
cm[py] = {b: 0 for b in labels}
if rust not in cm[py]:
cm[py][rust] = 0
cm[py][rust] += 1
return cm
def main() -> int:
args = parse_args()
out_dir = args.output_dir
py_reports = load_jsonl(out_dir / "python_reports.jsonl")
rust_reports = load_jsonl(out_dir / "rust_reports.jsonl")
common_ids = sorted(set(py_reports) & set(rust_reports))
only_py = sorted(set(py_reports) - set(rust_reports))
only_rust = sorted(set(rust_reports) - set(py_reports))
diffs: List[Dict[str, Any]] = []
quality_pairs: List[Tuple[str, str]] = []
per_type_total = Counter()
per_type_disagree = Counter()
tier_a = 0
tier_b = 0
for cid in common_ids:
d = compare_one(cid, py_reports[cid], rust_reports[cid])
if d is None:
quality_pairs.append(
(
py_reports[cid]["report"]["overall_quality"],
rust_reports[cid]["report"]["overall_quality"],
)
)
for st, _ in per_type_counts(py_reports[cid]["report"]).items():
per_type_total[st] += 1
else:
diffs.append(d)
if d["tier"] == "A":
tier_a += 1
elif d["tier"] == "B":
tier_b += 1
if "report" in py_reports[cid] and "report" in rust_reports[cid]:
quality_pairs.append(
(
py_reports[cid]["report"].get("overall_quality", "?"),
rust_reports[cid]["report"].get("overall_quality", "?"),
)
)
for cd in d.get("count_diff", []) or []:
per_type_disagree[cd["signal_type"]] += 1
per_type_total[cd["signal_type"]] += 1
n_total = len(common_ids)
n_match = n_total - len(diffs)
agreement = (n_match / n_total) if n_total else 0.0
quality_labels = ["excellent", "good", "neutral", "poor", "severe"]
cm = confusion_matrix(quality_pairs, quality_labels)
metrics = {
"n_python_reports": len(py_reports),
"n_rust_reports": len(rust_reports),
"n_common": n_total,
"n_only_python": len(only_py),
"n_only_rust": len(only_rust),
"n_full_match": n_match,
"agreement_pct": round(100.0 * agreement, 4),
"tier_a_divergences": tier_a,
"tier_b_divergences": tier_b,
"quality_confusion_matrix": cm,
"per_signal_type_total": dict(per_type_total),
"per_signal_type_disagree": dict(per_type_disagree),
}
# Pull in run metadata if present.
rm_path = out_dir / "run_metadata.json"
if rm_path.exists():
metrics["run_metadata"] = json.loads(rm_path.read_text())
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
with (out_dir / "diffs.jsonl").open("w") as f:
for d in diffs:
f.write(json.dumps(d, ensure_ascii=False))
f.write("\n")
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
print(
json.dumps(
{k: v for k, v in metrics.items() if k != "quality_confusion_matrix"},
indent=2,
)
)
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
print(f"summary: {out_dir / 'summary.md'}")
if tier_a > 0:
print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr)
return 1
return 0
def write_summary_md(
path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]
) -> None:
lines: List[str] = []
lines.append("# Signals Parity Report")
lines.append("")
rm = metrics.get("run_metadata", {})
if rm:
lines.append("## Run metadata")
lines.append("")
for k in (
"dataset_name",
"dataset_revision",
"seed",
"num_samples_actual",
"plano_git_sha",
"signals_python_version",
"rust_binary_sha256",
):
if k in rm:
lines.append(f"- **{k}**: `{rm[k]}`")
lines.append("")
lines.append("## Summary")
lines.append("")
lines.append(f"- Conversations compared: **{metrics['n_common']}**")
lines.append(f"- Full matches: **{metrics['n_full_match']}**")
lines.append(f"- Agreement: **{metrics['agreement_pct']}%**")
lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**")
lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**")
lines.append("")
lines.append("## Per-signal-type disagreement")
lines.append("")
lines.append("| Signal type | Total reports | Disagreements |")
lines.append("|---|---:|---:|")
totals = metrics["per_signal_type_total"]
disagrees = metrics["per_signal_type_disagree"]
for k in sorted(set(totals) | set(disagrees)):
lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |")
lines.append("")
lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)")
lines.append("")
cm = metrics["quality_confusion_matrix"]
labels = list(cm.keys())
lines.append("| | " + " | ".join(labels) + " |")
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
for r in labels:
lines.append(
f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |"
)
lines.append("")
if sample_diffs:
lines.append("## Sample divergences (first 20)")
lines.append("")
for d in sample_diffs:
lines.append(f"### `{d['id']}` — tier {d['tier']}{d['kind']}")
lines.append("")
lines.append("```json")
lines.append(json.dumps(d, indent=2))
lines.append("```")
lines.append("")
path.write_text("\n".join(lines))
if __name__ == "__main__":
sys.exit(main())

View file

@ -0,0 +1,3 @@
huggingface_hub>=0.25
pyarrow>=15
tqdm>=4.65

View file

@ -0,0 +1,332 @@
#!/usr/bin/env python3
"""
Parity harness driver.
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
reference analyzer (in-process) and the Rust port (subprocess), writes both
reports to disk for `compare.py` to diff.
Usage:
python run_parity.py \\
--num-samples 2000 \\
--seed 42 \\
--dataset-revision <hf-revision-sha> \\
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
import hashlib
import json
import random
import subprocess
import sys
import time
from pathlib import Path
from typing import Any, Dict, Iterator, List
try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print(
"error: install dependencies first: pip install -r requirements.txt",
file=sys.stderr,
)
sys.exit(2)
try:
from signals.analyzer import SignalAnalyzer
except ImportError:
print(
"error: the python `signals` package is not installed. "
"install it from your local checkout: pip install -e /path/to/signals",
file=sys.stderr,
)
sys.exit(2)
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
DATASET_NAME = "lmsys/lmsys-chat-1m"
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--num-samples", type=int, default=2000)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--dataset-revision",
default=None,
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
)
p.add_argument(
"--rust-binary",
type=Path,
required=True,
help="path to the `signals_replay` binary built from crates/brightstaff",
)
p.add_argument(
"--output-dir",
type=Path,
default=Path("out"),
help="directory to write the conversations + both runners' outputs",
)
p.add_argument(
"--max-conv-messages",
type=int,
default=200,
help="drop conversations with more than this many messages (the analyzer "
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
)
return p.parse_args()
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
"""
out = []
for m in conversation:
role = m.get("role", "")
content = m.get("content", "")
if not isinstance(content, str):
content = str(content) if content is not None else ""
if role == "user":
from_ = "human"
elif role == "assistant":
from_ = "gpt"
else:
# lmsys is human/assistant only; skip anything else defensively.
continue
out.append({"from": from_, "value": content})
return out
def _list_parquet_files(revision: str | None) -> List[str]:
"""Return the list of parquet shard paths in the dataset repo."""
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
return sorted(f for f in files if f.endswith(".parquet"))
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
"""Download each parquet shard to the HF cache, return local paths."""
local: List[Path] = []
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
p = hf_hub_download(
DATASET_NAME,
filename=rel,
repo_type="dataset",
revision=revision,
)
local.append(Path(p))
return local
def sample_conversations(
*,
num_samples: int,
seed: int,
revision: str | None,
max_conv_messages: int,
) -> Iterator[Dict[str, Any]]:
"""Yield `num_samples` conversations sampled uniformly across the dataset.
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
and read the parquet shards directly via pyarrow.
"""
print(
f"listing {DATASET_NAME}"
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
file=sys.stderr,
)
shard_paths = _list_parquet_files(revision)
if not shard_paths:
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
local_paths = _download_shards(shard_paths, revision)
# Collect row counts without reading data.
shard_row_counts: List[int] = []
for p in local_paths:
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
file=sys.stderr,
)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
# Bucket indices by shard.
by_shard: Dict[int, List[int]] = {}
cumulative = 0
shard_offsets = []
for c in shard_row_counts:
shard_offsets.append(cumulative)
cumulative += c
for gi in global_indices:
# Find which shard this index belongs to.
for si, off in enumerate(shard_offsets):
if gi < off + shard_row_counts[si]:
by_shard.setdefault(si, []).append(gi - off)
break
yielded = 0
for si in sorted(by_shard.keys()):
local_rows = by_shard[si]
pf = pq.ParquetFile(str(local_paths[si]))
table = pf.read(columns=["conversation"])
conv_col = table.column("conversation")
for local_idx in local_rows:
raw = conv_col[local_idx].as_py()
if not raw:
continue
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
if len(conversation) > max_conv_messages:
continue
messages = lmsys_to_sharegpt(conversation)
if not messages:
continue
global_idx = shard_offsets[si] + local_idx
yield {
"id": f"lmsys-{global_idx}",
"messages": messages,
}
yielded += 1
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
n = 0
with out_path.open("w") as f:
for s in tqdm(samples, desc="sampling", unit="convo"):
f.write(json.dumps(s, ensure_ascii=False))
f.write("\n")
n += 1
return n
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
t0 = time.monotonic()
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
proc = subprocess.run(
[str(rust_binary)],
stdin=fin,
stdout=fout,
stderr=subprocess.PIPE,
check=False,
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
raise SystemExit(f"rust runner exited {proc.returncode}")
elapsed = time.monotonic() - t0
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
def run_python(conv_path: Path, out_path: Path) -> None:
print("running python analyzer...", file=sys.stderr)
t0 = time.monotonic()
analyzer = SignalAnalyzer()
with conv_path.open() as fin, out_path.open("w") as fout:
for line in tqdm(fin, desc="python", unit="convo"):
line = line.strip()
if not line:
continue
try:
obj = json.loads(line)
report = analyzer.analyze(obj["messages"])
fout.write(
json.dumps(
{"id": obj["id"], "report": report.to_dict()},
ensure_ascii=False,
)
)
except Exception as e:
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
fout.write("\n")
elapsed = time.monotonic() - t0
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
)
.decode()
.strip()
)
except Exception:
plano_sha = "unknown"
try:
signals_version = subprocess.check_output(
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(
l.split(":", 1)[1].strip()
for l in signals_version.splitlines()
if l.startswith("Version")
),
"unknown",
)
except Exception:
signals_version = "unknown"
meta = {
"dataset_name": DATASET_NAME,
"dataset_revision": args.dataset_revision,
"seed": args.seed,
"num_samples_requested": args.num_samples,
"num_samples_actual": n_samples,
"rust_binary": str(args.rust_binary.resolve()),
"rust_binary_sha256": binary_sha,
"plano_git_sha": plano_sha,
"signals_python_version": signals_version,
"max_conv_messages": args.max_conv_messages,
}
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
def main() -> None:
args = parse_args()
args.output_dir.mkdir(parents=True, exist_ok=True)
if not args.rust_binary.exists():
raise SystemExit(f"rust binary not found at {args.rust_binary}")
conv_path = args.output_dir / "conversations.jsonl"
rust_path = args.output_dir / "rust_reports.jsonl"
py_path = args.output_dir / "python_reports.jsonl"
samples = sample_conversations(
num_samples=args.num_samples,
seed=args.seed,
revision=args.dataset_revision,
max_conv_messages=args.max_conv_messages,
)
n = write_conversations(conv_path, samples)
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
run_rust(args.rust_binary, conv_path, rust_path)
run_python(conv_path, py_path)
stamp_metadata(args, args.output_dir, n)
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
if __name__ == "__main__":
main()