mirror of
https://github.com/katanemo/plano.git
synced 2026-05-30 14:25:15 +02:00
style: format parity harness with black
Made-with: Cursor
This commit is contained in:
parent
e10b385926
commit
12b6b3d814
3 changed files with 50 additions and 15 deletions
|
|
@ -14,6 +14,7 @@ Usage:
|
|||
--rust-binary ../../../crates/target/release/signals_replay \\
|
||||
--output-dir out/
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
|
|
@ -30,7 +31,10 @@ try:
|
|||
import pyarrow.parquet as pq
|
||||
from huggingface_hub import hf_hub_download, list_repo_files
|
||||
except ImportError:
|
||||
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
|
||||
print(
|
||||
"error: install dependencies first: pip install -r requirements.txt",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(2)
|
||||
|
||||
try:
|
||||
|
|
@ -46,6 +50,7 @@ except ImportError:
|
|||
try:
|
||||
from tqdm import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(it, **_kwargs): # type: ignore[no-redef]
|
||||
return it
|
||||
|
||||
|
|
@ -154,7 +159,10 @@ def sample_conversations(
|
|||
pf = pq.ParquetFile(str(p))
|
||||
shard_row_counts.append(pf.metadata.num_rows)
|
||||
total_rows = sum(shard_row_counts)
|
||||
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
|
||||
print(
|
||||
f"dataset has {total_rows:,} rows across {len(local_paths)} shards",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
rng = random.Random(seed)
|
||||
global_indices = sorted(rng.sample(range(total_rows), num_samples))
|
||||
|
|
@ -255,9 +263,13 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
|
|||
"""Write the input metadata so compare.py can include it in the report."""
|
||||
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
|
||||
try:
|
||||
plano_sha = subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||
).decode().strip()
|
||||
plano_sha = (
|
||||
subprocess.check_output(
|
||||
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||
)
|
||||
.decode()
|
||||
.strip()
|
||||
)
|
||||
except Exception:
|
||||
plano_sha = "unknown"
|
||||
try:
|
||||
|
|
@ -265,7 +277,11 @@ def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -
|
|||
[sys.executable, "-m", "pip", "show", "signals"]
|
||||
).decode()
|
||||
signals_version = next(
|
||||
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
|
||||
(
|
||||
l.split(":", 1)[1].strip()
|
||||
for l in signals_version.splitlines()
|
||||
if l.startswith("Version")
|
||||
),
|
||||
"unknown",
|
||||
)
|
||||
except Exception:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue