mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
style: format parity harness with black
Made-with: Cursor
This commit is contained in:
parent
e10b385926
commit
12b6b3d814
3 changed files with 50 additions and 15 deletions
|
|
@ -6,6 +6,7 @@ hand-picked set of conversations without touching the lmsys dataset.
|
|||
Run from this directory:
|
||||
python _smoke_test.py --rust-binary <path>
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
|
@ -38,7 +39,10 @@ SAMPLES = [
|
|||
"messages": [
|
||||
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
|
||||
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
|
||||
{"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"},
|
||||
{
|
||||
"from": "human",
|
||||
"value": "No, I meant flights for Saturday, not tomorrow",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
|
|
@ -75,7 +79,9 @@ def main() -> int:
|
|||
f.write(json.dumps(s) + "\n")
|
||||
|
||||
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
|
||||
proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE)
|
||||
proc = subprocess.run(
|
||||
[str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
|
||||
return 2
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Diff Rust vs Python signal reports produced by run_parity.py.
|
|||
See README.md for the tier definitions. Exits non-zero iff any Tier-A
|
||||
divergence is found.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
|
@ -15,7 +16,12 @@ from pathlib import Path
|
|||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
CATEGORIES_BY_LAYER = {
|
||||
"interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"],
|
||||
"interaction_signals": [
|
||||
"misalignment",
|
||||
"stagnation",
|
||||
"disengagement",
|
||||
"satisfaction",
|
||||
],
|
||||
"execution_signals": ["failure", "loops"],
|
||||
"environment_signals": ["exhaustion"],
|
||||
}
|
||||
|
|
@ -69,9 +75,7 @@ def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
|
|||
return dict(out)
|
||||
|
||||
|
||||
def diff_counts(
|
||||
a: Dict[str, int], b: Dict[str, int]
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
def diff_counts(a: Dict[str, int], b: Dict[str, int]) -> List[Tuple[str, int, int]]:
|
||||
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
|
||||
keys = set(a) | set(b)
|
||||
out = []
|
||||
|
|
@ -242,7 +246,12 @@ def main() -> int:
|
|||
|
||||
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
|
||||
|
||||
print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2))
|
||||
print(
|
||||
json.dumps(
|
||||
{k: v for k, v in metrics.items() if k != "quality_confusion_matrix"},
|
||||
indent=2,
|
||||
)
|
||||
)
|
||||
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
|
||||
print(f"summary: {out_dir / 'summary.md'}")
|
||||
|
||||
|
|
@ -252,7 +261,9 @@ def main() -> int:
|
|||
return 0
|
||||
|
||||
|
||||
def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None:
|
||||
def write_summary_md(
|
||||
path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]
|
||||
) -> None:
|
||||
lines: List[str] = []
|
||||
lines.append("# Signals Parity Report")
|
||||
lines.append("")
|
||||
|
|
@ -299,7 +310,9 @@ def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dic
|
|||
lines.append("| | " + " | ".join(labels) + " |")
|
||||
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
|
||||
for r in labels:
|
||||
lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |")
|
||||
lines.append(
|
||||
f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |"
|
||||
)
|
||||
lines.append("")
|
||||
|
||||
if sample_diffs:
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ Usage:
|
|||
--rust-binary ../../../crates/target/release/signals_replay \\
|
||||
--output-dir out/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
|
@ -30,7 +31,10 @@ try:
|
|||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
except ImportError:
|
||||
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
|
||||
print(
|
||||
"error: install dependencies first: pip install -r requirements.txt",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
try:
|
||||
|
|
@ -46,6 +50,7 @@ except ImportError:
|
|||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(it, **_kwargs): # type: ignore[no-redef]
|
||||
return it
|
||||
|
||||
|
|
@ -154,7 +159,10 @@ def sample_conversations(
|
|||
pf = pq.ParquetFile(str(p))
|
||||
shard_row_counts.append(pf.metadata.num_rows)
|
||||
total_rows = sum(shard_row_counts)
|
||||
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
|
||||
print(
|
||||
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
rng = random.Random(seed)
|
||||
global_indices = sorted(rng.sample(range(total_rows), num_samples))
|
||||
|
|
@ -255,9 +263,13 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
|
|||
"""Write the input metadata so compare.py can include it in the report."""
|
||||
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
|
||||
try:
|
||||
plano_sha = subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||
).decode().strip()
|
||||
plano_sha = (
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
except Exception:
|
||||
plano_sha = "unknown"
|
||||
try:
|
||||
|
|
@ -265,7 +277,11 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
|
|||
[sys.executable, "-m", "pip", "show", "signals"]
|
||||
).decode()
|
||||
signals_version = next(
|
||||
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
|
||||
(
|
||||
l.split(":", 1)[1].strip()
|
||||
for l in signals_version.splitlines()
|
||||
if l.startswith("Version")
|
||||
),
|
||||
"unknown",
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue