style: format parity harness with black

Made-with: Cursor
This commit is contained in:
Syed Hashmi 2026-04-22 21:34:23 -07:00
parent 471116b542
commit b14a74348e
No known key found for this signature in database
3 changed files with 50 additions and 15 deletions

View file

@ -6,6 +6,7 @@ hand-picked set of conversations without touching the lmsys dataset.
Run from this directory:
python _smoke_test.py --rust-binary <path>
"""
from __future__ import annotations
import argparse
@ -38,7 +39,10 @@ SAMPLES = [
"messages": [
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
{"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"},
{
"from": "human",
"value": "No, I meant flights for Saturday, not tomorrow",
},
],
},
{
@ -75,7 +79,9 @@ def main() -> int:
f.write(json.dumps(s) + "\n")
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE)
proc = subprocess.run(
[str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE
)
if proc.returncode != 0:
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
return 2

View file

@ -5,6 +5,7 @@ Diff Rust vs Python signal reports produced by run_parity.py.
See README.md for the tier definitions. Exits non-zero iff any Tier-A
divergence is found.
"""
from __future__ import annotations
import argparse
@ -15,7 +16,12 @@ from pathlib import Path
from typing import Any, Dict, List, Tuple
CATEGORIES_BY_LAYER = {
"interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"],
"interaction_signals": [
"misalignment",
"stagnation",
"disengagement",
"satisfaction",
],
"execution_signals": ["failure", "loops"],
"environment_signals": ["exhaustion"],
}
@ -69,9 +75,7 @@ def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
return dict(out)
def diff_counts(
a: Dict[str, int], b: Dict[str, int]
) -> List[Tuple[str, int, int]]:
def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]:
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
keys = set(a) | set(b)
out = []
@ -242,7 +246,12 @@ def main() -> int:
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2))
print(
json.dumps(
{k: v for k, v in metrics.items() if k != "quality_confusion_matrix"},
indent=2,
)
)
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
print(f"summary: {out_dir / 'summary.md'}")
@ -252,7 +261,9 @@ def main() -> int:
return 0
def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None:
def write_summary_md(
path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]
) -> None:
lines: List[str] = []
lines.append("# Signals Parity Report")
lines.append("")
@ -299,7 +310,9 @@ def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dic
lines.append("| | " + " | ".join(labels) + " |")
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
for r in labels:
lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |")
lines.append(
f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |"
)
lines.append("")
if sample_diffs:

View file

@ -14,6 +14,7 @@ Usage:
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
@ -30,7 +31,10 @@ try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
print(
"error: install dependencies first: pip install -r requirements.txt",
file=sys.stderr,
)
sys.exit(2)
try:
@ -46,6 +50,7 @@ except ImportError:
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
@ -154,7 +159,10 @@ def sample_conversations(
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
print(
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
file=sys.stderr,
)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
@ -255,9 +263,13 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
).decode().strip()
plano_sha = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
)
.decode()
.strip()
)
except Exception:
plano_sha = "unknown"
try:
@ -265,7 +277,11 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
(
l.split(":", 1)[1].strip()
for l in signals_version.splitlines()
if l.startswith("Version")
),
"unknown",
)
except Exception: