style: format parity harness with black

Made-with: Cursor
This commit is contained in:
Syed Hashmi 2026-04-22 21:34:23 -07:00
parent e10b385926
commit 12b6b3d814
3 changed files with 50 additions and 15 deletions

View file

@ -14,6 +14,7 @@ Usage:
--rust-binary ../../../crates/target/release/signals_replay \\
--output-dir out/
"""
from __future__ import annotations
import argparse
@ -30,7 +31,10 @@ try:
import pyarrow.parquet as pq
from huggingface_hub import hf_hub_download, list_repo_files
except ImportError:
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
print(
"error: install dependencies first: pip install -r requirements.txt",
file=sys.stderr,
)
sys.exit(2)
try:
@ -46,6 +50,7 @@ except ImportError:
try:
from tqdm import tqdm
except ImportError:
def tqdm(it, **_kwargs): # type: ignore[no-redef]
return it
@ -154,7 +159,10 @@ def sample_conversations(
pf = pq.ParquetFile(str(p))
shard_row_counts.append(pf.metadata.num_rows)
total_rows = sum(shard_row_counts)
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
print(
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
file=sys.stderr,
)
rng = random.Random(seed)
global_indices = sorted(rng.sample(range(total_rows), num_samples))
@ -255,9 +263,13 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
"""Write the input metadata so compare.py can include it in the report."""
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
try:
plano_sha = subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
).decode().strip()
plano_sha = (
subprocess.check_output(
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
)
.decode()
.strip()
)
except Exception:
plano_sha = "unknown"
try:
@ -265,7 +277,11 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
[sys.executable, "-m", "pip", "show", "signals"]
).decode()
signals_version = next(
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
(
l.split(":", 1)[1].strip()
for l in signals_version.splitlines()
if l.startswith("Version")
),
"unknown",
)
except Exception: