mirror of
https://github.com/katanemo/plano.git
synced 2026-04-25 00:36:34 +02:00
test: parity harness for rust vs python signals analyzer
Validates the brightstaff signals port against the katanemo/signals Python reference on lmsys/lmsys-chat-1m. Adds a signals_replay bin emitting python- compatible JSON, a pyarrow-based driver (bypasses the datasets loader pickle bug on python 3.14), a 3-tier comparator, and an on-demand workflow_dispatch CI job. Made-with: Cursor
This commit is contained in:
parent
cb2fdc9dbf
commit
ffaa4da22c
9 changed files with 1118 additions and 0 deletions
97
.github/workflows/parity-signals.yml
vendored
Normal file
97
.github/workflows/parity-signals.yml
vendored
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
name: parity-signals
|
||||||
|
|
||||||
|
# On-demand parity validation of the Rust signals port against the Python
|
||||||
|
# reference (https://github.com/katanemo/signals). Not run on every PR
|
||||||
|
# because it downloads several GB of dataset content.
|
||||||
|
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
num_samples:
|
||||||
|
description: "Number of conversations to sample from lmsys-chat-1m"
|
||||||
|
required: true
|
||||||
|
default: "2000"
|
||||||
|
seed:
|
||||||
|
description: "Sampling seed (use the same value for reproducibility)"
|
||||||
|
required: true
|
||||||
|
default: "42"
|
||||||
|
dataset_revision:
|
||||||
|
description: "HF dataset revision (commit SHA). Empty = latest (NOT pinned)."
|
||||||
|
required: false
|
||||||
|
default: ""
|
||||||
|
signals_ref:
|
||||||
|
description: "Git ref of katanemo/signals to install"
|
||||||
|
required: true
|
||||||
|
default: "main"
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
parity:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
timeout-minutes: 90
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Install Rust (stable)
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Cache cargo
|
||||||
|
uses: Swatinem/rust-cache@v2
|
||||||
|
with:
|
||||||
|
workspaces: crates
|
||||||
|
|
||||||
|
- name: Build signals_replay
|
||||||
|
working-directory: crates
|
||||||
|
run: cargo build --release -p brightstaff --bin signals_replay
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install harness deps
|
||||||
|
working-directory: tests/parity/signals
|
||||||
|
run: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
pip install -r requirements.txt
|
||||||
|
pip install "signals @ git+https://github.com/katanemo/signals@${{ inputs.signals_ref }}"
|
||||||
|
|
||||||
|
- name: Authenticate Hugging Face (lmsys-chat-1m is gated)
|
||||||
|
env:
|
||||||
|
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
||||||
|
run: |
|
||||||
|
if [ -z "$HF_TOKEN" ]; then
|
||||||
|
echo "::error::HF_TOKEN secret is required to download lmsys-chat-1m"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
mkdir -p ~/.cache/huggingface
|
||||||
|
echo -n "$HF_TOKEN" > ~/.cache/huggingface/token
|
||||||
|
|
||||||
|
- name: Run parity harness
|
||||||
|
working-directory: tests/parity/signals
|
||||||
|
env:
|
||||||
|
DATASET_REV: ${{ inputs.dataset_revision }}
|
||||||
|
run: |
|
||||||
|
ARGS=(
|
||||||
|
--num-samples "${{ inputs.num_samples }}"
|
||||||
|
--seed "${{ inputs.seed }}"
|
||||||
|
--rust-binary "${GITHUB_WORKSPACE}/crates/target/release/signals_replay"
|
||||||
|
--output-dir out/
|
||||||
|
)
|
||||||
|
if [ -n "$DATASET_REV" ]; then
|
||||||
|
ARGS+=(--dataset-revision "$DATASET_REV")
|
||||||
|
fi
|
||||||
|
python run_parity.py "${ARGS[@]}"
|
||||||
|
|
||||||
|
- name: Compare reports
|
||||||
|
working-directory: tests/parity/signals
|
||||||
|
run: python compare.py --output-dir out/
|
||||||
|
|
||||||
|
- name: Upload artifacts
|
||||||
|
if: always()
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: signals-parity-out
|
||||||
|
path: tests/parity/signals/out/
|
||||||
|
if-no-files-found: warn
|
||||||
|
|
@ -3,6 +3,14 @@ name = "brightstaff"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "brightstaff"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "signals_replay"
|
||||||
|
path = "src/bin/signals_replay.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-openai = "0.30.1"
|
async-openai = "0.30.1"
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
|
|
||||||
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
175
crates/brightstaff/src/bin/signals_replay.rs
Normal file
|
|
@ -0,0 +1,175 @@
|
||||||
|
//! `signals-replay` — batch driver for the `brightstaff` signal analyzer.
|
||||||
|
//!
|
||||||
|
//! Reads JSONL conversations from stdin (one per line) and emits matching
|
||||||
|
//! JSONL reports on stdout, one per input conversation, in the same order.
|
||||||
|
//!
|
||||||
|
//! Input shape (per line):
|
||||||
|
//! ```json
|
||||||
|
//! {"id": "convo-42", "messages": [{"from": "human", "value": "..."}, ...]}
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! Output shape (per line, success):
|
||||||
|
//! ```json
|
||||||
|
//! {"id": "convo-42", "report": { ...python-compatible SignalReport dict... }}
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! On per-line failure (parse / analyzer error), emits:
|
||||||
|
//! ```json
|
||||||
|
//! {"id": "convo-42", "error": "..."}
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
//! The output report dict is shaped to match the Python reference's
|
||||||
|
//! `SignalReport.to_dict()` byte-for-byte so the parity comparator can do a
|
||||||
|
//! direct structural diff.
|
||||||
|
|
||||||
|
use std::io::{self, BufRead, BufWriter, Write};
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
|
||||||
|
use brightstaff::signals::{SignalAnalyzer, SignalGroup, SignalReport};
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct InputLine {
|
||||||
|
id: Value,
|
||||||
|
messages: Vec<MessageRow>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct MessageRow {
|
||||||
|
#[serde(default)]
|
||||||
|
from: String,
|
||||||
|
#[serde(default)]
|
||||||
|
value: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let stdin = io::stdin();
|
||||||
|
let stdout = io::stdout();
|
||||||
|
let mut out = BufWriter::new(stdout.lock());
|
||||||
|
let analyzer = SignalAnalyzer::default();
|
||||||
|
|
||||||
|
for line in stdin.lock().lines() {
|
||||||
|
let line = match line {
|
||||||
|
Ok(l) => l,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("read error: {e}");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let trimmed = line.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let result = process_line(&analyzer, trimmed);
|
||||||
|
// Always emit one line per input line so id ordering stays aligned.
|
||||||
|
if let Err(e) = writeln!(out, "{result}") {
|
||||||
|
eprintln!("write error: {e}");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
// Flush periodically isn't strictly needed — BufWriter handles it,
|
||||||
|
// and the parent process reads the whole stream when we're done.
|
||||||
|
}
|
||||||
|
let _ = out.flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_line(analyzer: &SignalAnalyzer, line: &str) -> Value {
|
||||||
|
let parsed: InputLine = match serde_json::from_str(line) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return json!({
|
||||||
|
"id": Value::Null,
|
||||||
|
"error": format!("input parse: {e}"),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let id = parsed.id.clone();
|
||||||
|
|
||||||
|
let view: Vec<brightstaff::signals::analyzer::ShareGptMessage<'_>> = parsed
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| brightstaff::signals::analyzer::ShareGptMessage {
|
||||||
|
from: m.from.as_str(),
|
||||||
|
value: m.value.as_str(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let report = analyzer.analyze_sharegpt(&view);
|
||||||
|
let report_dict = report_to_python_dict(&report);
|
||||||
|
json!({
|
||||||
|
"id": id,
|
||||||
|
"report": report_dict,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a `SignalReport` into the Python reference's `to_dict()` shape.
|
||||||
|
///
|
||||||
|
/// Ordering of category keys in each layer dict follows the Python source
|
||||||
|
/// exactly so even string-equality comparisons behave deterministically.
|
||||||
|
fn report_to_python_dict(r: &SignalReport) -> Value {
|
||||||
|
let mut interaction = Map::new();
|
||||||
|
interaction.insert(
|
||||||
|
"misalignment".to_string(),
|
||||||
|
signal_group_to_python(&r.interaction.misalignment),
|
||||||
|
);
|
||||||
|
interaction.insert(
|
||||||
|
"stagnation".to_string(),
|
||||||
|
signal_group_to_python(&r.interaction.stagnation),
|
||||||
|
);
|
||||||
|
interaction.insert(
|
||||||
|
"disengagement".to_string(),
|
||||||
|
signal_group_to_python(&r.interaction.disengagement),
|
||||||
|
);
|
||||||
|
interaction.insert(
|
||||||
|
"satisfaction".to_string(),
|
||||||
|
signal_group_to_python(&r.interaction.satisfaction),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut execution = Map::new();
|
||||||
|
execution.insert(
|
||||||
|
"failure".to_string(),
|
||||||
|
signal_group_to_python(&r.execution.failure),
|
||||||
|
);
|
||||||
|
execution.insert(
|
||||||
|
"loops".to_string(),
|
||||||
|
signal_group_to_python(&r.execution.loops),
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut environment = Map::new();
|
||||||
|
environment.insert(
|
||||||
|
"exhaustion".to_string(),
|
||||||
|
signal_group_to_python(&r.environment.exhaustion),
|
||||||
|
);
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"interaction_signals": Value::Object(interaction),
|
||||||
|
"execution_signals": Value::Object(execution),
|
||||||
|
"environment_signals": Value::Object(environment),
|
||||||
|
"overall_quality": r.overall_quality.as_str(),
|
||||||
|
"summary": r.summary,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn signal_group_to_python(g: &SignalGroup) -> Value {
|
||||||
|
let signals: Vec<Value> = g
|
||||||
|
.signals
|
||||||
|
.iter()
|
||||||
|
.map(|s| {
|
||||||
|
json!({
|
||||||
|
"signal_type": s.signal_type.as_str(),
|
||||||
|
"message_index": s.message_index,
|
||||||
|
"snippet": s.snippet,
|
||||||
|
"confidence": s.confidence,
|
||||||
|
"metadata": s.metadata,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"category": g.category,
|
||||||
|
"count": g.count,
|
||||||
|
"severity": g.severity,
|
||||||
|
"signals": signals,
|
||||||
|
})
|
||||||
|
}
|
||||||
4
tests/parity/signals/.gitignore
vendored
Normal file
4
tests/parity/signals/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,4 @@
|
||||||
|
out/
|
||||||
|
.venv/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
98
tests/parity/signals/README.md
Normal file
98
tests/parity/signals/README.md
Normal file
|
|
@ -0,0 +1,98 @@
|
||||||
|
# Signals Parity Harness
|
||||||
|
|
||||||
|
Validates that `crates/brightstaff/src/signals/` (Rust port) produces the same
|
||||||
|
`SignalReport` as the Python reference at <https://github.com/katanemo/signals>
|
||||||
|
on a fixed sample of `lmsys/lmsys-chat-1m` conversations.
|
||||||
|
|
||||||
|
This harness is **not** part of normal CI. It downloads several GB and is run
|
||||||
|
on demand to gate releases of the signals subsystem (or to investigate
|
||||||
|
regressions reported in production).
|
||||||
|
|
||||||
|
## What gets compared
|
||||||
|
|
||||||
|
For each conversation, both analyzers emit a `SignalReport`. The comparator
|
||||||
|
classifies any divergence into three tiers:
|
||||||
|
|
||||||
|
| Tier | Field | Action on divergence |
|
||||||
|
|------|------------------------------------------------|----------------------|
|
||||||
|
| A | set of `SignalType` present, per-type counts, `overall_quality` | Fail the run |
|
||||||
|
| B | per-instance `message_index`, instance counts per type | Log + collect, do not fail |
|
||||||
|
| C | metadata, snippet text, summary | Information only |
|
||||||
|
|
||||||
|
Quality buckets are compared by string (`excellent` / `good` / ...).
|
||||||
|
|
||||||
|
## What this harness does *not* cover
|
||||||
|
|
||||||
|
`lmsys-chat-1m` is plain user/assistant chat. It exercises the **interaction**
|
||||||
|
layer well (misalignment, stagnation, disengagement, satisfaction) but does
|
||||||
|
**not** exercise:
|
||||||
|
|
||||||
|
- `execution.failure.*`
|
||||||
|
- `execution.loops.*`
|
||||||
|
- `environment.exhaustion.*`
|
||||||
|
|
||||||
|
Those signals require `function_call` / `observation` ShareGPT roles. They are
|
||||||
|
covered by the Rust unit tests and the Python repo's own test fixtures, both
|
||||||
|
of which run on every PR. A synthetic tool-trace dataset for full coverage is
|
||||||
|
deferred to a follow-up.
|
||||||
|
|
||||||
|
## One-time setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 1. Build the Rust replay binary.
|
||||||
|
cd ../../../crates && cargo build --release -p brightstaff --bin signals_replay
|
||||||
|
|
||||||
|
# 2. Set up the Python environment for the harness driver.
|
||||||
|
cd ../tests/parity/signals
|
||||||
|
python3 -m venv .venv && source .venv/bin/activate
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# 3. Install the Python signals reference.
|
||||||
|
# Either point at a local checkout:
|
||||||
|
pip install -e /path/to/signals
|
||||||
|
# or pull from git:
|
||||||
|
pip install 'signals @ git+https://github.com/katanemo/signals@<sha>'
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running
|
||||||
|
|
||||||
|
```bash
|
||||||
|
source .venv/bin/activate
|
||||||
|
|
||||||
|
python run_parity.py \
|
||||||
|
--num-samples 2000 \
|
||||||
|
--seed 42 \
|
||||||
|
--dataset-revision <hf-dataset-revision-sha> \
|
||||||
|
--rust-binary ../../../crates/target/release/signals_replay \
|
||||||
|
--output-dir out/
|
||||||
|
|
||||||
|
python compare.py --output-dir out/
|
||||||
|
```
|
||||||
|
|
||||||
|
`run_parity.py` will:
|
||||||
|
|
||||||
|
1. Download `lmsys/lmsys-chat-1m` (cached in `~/.cache/huggingface`).
|
||||||
|
2. Pick `--num-samples` rows under `--seed`.
|
||||||
|
3. Convert each to ShareGPT, write `out/conversations.jsonl`.
|
||||||
|
4. Run the Rust binary as a subprocess → `out/rust_reports.jsonl`.
|
||||||
|
5. Run the Python analyzer in-process → `out/python_reports.jsonl`.
|
||||||
|
|
||||||
|
`compare.py` reads both report files and writes:
|
||||||
|
|
||||||
|
- `out/diffs.jsonl` — one record per mismatched conversation, with tier + structural diff
|
||||||
|
- `out/metrics.json` — agreement %, per-`SignalType` confusion matrix, quality-bucket confusion matrix
|
||||||
|
- `out/summary.md` — human-readable PR-ready report
|
||||||
|
|
||||||
|
Exit code is non-zero iff any Tier-A divergence is observed.
|
||||||
|
|
||||||
|
## Reproducibility
|
||||||
|
|
||||||
|
Every run pins:
|
||||||
|
|
||||||
|
- `dataset_revision` — the HF dataset commit
|
||||||
|
- `seed` — RNG seed for sampling
|
||||||
|
- `signals_python_version` — `pip show signals` version
|
||||||
|
- `plano_git_sha` — `git rev-parse HEAD` of this repo
|
||||||
|
- `signals_replay_binary_sha256` — the hash of the Rust bin
|
||||||
|
|
||||||
|
All are stamped into `metrics.json`.
|
||||||
97
tests/parity/signals/_smoke_test.py
Normal file
97
tests/parity/signals/_smoke_test.py
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Local smoke test for the parity harness — runs both runners on a tiny
|
||||||
|
hand-picked set of conversations without touching the lmsys dataset.
|
||||||
|
|
||||||
|
Run from this directory:
|
||||||
|
python _smoke_test.py --rust-binary <path>
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from signals.analyzer import SignalAnalyzer
|
||||||
|
|
||||||
|
SAMPLES = [
|
||||||
|
{
|
||||||
|
"id": "smoke-gratitude",
|
||||||
|
"messages": [
|
||||||
|
{"from": "human", "value": "What is the weather in Istanbul?"},
|
||||||
|
{"from": "gpt", "value": "Istanbul is 14C and partly cloudy."},
|
||||||
|
{"from": "human", "value": "That worked, exactly what I needed. Thanks!"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "smoke-escalation",
|
||||||
|
"messages": [
|
||||||
|
{"from": "human", "value": "This isn't helpful at all"},
|
||||||
|
{"from": "gpt", "value": "I'm sorry, can you tell me more?"},
|
||||||
|
{"from": "human", "value": "Get me a human, this is useless"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "smoke-correction",
|
||||||
|
"messages": [
|
||||||
|
{"from": "human", "value": "Book me a flight to NYC for tomorrow"},
|
||||||
|
{"from": "gpt", "value": "Sure, here are flights to NYC for Friday."},
|
||||||
|
{"from": "human", "value": "No, I meant flights for Saturday, not tomorrow"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "smoke-clean",
|
||||||
|
"messages": [
|
||||||
|
{"from": "human", "value": "Hi"},
|
||||||
|
{"from": "gpt", "value": "Hello, how can I help?"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "smoke-rephrase",
|
||||||
|
"messages": [
|
||||||
|
{"from": "human", "value": "Can you summarize the news please"},
|
||||||
|
{"from": "gpt", "value": "Sure, here is a summary."},
|
||||||
|
{"from": "human", "value": "Could you please summarize the news"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
p = argparse.ArgumentParser()
|
||||||
|
p.add_argument("--rust-binary", required=True, type=Path)
|
||||||
|
args = p.parse_args()
|
||||||
|
|
||||||
|
out_dir = Path("out_smoke")
|
||||||
|
out_dir.mkdir(exist_ok=True)
|
||||||
|
conv_path = out_dir / "conversations.jsonl"
|
||||||
|
rust_path = out_dir / "rust_reports.jsonl"
|
||||||
|
py_path = out_dir / "python_reports.jsonl"
|
||||||
|
|
||||||
|
with conv_path.open("w") as f:
|
||||||
|
for s in SAMPLES:
|
||||||
|
f.write(json.dumps(s) + "\n")
|
||||||
|
|
||||||
|
with conv_path.open("rb") as fin, rust_path.open("wb") as fout:
|
||||||
|
proc = subprocess.run([str(args.rust_binary)], stdin=fin, stdout=fout, stderr=subprocess.PIPE)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
|
||||||
|
return 2
|
||||||
|
|
||||||
|
analyzer = SignalAnalyzer()
|
||||||
|
with conv_path.open() as fin, py_path.open("w") as fout:
|
||||||
|
for line in fin:
|
||||||
|
obj = json.loads(line)
|
||||||
|
r = analyzer.analyze(obj["messages"])
|
||||||
|
fout.write(json.dumps({"id": obj["id"], "report": r.to_dict()}) + "\n")
|
||||||
|
|
||||||
|
rc = subprocess.call(
|
||||||
|
[sys.executable, "compare.py", "--output-dir", str(out_dir)],
|
||||||
|
)
|
||||||
|
return rc
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
320
tests/parity/signals/compare.py
Normal file
320
tests/parity/signals/compare.py
Normal file
|
|
@ -0,0 +1,320 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Diff Rust vs Python signal reports produced by run_parity.py.
|
||||||
|
|
||||||
|
See README.md for the tier definitions. Exits non-zero iff any Tier-A
|
||||||
|
divergence is found.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from collections import Counter, defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Tuple
|
||||||
|
|
||||||
|
CATEGORIES_BY_LAYER = {
|
||||||
|
"interaction_signals": ["misalignment", "stagnation", "disengagement", "satisfaction"],
|
||||||
|
"execution_signals": ["failure", "loops"],
|
||||||
|
"environment_signals": ["exhaustion"],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__)
|
||||||
|
p.add_argument("--output-dir", type=Path, default=Path("out"))
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_jsonl(path: Path) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""Load a JSONL file keyed by `id`. Lines with errors are still indexed."""
|
||||||
|
out: Dict[str, Dict[str, Any]] = {}
|
||||||
|
with path.open() as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
obj = json.loads(line)
|
||||||
|
out[str(obj.get("id"))] = obj
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def per_type_counts(report: Dict[str, Any]) -> Dict[str, int]:
|
||||||
|
"""Return {signal_type: count} across all groups in a report dict."""
|
||||||
|
counts: Counter[str] = Counter()
|
||||||
|
for layer in CATEGORIES_BY_LAYER:
|
||||||
|
groups = report.get(layer, {}) or {}
|
||||||
|
for category in CATEGORIES_BY_LAYER[layer]:
|
||||||
|
group = groups.get(category)
|
||||||
|
if not group:
|
||||||
|
continue
|
||||||
|
for sig in group.get("signals", []) or []:
|
||||||
|
counts[sig["signal_type"]] += 1
|
||||||
|
return dict(counts)
|
||||||
|
|
||||||
|
|
||||||
|
def per_type_indices(report: Dict[str, Any]) -> Dict[str, List[int]]:
|
||||||
|
out: Dict[str, List[int]] = defaultdict(list)
|
||||||
|
for layer in CATEGORIES_BY_LAYER:
|
||||||
|
groups = report.get(layer, {}) or {}
|
||||||
|
for category in CATEGORIES_BY_LAYER[layer]:
|
||||||
|
group = groups.get(category)
|
||||||
|
if not group:
|
||||||
|
continue
|
||||||
|
for sig in group.get("signals", []) or []:
|
||||||
|
out[sig["signal_type"]].append(sig.get("message_index"))
|
||||||
|
for k in out:
|
||||||
|
out[k].sort(key=lambda x: (x is None, x))
|
||||||
|
return dict(out)
|
||||||
|
|
||||||
|
|
||||||
|
def diff_counts(
|
||||||
|
a: Dict[str, int], b: Dict[str, int]
|
||||||
|
) -> List[Tuple[str, int, int]]:
|
||||||
|
"""Return [(signal_type, a_count, b_count)] for entries that differ."""
|
||||||
|
keys = set(a) | set(b)
|
||||||
|
out = []
|
||||||
|
for k in sorted(keys):
|
||||||
|
ac = a.get(k, 0)
|
||||||
|
bc = b.get(k, 0)
|
||||||
|
if ac != bc:
|
||||||
|
out.append((k, ac, bc))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def diff_indices(
|
||||||
|
a: Dict[str, List[int]], b: Dict[str, List[int]]
|
||||||
|
) -> List[Tuple[str, List[int], List[int]]]:
|
||||||
|
keys = set(a) | set(b)
|
||||||
|
out = []
|
||||||
|
for k in sorted(keys):
|
||||||
|
ai = a.get(k, [])
|
||||||
|
bi = b.get(k, [])
|
||||||
|
if ai != bi:
|
||||||
|
out.append((k, ai, bi))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def compare_one(
|
||||||
|
convo_id: str, py: Dict[str, Any], rust: Dict[str, Any]
|
||||||
|
) -> Dict[str, Any] | None:
|
||||||
|
"""Compare a single conversation. Return diff record, or None if identical."""
|
||||||
|
if "error" in py or "error" in rust:
|
||||||
|
return {
|
||||||
|
"id": convo_id,
|
||||||
|
"tier": "A",
|
||||||
|
"kind": "error_in_runner",
|
||||||
|
"python_error": py.get("error"),
|
||||||
|
"rust_error": rust.get("error"),
|
||||||
|
}
|
||||||
|
py_report = py["report"]
|
||||||
|
rust_report = rust["report"]
|
||||||
|
|
||||||
|
py_counts = per_type_counts(py_report)
|
||||||
|
rust_counts = per_type_counts(rust_report)
|
||||||
|
count_diff = diff_counts(py_counts, rust_counts)
|
||||||
|
|
||||||
|
py_quality = py_report.get("overall_quality")
|
||||||
|
rust_quality = rust_report.get("overall_quality")
|
||||||
|
quality_mismatch = py_quality != rust_quality
|
||||||
|
|
||||||
|
if count_diff or quality_mismatch:
|
||||||
|
return {
|
||||||
|
"id": convo_id,
|
||||||
|
"tier": "A",
|
||||||
|
"kind": "signal_or_quality_mismatch",
|
||||||
|
"quality": {"python": py_quality, "rust": rust_quality},
|
||||||
|
"count_diff": [
|
||||||
|
{"signal_type": st, "python": pc, "rust": rc}
|
||||||
|
for (st, pc, rc) in count_diff
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
py_idx = per_type_indices(py_report)
|
||||||
|
rust_idx = per_type_indices(rust_report)
|
||||||
|
idx_diff = diff_indices(py_idx, rust_idx)
|
||||||
|
if idx_diff:
|
||||||
|
return {
|
||||||
|
"id": convo_id,
|
||||||
|
"tier": "B",
|
||||||
|
"kind": "instance_index_mismatch",
|
||||||
|
"diff": [
|
||||||
|
{"signal_type": st, "python_indices": pi, "rust_indices": ri}
|
||||||
|
for (st, pi, ri) in idx_diff
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def confusion_matrix(
|
||||||
|
pairs: List[Tuple[str, str]], labels: List[str]
|
||||||
|
) -> Dict[str, Dict[str, int]]:
|
||||||
|
cm: Dict[str, Dict[str, int]] = {a: {b: 0 for b in labels} for a in labels}
|
||||||
|
for py, rust in pairs:
|
||||||
|
if py not in cm:
|
||||||
|
cm[py] = {b: 0 for b in labels}
|
||||||
|
if rust not in cm[py]:
|
||||||
|
cm[py][rust] = 0
|
||||||
|
cm[py][rust] += 1
|
||||||
|
return cm
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
args = parse_args()
|
||||||
|
out_dir = args.output_dir
|
||||||
|
|
||||||
|
py_reports = load_jsonl(out_dir / "python_reports.jsonl")
|
||||||
|
rust_reports = load_jsonl(out_dir / "rust_reports.jsonl")
|
||||||
|
|
||||||
|
common_ids = sorted(set(py_reports) & set(rust_reports))
|
||||||
|
only_py = sorted(set(py_reports) - set(rust_reports))
|
||||||
|
only_rust = sorted(set(rust_reports) - set(py_reports))
|
||||||
|
|
||||||
|
diffs: List[Dict[str, Any]] = []
|
||||||
|
quality_pairs: List[Tuple[str, str]] = []
|
||||||
|
per_type_total = Counter()
|
||||||
|
per_type_disagree = Counter()
|
||||||
|
|
||||||
|
tier_a = 0
|
||||||
|
tier_b = 0
|
||||||
|
for cid in common_ids:
|
||||||
|
d = compare_one(cid, py_reports[cid], rust_reports[cid])
|
||||||
|
if d is None:
|
||||||
|
quality_pairs.append(
|
||||||
|
(
|
||||||
|
py_reports[cid]["report"]["overall_quality"],
|
||||||
|
rust_reports[cid]["report"]["overall_quality"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for st, _ in per_type_counts(py_reports[cid]["report"]).items():
|
||||||
|
per_type_total[st] += 1
|
||||||
|
else:
|
||||||
|
diffs.append(d)
|
||||||
|
if d["tier"] == "A":
|
||||||
|
tier_a += 1
|
||||||
|
elif d["tier"] == "B":
|
||||||
|
tier_b += 1
|
||||||
|
if "report" in py_reports[cid] and "report" in rust_reports[cid]:
|
||||||
|
quality_pairs.append(
|
||||||
|
(
|
||||||
|
py_reports[cid]["report"].get("overall_quality", "?"),
|
||||||
|
rust_reports[cid]["report"].get("overall_quality", "?"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for cd in d.get("count_diff", []) or []:
|
||||||
|
per_type_disagree[cd["signal_type"]] += 1
|
||||||
|
per_type_total[cd["signal_type"]] += 1
|
||||||
|
|
||||||
|
n_total = len(common_ids)
|
||||||
|
n_match = n_total - len(diffs)
|
||||||
|
agreement = (n_match / n_total) if n_total else 0.0
|
||||||
|
|
||||||
|
quality_labels = ["excellent", "good", "neutral", "poor", "severe"]
|
||||||
|
cm = confusion_matrix(quality_pairs, quality_labels)
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"n_python_reports": len(py_reports),
|
||||||
|
"n_rust_reports": len(rust_reports),
|
||||||
|
"n_common": n_total,
|
||||||
|
"n_only_python": len(only_py),
|
||||||
|
"n_only_rust": len(only_rust),
|
||||||
|
"n_full_match": n_match,
|
||||||
|
"agreement_pct": round(100.0 * agreement, 4),
|
||||||
|
"tier_a_divergences": tier_a,
|
||||||
|
"tier_b_divergences": tier_b,
|
||||||
|
"quality_confusion_matrix": cm,
|
||||||
|
"per_signal_type_total": dict(per_type_total),
|
||||||
|
"per_signal_type_disagree": dict(per_type_disagree),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Pull in run metadata if present.
|
||||||
|
rm_path = out_dir / "run_metadata.json"
|
||||||
|
if rm_path.exists():
|
||||||
|
metrics["run_metadata"] = json.loads(rm_path.read_text())
|
||||||
|
|
||||||
|
(out_dir / "metrics.json").write_text(json.dumps(metrics, indent=2))
|
||||||
|
with (out_dir / "diffs.jsonl").open("w") as f:
|
||||||
|
for d in diffs:
|
||||||
|
f.write(json.dumps(d, ensure_ascii=False))
|
||||||
|
f.write("\n")
|
||||||
|
|
||||||
|
write_summary_md(out_dir / "summary.md", metrics, diffs[:20])
|
||||||
|
|
||||||
|
print(json.dumps({k: v for k, v in metrics.items() if k != "quality_confusion_matrix"}, indent=2))
|
||||||
|
print(f"\ndiffs: {out_dir / 'diffs.jsonl'} metrics: {out_dir / 'metrics.json'}")
|
||||||
|
print(f"summary: {out_dir / 'summary.md'}")
|
||||||
|
|
||||||
|
if tier_a > 0:
|
||||||
|
print(f"\nFAIL: {tier_a} Tier-A divergence(s) detected.", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def write_summary_md(path: Path, metrics: Dict[str, Any], sample_diffs: List[Dict[str, Any]]) -> None:
|
||||||
|
lines: List[str] = []
|
||||||
|
lines.append("# Signals Parity Report")
|
||||||
|
lines.append("")
|
||||||
|
rm = metrics.get("run_metadata", {})
|
||||||
|
if rm:
|
||||||
|
lines.append("## Run metadata")
|
||||||
|
lines.append("")
|
||||||
|
for k in (
|
||||||
|
"dataset_name",
|
||||||
|
"dataset_revision",
|
||||||
|
"seed",
|
||||||
|
"num_samples_actual",
|
||||||
|
"plano_git_sha",
|
||||||
|
"signals_python_version",
|
||||||
|
"rust_binary_sha256",
|
||||||
|
):
|
||||||
|
if k in rm:
|
||||||
|
lines.append(f"- **{k}**: `{rm[k]}`")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("## Summary")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(f"- Conversations compared: **{metrics['n_common']}**")
|
||||||
|
lines.append(f"- Full matches: **{metrics['n_full_match']}**")
|
||||||
|
lines.append(f"- Agreement: **{metrics['agreement_pct']}%**")
|
||||||
|
lines.append(f"- Tier-A divergences: **{metrics['tier_a_divergences']}**")
|
||||||
|
lines.append(f"- Tier-B divergences: **{metrics['tier_b_divergences']}**")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("## Per-signal-type disagreement")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("| Signal type | Total reports | Disagreements |")
|
||||||
|
lines.append("|---|---:|---:|")
|
||||||
|
totals = metrics["per_signal_type_total"]
|
||||||
|
disagrees = metrics["per_signal_type_disagree"]
|
||||||
|
for k in sorted(set(totals) | set(disagrees)):
|
||||||
|
lines.append(f"| `{k}` | {totals.get(k, 0)} | {disagrees.get(k, 0)} |")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append("## Quality bucket confusion matrix (rows = python, cols = rust)")
|
||||||
|
lines.append("")
|
||||||
|
cm = metrics["quality_confusion_matrix"]
|
||||||
|
labels = list(cm.keys())
|
||||||
|
lines.append("| | " + " | ".join(labels) + " |")
|
||||||
|
lines.append("|---|" + "|".join(["---:"] * len(labels)) + "|")
|
||||||
|
for r in labels:
|
||||||
|
lines.append(f"| {r} | " + " | ".join(str(cm[r].get(c, 0)) for c in labels) + " |")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
if sample_diffs:
|
||||||
|
lines.append("## Sample divergences (first 20)")
|
||||||
|
lines.append("")
|
||||||
|
for d in sample_diffs:
|
||||||
|
lines.append(f"### `{d['id']}` — tier {d['tier']} — {d['kind']}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("```json")
|
||||||
|
lines.append(json.dumps(d, indent=2))
|
||||||
|
lines.append("```")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
path.write_text("\n".join(lines))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
3
tests/parity/signals/requirements.txt
Normal file
3
tests/parity/signals/requirements.txt
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
huggingface_hub>=0.25
|
||||||
|
pyarrow>=15
|
||||||
|
tqdm>=4.65
|
||||||
316
tests/parity/signals/run_parity.py
Normal file
316
tests/parity/signals/run_parity.py
Normal file
|
|
@ -0,0 +1,316 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Parity harness driver.
|
||||||
|
|
||||||
|
Samples conversations from `lmsys/lmsys-chat-1m`, runs both the Python
|
||||||
|
reference analyzer (in-process) and the Rust port (subprocess), writes both
|
||||||
|
reports to disk for `compare.py` to diff.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python run_parity.py \\
|
||||||
|
--num-samples 2000 \\
|
||||||
|
--seed 42 \\
|
||||||
|
--dataset-revision <hf-revision-sha> \\
|
||||||
|
--rust-binary ../../../crates/target/release/signals_replay \\
|
||||||
|
--output-dir out/
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Iterator, List
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pyarrow.parquet as pq
|
||||||
|
from huggingface_hub import hf_hub_download, list_repo_files
|
||||||
|
except ImportError:
|
||||||
|
print("error: install dependencies first: pip install -r requirements.txt", file=sys.stderr)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from signals.analyzer import SignalAnalyzer
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
"error: the python `signals` package is not installed. "
|
||||||
|
"install it from your local checkout: pip install -e /path/to/signals",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(2)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm
|
||||||
|
except ImportError:
|
||||||
|
def tqdm(it, **_kwargs): # type: ignore[no-redef]
|
||||||
|
return it
|
||||||
|
|
||||||
|
|
||||||
|
DATASET_NAME = "lmsys/lmsys-chat-1m"
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
p = argparse.ArgumentParser(description=__doc__)
|
||||||
|
p.add_argument("--num-samples", type=int, default=2000)
|
||||||
|
p.add_argument("--seed", type=int, default=42)
|
||||||
|
p.add_argument(
|
||||||
|
"--dataset-revision",
|
||||||
|
default=None,
|
||||||
|
help="HF dataset revision to pin (default: latest, NOT recommended for reproducibility)",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--rust-binary",
|
||||||
|
type=Path,
|
||||||
|
required=True,
|
||||||
|
help="path to the `signals_replay` binary built from crates/brightstaff",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--output-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("out"),
|
||||||
|
help="directory to write the conversations + both runners' outputs",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max-conv-messages",
|
||||||
|
type=int,
|
||||||
|
default=200,
|
||||||
|
help="drop conversations with more than this many messages (the analyzer "
|
||||||
|
"truncates to last 100 anyway; this is a sanity cap on input parsing)",
|
||||||
|
)
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def lmsys_to_sharegpt(conversation: List[Dict[str, str]]) -> List[Dict[str, str]]:
|
||||||
|
"""Convert lmsys-chat-1m's `[{role, content}]` to ShareGPT's `[{from, value}]`.
|
||||||
|
|
||||||
|
lmsys uses `user` / `assistant` (no tools, no system role in `conversation`).
|
||||||
|
"""
|
||||||
|
out = []
|
||||||
|
for m in conversation:
|
||||||
|
role = m.get("role", "")
|
||||||
|
content = m.get("content", "")
|
||||||
|
if not isinstance(content, str):
|
||||||
|
content = str(content) if content is not None else ""
|
||||||
|
if role == "user":
|
||||||
|
from_ = "human"
|
||||||
|
elif role == "assistant":
|
||||||
|
from_ = "gpt"
|
||||||
|
else:
|
||||||
|
# lmsys is human/assistant only; skip anything else defensively.
|
||||||
|
continue
|
||||||
|
out.append({"from": from_, "value": content})
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _list_parquet_files(revision: str | None) -> List[str]:
|
||||||
|
"""Return the list of parquet shard paths in the dataset repo."""
|
||||||
|
files = list_repo_files(DATASET_NAME, repo_type="dataset", revision=revision)
|
||||||
|
return sorted(f for f in files if f.endswith(".parquet"))
|
||||||
|
|
||||||
|
|
||||||
|
def _download_shards(paths: List[str], revision: str | None) -> List[Path]:
|
||||||
|
"""Download each parquet shard to the HF cache, return local paths."""
|
||||||
|
local: List[Path] = []
|
||||||
|
for rel in tqdm(paths, desc="downloading shards", unit="shard"):
|
||||||
|
p = hf_hub_download(
|
||||||
|
DATASET_NAME,
|
||||||
|
filename=rel,
|
||||||
|
repo_type="dataset",
|
||||||
|
revision=revision,
|
||||||
|
)
|
||||||
|
local.append(Path(p))
|
||||||
|
return local
|
||||||
|
|
||||||
|
|
||||||
|
def sample_conversations(
|
||||||
|
*,
|
||||||
|
num_samples: int,
|
||||||
|
seed: int,
|
||||||
|
revision: str | None,
|
||||||
|
max_conv_messages: int,
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
"""Yield `num_samples` conversations sampled uniformly across the dataset.
|
||||||
|
|
||||||
|
We bypass the `datasets` loader (which has a Python 3.14 pickle issue)
|
||||||
|
and read the parquet shards directly via pyarrow.
|
||||||
|
"""
|
||||||
|
print(
|
||||||
|
f"listing {DATASET_NAME}"
|
||||||
|
f"{' @ ' + revision if revision else ' (no revision pinned!)'}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
shard_paths = _list_parquet_files(revision)
|
||||||
|
if not shard_paths:
|
||||||
|
raise SystemExit(f"no parquet shards found for {DATASET_NAME}")
|
||||||
|
local_paths = _download_shards(shard_paths, revision)
|
||||||
|
|
||||||
|
# Collect row counts without reading data.
|
||||||
|
shard_row_counts: List[int] = []
|
||||||
|
for p in local_paths:
|
||||||
|
pf = pq.ParquetFile(str(p))
|
||||||
|
shard_row_counts.append(pf.metadata.num_rows)
|
||||||
|
total_rows = sum(shard_row_counts)
|
||||||
|
print(f"dataset has {total_rows:,} rows across {len(local_paths)} shards", file=sys.stderr)
|
||||||
|
|
||||||
|
rng = random.Random(seed)
|
||||||
|
global_indices = sorted(rng.sample(range(total_rows), num_samples))
|
||||||
|
|
||||||
|
# Bucket indices by shard.
|
||||||
|
by_shard: Dict[int, List[int]] = {}
|
||||||
|
cumulative = 0
|
||||||
|
shard_offsets = []
|
||||||
|
for c in shard_row_counts:
|
||||||
|
shard_offsets.append(cumulative)
|
||||||
|
cumulative += c
|
||||||
|
for gi in global_indices:
|
||||||
|
# Find which shard this index belongs to.
|
||||||
|
for si, off in enumerate(shard_offsets):
|
||||||
|
if gi < off + shard_row_counts[si]:
|
||||||
|
by_shard.setdefault(si, []).append(gi - off)
|
||||||
|
break
|
||||||
|
|
||||||
|
yielded = 0
|
||||||
|
for si in sorted(by_shard.keys()):
|
||||||
|
local_rows = by_shard[si]
|
||||||
|
pf = pq.ParquetFile(str(local_paths[si]))
|
||||||
|
table = pf.read(columns=["conversation"])
|
||||||
|
conv_col = table.column("conversation")
|
||||||
|
for local_idx in local_rows:
|
||||||
|
raw = conv_col[local_idx].as_py()
|
||||||
|
if not raw:
|
||||||
|
continue
|
||||||
|
conversation = raw if isinstance(raw, list) else raw.get("conversation", [])
|
||||||
|
if len(conversation) > max_conv_messages:
|
||||||
|
continue
|
||||||
|
messages = lmsys_to_sharegpt(conversation)
|
||||||
|
if not messages:
|
||||||
|
continue
|
||||||
|
global_idx = shard_offsets[si] + local_idx
|
||||||
|
yield {
|
||||||
|
"id": f"lmsys-{global_idx}",
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
yielded += 1
|
||||||
|
print(f"yielded {yielded} conversations after filtering", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def write_conversations(out_path: Path, samples: Iterator[Dict[str, Any]]) -> int:
|
||||||
|
n = 0
|
||||||
|
with out_path.open("w") as f:
|
||||||
|
for s in tqdm(samples, desc="sampling", unit="convo"):
|
||||||
|
f.write(json.dumps(s, ensure_ascii=False))
|
||||||
|
f.write("\n")
|
||||||
|
n += 1
|
||||||
|
return n
|
||||||
|
|
||||||
|
|
||||||
|
def run_rust(rust_binary: Path, conv_path: Path, out_path: Path) -> None:
|
||||||
|
print(f"running rust analyzer: {rust_binary}", file=sys.stderr)
|
||||||
|
t0 = time.monotonic()
|
||||||
|
with conv_path.open("rb") as fin, out_path.open("wb") as fout:
|
||||||
|
proc = subprocess.run(
|
||||||
|
[str(rust_binary)],
|
||||||
|
stdin=fin,
|
||||||
|
stdout=fout,
|
||||||
|
stderr=subprocess.PIPE,
|
||||||
|
check=False,
|
||||||
|
)
|
||||||
|
if proc.returncode != 0:
|
||||||
|
sys.stderr.write(proc.stderr.decode("utf-8", errors="replace"))
|
||||||
|
raise SystemExit(f"rust runner exited {proc.returncode}")
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
print(f" rust runner: {elapsed:.1f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def run_python(conv_path: Path, out_path: Path) -> None:
|
||||||
|
print("running python analyzer...", file=sys.stderr)
|
||||||
|
t0 = time.monotonic()
|
||||||
|
analyzer = SignalAnalyzer()
|
||||||
|
with conv_path.open() as fin, out_path.open("w") as fout:
|
||||||
|
for line in tqdm(fin, desc="python", unit="convo"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
obj = json.loads(line)
|
||||||
|
report = analyzer.analyze(obj["messages"])
|
||||||
|
fout.write(
|
||||||
|
json.dumps(
|
||||||
|
{"id": obj["id"], "report": report.to_dict()},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
fout.write(json.dumps({"id": obj.get("id"), "error": str(e)}))
|
||||||
|
fout.write("\n")
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
print(f" python runner: {elapsed:.1f}s", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def stamp_metadata(args: argparse.Namespace, output_dir: Path, n_samples: int) -> None:
|
||||||
|
"""Write the input metadata so compare.py can include it in the report."""
|
||||||
|
binary_sha = hashlib.sha256(args.rust_binary.read_bytes()).hexdigest()
|
||||||
|
try:
|
||||||
|
plano_sha = subprocess.check_output(
|
||||||
|
["git", "rev-parse", "HEAD"], cwd=Path(__file__).parent
|
||||||
|
).decode().strip()
|
||||||
|
except Exception:
|
||||||
|
plano_sha = "unknown"
|
||||||
|
try:
|
||||||
|
signals_version = subprocess.check_output(
|
||||||
|
[sys.executable, "-m", "pip", "show", "signals"]
|
||||||
|
).decode()
|
||||||
|
signals_version = next(
|
||||||
|
(l.split(":", 1)[1].strip() for l in signals_version.splitlines() if l.startswith("Version")),
|
||||||
|
"unknown",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
signals_version = "unknown"
|
||||||
|
|
||||||
|
meta = {
|
||||||
|
"dataset_name": DATASET_NAME,
|
||||||
|
"dataset_revision": args.dataset_revision,
|
||||||
|
"seed": args.seed,
|
||||||
|
"num_samples_requested": args.num_samples,
|
||||||
|
"num_samples_actual": n_samples,
|
||||||
|
"rust_binary": str(args.rust_binary.resolve()),
|
||||||
|
"rust_binary_sha256": binary_sha,
|
||||||
|
"plano_git_sha": plano_sha,
|
||||||
|
"signals_python_version": signals_version,
|
||||||
|
"max_conv_messages": args.max_conv_messages,
|
||||||
|
}
|
||||||
|
(output_dir / "run_metadata.json").write_text(json.dumps(meta, indent=2))
|
||||||
|
print(f"wrote {output_dir / 'run_metadata.json'}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not args.rust_binary.exists():
|
||||||
|
raise SystemExit(f"rust binary not found at {args.rust_binary}")
|
||||||
|
|
||||||
|
conv_path = args.output_dir / "conversations.jsonl"
|
||||||
|
rust_path = args.output_dir / "rust_reports.jsonl"
|
||||||
|
py_path = args.output_dir / "python_reports.jsonl"
|
||||||
|
|
||||||
|
samples = sample_conversations(
|
||||||
|
num_samples=args.num_samples,
|
||||||
|
seed=args.seed,
|
||||||
|
revision=args.dataset_revision,
|
||||||
|
max_conv_messages=args.max_conv_messages,
|
||||||
|
)
|
||||||
|
n = write_conversations(conv_path, samples)
|
||||||
|
print(f"wrote {n} conversations to {conv_path}", file=sys.stderr)
|
||||||
|
|
||||||
|
run_rust(args.rust_binary, conv_path, rust_path)
|
||||||
|
run_python(conv_path, py_path)
|
||||||
|
stamp_metadata(args, args.output_dir, n)
|
||||||
|
print("done. now run: python compare.py --output-dir " + str(args.output_dir))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue