mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-04-26 08:36:23 +02:00
288 lines
10 KiB
Python
288 lines
10 KiB
Python
import argparse
|
|
import json
|
|
import math
|
|
import os
|
|
import random
|
|
|
|
# -----------------------------
|
|
# Config knobs (edit or use CLI)
|
|
# -----------------------------
|
|
TOKENS_PER_BLOCK = 40 # rough heuristic tokens per noise block
|
|
BASE_SAMPLES_PER_BIN = (
|
|
320_000 # training samples budget scaler only (val/test fixed at 1000 each)
|
|
)
|
|
RNG_SEED = 42
|
|
NOISE_BLOCK = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
|
|
SPECIAL_TPL = "The special magic number is {magic_number}."
|
|
SEP = "\n" # between blocks
|
|
|
|
|
|
def save_jsonl(data: list[dict], filepath: str) -> None:
|
|
parent_dir = os.path.dirname(filepath)
|
|
if parent_dir:
|
|
os.makedirs(parent_dir, exist_ok=True)
|
|
with open(filepath, "w") as f:
|
|
for entry in data:
|
|
json.dump(entry, f)
|
|
f.write("\n")
|
|
|
|
|
|
essential_digits4 = lambda: f"{random.randint(0, 9_999):04d}"
|
|
|
|
|
|
def _choose_position(total_blocks: int, depth_bin: int) -> int:
|
|
"""Choose an insertion index for the special sentence within [0, total_blocks-1]
|
|
such that its relative depth falls within the depth bin [i/10, (i+1)/10).
|
|
"""
|
|
if total_blocks <= 0:
|
|
return 0
|
|
# Use floor for start and ceil for end to cover boundaries evenly
|
|
start = math.floor(total_blocks * (depth_bin / 10))
|
|
end = math.ceil(total_blocks * ((depth_bin + 1) / 10)) - 1
|
|
# clamp
|
|
start = max(0, min(start, total_blocks - 1))
|
|
end = max(start, min(end, total_blocks - 1))
|
|
return random.randint(start, end)
|
|
|
|
|
|
def _build_example(total_blocks: int, depth_bin: int) -> dict:
|
|
"""Build one example with a special line inserted among noise blocks.
|
|
|
|
total_blocks: total number of blocks in the final context (including the special one)
|
|
depth_bin: integer in [0, 9]
|
|
"""
|
|
total_blocks = max(1, total_blocks)
|
|
|
|
# Prepare blocks
|
|
magic = essential_digits4()
|
|
special_line = SPECIAL_TPL.format(magic_number=magic)
|
|
|
|
# We'll have (total_blocks - 1) noise blocks and 1 special line
|
|
noise_count = max(0, total_blocks - 1)
|
|
blocks: list[str] = [NOISE_BLOCK for _ in range(noise_count)]
|
|
|
|
insert_at = _choose_position(total_blocks, depth_bin)
|
|
# Insert special line at the desired position within the final sequence
|
|
# If noise_count == 0, we just return special
|
|
if noise_count == 0:
|
|
final_blocks = [special_line]
|
|
else:
|
|
# Compose by interleaving noise and inserting special at index
|
|
# Build a list of length `total_blocks` and fill
|
|
final_blocks = []
|
|
noise_idx = 0
|
|
for idx in range(total_blocks):
|
|
if idx == insert_at:
|
|
final_blocks.append(special_line)
|
|
else:
|
|
final_blocks.append(blocks[noise_idx])
|
|
noise_idx += 1
|
|
|
|
context = SEP.join(final_blocks)
|
|
prompt = "What is the special magic number? Reply with only the number."
|
|
response = magic
|
|
return {"context": context, "prompt": prompt, "response": response}
|
|
|
|
|
|
def generate_examples(n: int, k: int) -> list[dict]:
|
|
"""Generate n examples (all for block length k) evenly across 10 depth bins."""
|
|
if n <= 0:
|
|
return []
|
|
base = n // 10
|
|
rem = n % 10
|
|
counts = [base + (1 if i < rem else 0) for i in range(10)]
|
|
out: list[dict] = []
|
|
for depth_bin, c in enumerate(counts):
|
|
for _ in range(c):
|
|
out.append(_build_example(total_blocks=k, depth_bin=depth_bin))
|
|
random.shuffle(out)
|
|
return out
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate noise-wrapped special magic number dataset (similar structure to generate_ctx_kv.py)",
|
|
)
|
|
parser.add_argument("--seed", type=int, default=RNG_SEED, help="Random seed")
|
|
parser.add_argument(
|
|
"--tokenizer-name",
|
|
type=str,
|
|
default="google/gemma-2-2b-it",
|
|
help=("Tokenizer name"),
|
|
)
|
|
parser.add_argument(
|
|
"--base-samples-per-bin",
|
|
type=int,
|
|
default=BASE_SAMPLES_PER_BIN,
|
|
help="Baseline number of TRAINING samples per token bin (scaled by bin width). Validation & test are always 1000 each.",
|
|
)
|
|
parser.add_argument(
|
|
"--out-prefix",
|
|
type=str,
|
|
default="data/raw_datasets/ctx_magic_number",
|
|
help="Output directory prefix (bin range will be appended)",
|
|
)
|
|
parser.add_argument(
|
|
"--tokens-per-block",
|
|
"--tokens-per-pair",
|
|
dest="tokens_per_block",
|
|
type=int,
|
|
default=TOKENS_PER_BLOCK,
|
|
help="Heuristic tokens per noise block for bucketing",
|
|
)
|
|
parser.add_argument(
|
|
"--only-first-n-bins",
|
|
type=int,
|
|
default=None,
|
|
help="For quick tests: only generate the first N token bins",
|
|
)
|
|
parser.add_argument(
|
|
"--dry-run",
|
|
action="store_true",
|
|
help="Print a small sample and exit without writing files",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
random.seed(args.seed)
|
|
|
|
# ----------------------------------------------------
|
|
# Optional: report tokenizer-based token length stats
|
|
# ----------------------------------------------------
|
|
if args.tokenizer_name:
|
|
try:
|
|
from transformers import AutoTokenizer # type: ignore
|
|
except Exception as e: # pragma: no cover
|
|
raise RuntimeError(
|
|
"Failed to import transformers. Install it or omit --tokenizer-name."
|
|
) from e
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
|
|
noise_token_count = len(tokenizer(NOISE_BLOCK).input_ids)
|
|
special_example = SPECIAL_TPL.format(magic_number="0000")
|
|
special_token_count = len(tokenizer(special_example).input_ids)
|
|
print(
|
|
f"[Tokenizer: {args.tokenizer_name}] Noise block tokens: {noise_token_count} | Special line tokens: {special_token_count}"
|
|
)
|
|
|
|
tok_bins = [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
|
|
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
|
|
]
|
|
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
|
|
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
|
|
if args.only_first_n_bins is not None:
|
|
tok_bins = tok_bins[: args.only_first_n_bins]
|
|
|
|
if args.tokenizer_name:
|
|
max_hi = max(hi for _, hi in tok_bins)
|
|
|
|
def measure_len(k: int) -> int:
|
|
if k == 1:
|
|
ctx = SPECIAL_TPL.format(magic_number="0000")
|
|
else:
|
|
blocks = [NOISE_BLOCK] * (k - 1) + [
|
|
SPECIAL_TPL.format(magic_number="0000")
|
|
]
|
|
ctx = SEP.join(blocks)
|
|
return len(tokenizer(ctx).input_ids)
|
|
|
|
lengths: list[int] = [0]
|
|
k = 1
|
|
while True:
|
|
L = measure_len(k)
|
|
lengths.append(L)
|
|
if L >= max_hi:
|
|
break
|
|
k += 1
|
|
|
|
len_bins = []
|
|
for lo, hi in tok_bins:
|
|
k_lo = None
|
|
for kk in range(1, len(lengths)):
|
|
if lengths[kk] >= lo:
|
|
k_lo = kk
|
|
break
|
|
if k_lo is None or lengths[k_lo] >= hi:
|
|
len_bins.append((0, 0))
|
|
continue
|
|
k_hi = len(lengths)
|
|
for kk in range(k_lo, len(lengths)):
|
|
if lengths[kk] >= hi:
|
|
k_hi = kk
|
|
break
|
|
len_bins.append((k_lo, k_hi))
|
|
|
|
base_tokens = lengths[1]
|
|
delta = (lengths[2] - lengths[1]) if len(lengths) > 2 else 0
|
|
print(
|
|
f"Using tokenizer-measured block ranges. base_tokens={base_tokens} approx_delta={delta}"
|
|
)
|
|
else:
|
|
len_bins = [
|
|
(lo // args.tokens_per_block, hi // args.tokens_per_block)
|
|
for (lo, hi) in tok_bins
|
|
]
|
|
|
|
if args.dry_run:
|
|
for lb in len_bins:
|
|
if lb[1] > lb[0]:
|
|
k = max(1, lb[0])
|
|
sample = generate_examples(10, k)
|
|
print("Sample entry:")
|
|
print(json.dumps(sample[0], indent=2))
|
|
break
|
|
return
|
|
# -----------------------------------------------
|
|
# Main generation per token bin
|
|
# -----------------------------------------------
|
|
TARGET_VAL = 1000
|
|
TARGET_TEST = 1000
|
|
for len_bin, tok_bin in zip(len_bins, tok_bins):
|
|
if len_bin[1] <= len_bin[0]:
|
|
print(f"Skipping token bin {tok_bin} (no valid block counts)")
|
|
continue
|
|
k_start = max(1, len_bin[0])
|
|
k_end = max(1, len_bin[1])
|
|
k_values = list(range(k_start, k_end))
|
|
bin_size = len(k_values)
|
|
save_dir = f"{args.out_prefix}_{tok_bin[0]}_{tok_bin[1]}"
|
|
training_enabled = tok_bin[1] <= 1024 # unchanged policy
|
|
if training_enabled:
|
|
train_data: list[dict] = []
|
|
# Distribute training budget across k values.
|
|
# Scale: per_k = base_samples_per_bin / bin_size
|
|
per_k_train = max(1, args.base_samples_per_bin // max(1, bin_size))
|
|
for k in k_values:
|
|
train_data += generate_examples(per_k_train, k)
|
|
val_data: list[dict] = []
|
|
test_data: list[dict] = []
|
|
base_val = TARGET_VAL // bin_size
|
|
rem_val = TARGET_VAL % bin_size
|
|
base_test = TARGET_TEST // bin_size
|
|
rem_test = TARGET_TEST % bin_size
|
|
for idx, k in enumerate(k_values):
|
|
n_val_k = base_val + (1 if idx < rem_val else 0)
|
|
n_test_k = base_test + (1 if idx < rem_test else 0)
|
|
if n_val_k:
|
|
val_data += generate_examples(n_val_k, k)
|
|
if n_test_k:
|
|
test_data += generate_examples(n_test_k, k)
|
|
random.shuffle(val_data)
|
|
random.shuffle(test_data)
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
if training_enabled:
|
|
save_jsonl(train_data, f"{save_dir}/train.jsonl")
|
|
save_jsonl(val_data, f"{save_dir}/val.jsonl")
|
|
save_jsonl(test_data, f"{save_dir}/test.jsonl")
|
|
if training_enabled:
|
|
print(
|
|
f"Dataset generated at {save_dir} (train={len(train_data)} val={len(val_data)} test={len(test_data)})"
|
|
)
|
|
else:
|
|
print(
|
|
f"Dataset (val/test only) generated at {save_dir} (val={len(val_data)} test={len(test_data)})"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|