doc-to-lora/data/generate_ctx_magic_number.py
2026-02-27 03:47:04 +00:00

288 lines
10 KiB
Python

import argparse
import json
import math
import os
import random
# -----------------------------
# Config knobs (edit or use CLI)
# -----------------------------
TOKENS_PER_BLOCK = 40 # rough heuristic tokens per noise block
BASE_SAMPLES_PER_BIN = (
320_000 # training samples budget scaler only (val/test fixed at 1000 each)
)
RNG_SEED = 42
NOISE_BLOCK = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
SPECIAL_TPL = "The special magic number is {magic_number}."
SEP = "\n" # between blocks
def save_jsonl(data: list[dict], filepath: str) -> None:
parent_dir = os.path.dirname(filepath)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
with open(filepath, "w") as f:
for entry in data:
json.dump(entry, f)
f.write("\n")
essential_digits4 = lambda: f"{random.randint(0, 9_999):04d}"
def _choose_position(total_blocks: int, depth_bin: int) -> int:
"""Choose an insertion index for the special sentence within [0, total_blocks-1]
such that its relative depth falls within the depth bin [i/10, (i+1)/10).
"""
if total_blocks <= 0:
return 0
# Use floor for start and ceil for end to cover boundaries evenly
start = math.floor(total_blocks * (depth_bin / 10))
end = math.ceil(total_blocks * ((depth_bin + 1) / 10)) - 1
# clamp
start = max(0, min(start, total_blocks - 1))
end = max(start, min(end, total_blocks - 1))
return random.randint(start, end)
def _build_example(total_blocks: int, depth_bin: int) -> dict:
"""Build one example with a special line inserted among noise blocks.
total_blocks: total number of blocks in the final context (including the special one)
depth_bin: integer in [0, 9]
"""
total_blocks = max(1, total_blocks)
# Prepare blocks
magic = essential_digits4()
special_line = SPECIAL_TPL.format(magic_number=magic)
# We'll have (total_blocks - 1) noise blocks and 1 special line
noise_count = max(0, total_blocks - 1)
blocks: list[str] = [NOISE_BLOCK for _ in range(noise_count)]
insert_at = _choose_position(total_blocks, depth_bin)
# Insert special line at the desired position within the final sequence
# If noise_count == 0, we just return special
if noise_count == 0:
final_blocks = [special_line]
else:
# Compose by interleaving noise and inserting special at index
# Build a list of length `total_blocks` and fill
final_blocks = []
noise_idx = 0
for idx in range(total_blocks):
if idx == insert_at:
final_blocks.append(special_line)
else:
final_blocks.append(blocks[noise_idx])
noise_idx += 1
context = SEP.join(final_blocks)
prompt = "What is the special magic number? Reply with only the number."
response = magic
return {"context": context, "prompt": prompt, "response": response}
def generate_examples(n: int, k: int) -> list[dict]:
"""Generate n examples (all for block length k) evenly across 10 depth bins."""
if n <= 0:
return []
base = n // 10
rem = n % 10
counts = [base + (1 if i < rem else 0) for i in range(10)]
out: list[dict] = []
for depth_bin, c in enumerate(counts):
for _ in range(c):
out.append(_build_example(total_blocks=k, depth_bin=depth_bin))
random.shuffle(out)
return out
def main():
parser = argparse.ArgumentParser(
description="Generate noise-wrapped special magic number dataset (similar structure to generate_ctx_kv.py)",
)
parser.add_argument("--seed", type=int, default=RNG_SEED, help="Random seed")
parser.add_argument(
"--tokenizer-name",
type=str,
default="google/gemma-2-2b-it",
help=("Tokenizer name"),
)
parser.add_argument(
"--base-samples-per-bin",
type=int,
default=BASE_SAMPLES_PER_BIN,
help="Baseline number of TRAINING samples per token bin (scaled by bin width). Validation & test are always 1000 each.",
)
parser.add_argument(
"--out-prefix",
type=str,
default="data/raw_datasets/ctx_magic_number",
help="Output directory prefix (bin range will be appended)",
)
parser.add_argument(
"--tokens-per-block",
"--tokens-per-pair",
dest="tokens_per_block",
type=int,
default=TOKENS_PER_BLOCK,
help="Heuristic tokens per noise block for bucketing",
)
parser.add_argument(
"--only-first-n-bins",
type=int,
default=None,
help="For quick tests: only generate the first N token bins",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print a small sample and exit without writing files",
)
args = parser.parse_args()
random.seed(args.seed)
# ----------------------------------------------------
# Optional: report tokenizer-based token length stats
# ----------------------------------------------------
if args.tokenizer_name:
try:
from transformers import AutoTokenizer # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError(
"Failed to import transformers. Install it or omit --tokenizer-name."
) from e
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
noise_token_count = len(tokenizer(NOISE_BLOCK).input_ids)
special_example = SPECIAL_TPL.format(magic_number="0000")
special_token_count = len(tokenizer(special_example).input_ids)
print(
f"[Tokenizer: {args.tokenizer_name}] Noise block tokens: {noise_token_count} | Special line tokens: {special_token_count}"
)
tok_bins = [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
]
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
if args.only_first_n_bins is not None:
tok_bins = tok_bins[: args.only_first_n_bins]
if args.tokenizer_name:
max_hi = max(hi for _, hi in tok_bins)
def measure_len(k: int) -> int:
if k == 1:
ctx = SPECIAL_TPL.format(magic_number="0000")
else:
blocks = [NOISE_BLOCK] * (k - 1) + [
SPECIAL_TPL.format(magic_number="0000")
]
ctx = SEP.join(blocks)
return len(tokenizer(ctx).input_ids)
lengths: list[int] = [0]
k = 1
while True:
L = measure_len(k)
lengths.append(L)
if L >= max_hi:
break
k += 1
len_bins = []
for lo, hi in tok_bins:
k_lo = None
for kk in range(1, len(lengths)):
if lengths[kk] >= lo:
k_lo = kk
break
if k_lo is None or lengths[k_lo] >= hi:
len_bins.append((0, 0))
continue
k_hi = len(lengths)
for kk in range(k_lo, len(lengths)):
if lengths[kk] >= hi:
k_hi = kk
break
len_bins.append((k_lo, k_hi))
base_tokens = lengths[1]
delta = (lengths[2] - lengths[1]) if len(lengths) > 2 else 0
print(
f"Using tokenizer-measured block ranges. base_tokens={base_tokens} approx_delta={delta}"
)
else:
len_bins = [
(lo // args.tokens_per_block, hi // args.tokens_per_block)
for (lo, hi) in tok_bins
]
if args.dry_run:
for lb in len_bins:
if lb[1] > lb[0]:
k = max(1, lb[0])
sample = generate_examples(10, k)
print("Sample entry:")
print(json.dumps(sample[0], indent=2))
break
return
# -----------------------------------------------
# Main generation per token bin
# -----------------------------------------------
TARGET_VAL = 1000
TARGET_TEST = 1000
for len_bin, tok_bin in zip(len_bins, tok_bins):
if len_bin[1] <= len_bin[0]:
print(f"Skipping token bin {tok_bin} (no valid block counts)")
continue
k_start = max(1, len_bin[0])
k_end = max(1, len_bin[1])
k_values = list(range(k_start, k_end))
bin_size = len(k_values)
save_dir = f"{args.out_prefix}_{tok_bin[0]}_{tok_bin[1]}"
training_enabled = tok_bin[1] <= 1024 # unchanged policy
if training_enabled:
train_data: list[dict] = []
# Distribute training budget across k values.
# Scale: per_k = base_samples_per_bin / bin_size
per_k_train = max(1, args.base_samples_per_bin // max(1, bin_size))
for k in k_values:
train_data += generate_examples(per_k_train, k)
val_data: list[dict] = []
test_data: list[dict] = []
base_val = TARGET_VAL // bin_size
rem_val = TARGET_VAL % bin_size
base_test = TARGET_TEST // bin_size
rem_test = TARGET_TEST % bin_size
for idx, k in enumerate(k_values):
n_val_k = base_val + (1 if idx < rem_val else 0)
n_test_k = base_test + (1 if idx < rem_test else 0)
if n_val_k:
val_data += generate_examples(n_val_k, k)
if n_test_k:
test_data += generate_examples(n_test_k, k)
random.shuffle(val_data)
random.shuffle(test_data)
os.makedirs(save_dir, exist_ok=True)
if training_enabled:
save_jsonl(train_data, f"{save_dir}/train.jsonl")
save_jsonl(val_data, f"{save_dir}/val.jsonl")
save_jsonl(test_data, f"{save_dir}/test.jsonl")
if training_enabled:
print(
f"Dataset generated at {save_dir} (train={len(train_data)} val={len(val_data)} test={len(test_data)})"
)
else:
print(
f"Dataset (val/test only) generated at {save_dir} (val={len(val_data)} test={len(test_data)})"
)
if __name__ == "__main__":
main()