diff --git a/tests/parity/signals/_smoke_test.py b/tests/parity/signals/_smoke_test.py index ccd753bb..68c6e879 100644 --- a/tests/parity/signals/_smoke_test.py +++ b/tests/parity/signals/_smoke_test.py @@ -6,6 +6,7 @@ hand-picked set of conversations without touching the lmsys dataset. Run from this directory: python _smoke_test.py --rust-binary """ + from __future__ import annotations import argparse @@ -38,7 +39,10 @@ SAMPLES = [ "messages": [ {"from": "human", "value": "Book me a flight to NYC for tomorrow"}, {"from": "gpt", "value": "Sure, here are flights to NYC for Friday."}, - {"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"}, + { + "from": "human", + "value": "No, I meant flights for Saturday, not tomorrow", + }, ], }, { @@ -75,7 +79,9 @@ def main() -> int: f.write(json.dumps(s) + "\n") with conv_path.open("rb") as fin, rust_path.open("wb") as fout: - proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE) + proc = subprocess.run( + [str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE + ) if proc.returncode != 0: sys.stderr.write(proc.stderr.decode("utf-8", errors="replace")) return 2 diff --git a/tests/parity/signals/compare.py b/tests/parity/signals/compare.py index 427e5bba..80f56295 100644 --- a/tests/parity/signals/compare.py +++ b/tests/parity/signals/compare.py @@ -5,6 +5,7 @@ Diff Rust vs Python signal reports produced by run_parity.py. See README.md for the tier definitions. Exits non-zero iff any Tier-A divergence is found. """ + from __future__ import annotations import argparse @@ -15,7 +16,12 @@ from pathlib import Path from typing import Any, Dict, List, Tuple CATEGORIES_BY_LAYER = { - "interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"], + "interaction_signals": [ + "misalignment", + "stagnation", + "disengagement", + "satisfaction", + ], "execution_signals": ["failure", "loops"], "environment_signals": ["exhaustion"], } @@ -69,9 +75,7 @@ def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]: return dict(out) -def diff_counts( - a: Dict[str, int], b: Dict[str, int] -) -> List[Tuple[str, int, int]]: +def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]: """Return [(signal_type, a_count, b_count)] for entries that differ.""" keys = set(a) | set(b) out = [] @@ -242,7 +246,12 @@ def main() -> int: write_summary_md(out_dir / "summary.md", metrics, diffs[:20]) - print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2)) + print( + json.dumps( + {k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, + indent=2, + ) + ) print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}") print(f"summary: {out_dir / 'summary.md'}") @@ -252,7 +261,9 @@ def main() -> int: return 0 -def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None: +def write_summary_md( + path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]] +) -> None: lines: List[str] = [] lines.append("# Signals Parity Report") lines.append("") @@ -299,7 +310,9 @@ def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dic lines.append("| | " + " | ".join(labels) + " |") lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|") for r in labels: - lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |") + lines.append( + f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |" + ) lines.append("") if sample_diffs: diff --git a/tests/parity/signals/run_parity.py b/tests/parity/signals/run_parity.py index bdb54966..1d14630e 100644 --- a/tests/parity/signals/run_parity.py +++ b/tests/parity/signals/run_parity.py @@ -14,6 +14,7 @@ Usage: --rust-binary ../../../crates/target/release/signals_replay \\ --output-dir out/ """ + from __future__ import annotations import argparse @@ -30,7 +31,10 @@ try: import pyarrow.parquet as pq from huggingface_hub import hf_hub_download, list_repo_files except ImportError: - print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr) + print( + "error: install dependencies first: pip install -r requirements.txt", + file=sys.stderr, + ) sys.exit(2) try: @@ -46,6 +50,7 @@ except ImportError: try: from tqdm import tqdm except ImportError: + def tqdm(it, **_kwargs): # type: ignore[no-redef] return it @@ -154,7 +159,10 @@ def sample_conversations( pf = pq.ParquetFile(str(p)) shard_row_counts.append(pf.metadata.num_rows) total_rows = sum(shard_row_counts) - print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr) + print( + f"dataset has {total_rows:,} rows across {len(local_paths)} shards", + file=sys.stderr, + ) rng = random.Random(seed) global_indices = sorted(rng.sample(range(total_rows), num_samples)) @@ -255,9 +263,13 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) - """Write the input metadata so compare.py can include it in the report.""" binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest() try: - plano_sha = subprocess.check_output( - ["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent - ).decode().strip() + plano_sha = ( + subprocess.check_output( + ["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent + ) + .decode() + .strip() + ) except Exception: plano_sha = "unknown" try: @@ -265,7 +277,11 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) - [sys.executable, "-m", "pip", "show", "signals"] ).decode() signals_version = next( - (l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")), + ( + l.split(":", 1)[1].strip() + for l in signals_version.splitlines() + if l.startswith("Version") + ), "unknown", ) except Exception: