Doc-to-LoRA release

This commit is contained in:
51616 2026-02-27 03:47:04 +00:00
commit 1abe8ae16d
92 changed files with 22131 additions and 0 deletions

View file

@ -0,0 +1,17 @@
from huggingface_hub import snapshot_download
if __name__ == "__main__":
self_gen_data_dir = "./data/raw_datasets/self_gen/"
snapshot_download(
"SakanaAI/self_gen_qa_d2l",
repo_type="dataset",
local_dir=self_gen_data_dir,
# we can filter based on model by using the `allow_patterns` argument
# based on https://huggingface.co/datasets/SakanaAI/self_gen_qa_d2l/tree/main
# we can use
# - `Qwen` for downloading the data for `Qwen/Qwen3-4B-Instruct-2507`
# - `google` for downloading the data for `google/gemma-2-2b-it`
# - `mistralai` for downloading the data for `mistralai/Mistral-7B-Instruct-v0.2`
#
# allow_patterns="google/*", # downloading the data for `google/gemma-2-2b-it`
)

14
scripts/main_exp/1-train.sh Executable file
View file

@ -0,0 +1,14 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True

View file

@ -0,0 +1,30 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj \
--lora_r=8 \
--eval_strategy=no \
--max_qas_len=512 \
--max_qas_per_sample=1 \
--per_rank_gen=True \
--per_layer_processing=True \
--gen_lora_l1_reg_coef=0.1 \
--max_steps=20000 \
--gradient_accumulation_steps=16 \
--max_packed_inp_len=1024 \
--max_packed_ctx_len=2048 \
--use_per_ctx_average_loss=True \
--use_kl_loss=True \
--quantize_ctx_encoder=True \
--torch_empty_cache_steps=10 \
--from_pretrained_checkpoint=train_outputs/runs/$RUN_NAME/checkpoint-80000/pytorch_model.bin \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--warmup_steps=2000 \
--learning_rate=2e-5

View file

@ -0,0 +1,25 @@
# D2L pipeline
### Data
You can either download the generated data (recommended, ~100 GB for each model) or generate them by youself.
Please see [`0-download_data.sh`](0-download_data.sh) for how to do model-specific data download.
```bash
# download training data for all three models (328GB)
uv run bash scripts/main_exp/0-download_data.sh
```
Generating data from scratch can take very long if not parallelized across multiple gpus.
```bash
# generate training data (takes very long if not parallelized across multiple gpus)
# optional: use the command below for generating data from scratch
# uv run bash scripts/main_exp/gen_data.sh
```
### Training
Simply run the training script once the data is ready.
```bash
# train
uv run bash scripts/main_exp/1-train.sh
```
### Evaluation
All evaluation scripts for reproducing the main results in the paper are included in [eval](eval/) directory.

View file

@ -0,0 +1,8 @@
# no truncation
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1
# w/ truncation
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --truncate_if_too_long_inp
# no context
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --remove_context

6
scripts/main_exp/eval/cd.sh Executable file
View file

@ -0,0 +1,6 @@
# qa
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=4
# longbench
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=1

View file

@ -0,0 +1 @@
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets $1 --split test --use_cd --cd_update_iterations 50 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=5 --cd_batch_size=2

View file

@ -0,0 +1,6 @@
# qa
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
# longbench
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp

13
scripts/main_exp/eval/d2l.sh Executable file
View file

@ -0,0 +1,13 @@
# main results
# batched
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1
# iterative
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1 --use_iterative_mode
# query internalization
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad --split test --eval_batch_size_gen=1 --flip_ctx_inp
# replaced squad context
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_assistant_ctx_no_passage --split test
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_negative_no_passage --split test

View file

@ -0,0 +1,279 @@
import json
import os
import re
from argparse import Namespace
from difflib import SequenceMatcher
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling.ctx_encoder import PerLayerActivations
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers
from ctx_to_lora.modeling.lora_merger import combine_lora
CLASS_NAMES = [
"tench",
"English springer",
"cassette player",
"chain saw",
"church",
"French horn",
"garbage truck",
"gas pump",
"golf ball",
"parachute",
]
CLASS_TO_INT = {name: i for i, name in enumerate(CLASS_NAMES)}
INPUT_TXT = f"What is in this image? Choose exactly one of the following classes: {', '.join(CLASS_NAMES)}. Response with only the correct class without any other text."
RUN_DIR = "train_outputs/runs/Oct16_02-37-04_slurm0-a3nodeset-8_94074_1d62ecb8"
def _normalize_text(text: str) -> str:
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
return " ".join(text.split())
def _normalize_compact(text: str) -> str:
return re.sub(r"[^a-z0-9]", "", text.lower())
def _build_alias_map():
alias_overrides = {
"english springer spaniel": "English springer",
"springer spaniel": "English springer",
"chainsaw": "chain saw",
"dump truck": "garbage truck",
"refuse truck": "garbage truck",
"garbage lorry": "garbage truck",
"fuel pump": "gas pump",
"gas station pump": "gas pump",
"cassette deck": "cassette player",
"cassette recorder": "cassette player",
"fish": "tench",
"tench fish": "tench",
"french horn instrument": "French horn",
"golfball": "golf ball",
"skydiving": "parachute",
"parachutist": "parachute",
}
alias_map = {}
def register(alias: str, canonical: str):
alias = _normalize_text(alias)
if alias:
alias_map[alias] = canonical
alias_map[_normalize_compact(alias)] = canonical
for name in CLASS_NAMES:
register(name, name)
register(name.replace(" ", ""), name)
register(name.replace(" ", "-"), name)
for alias, canonical in alias_overrides.items():
register(alias, canonical)
return alias_map
CLASS_ALIAS_MAP = _build_alias_map()
def pred_to_class_id(pred_txt: str) -> int:
norm_pred = _normalize_text(pred_txt)
compact_pred = _normalize_compact(pred_txt)
for alias, canonical in CLASS_ALIAS_MAP.items():
if alias and (alias in norm_pred or alias in compact_pred):
return CLASS_TO_INT[canonical]
pred_tokens = set(norm_pred.split())
best_class = None
best_token_hits = -1
for name in CLASS_NAMES:
class_tokens = set(_normalize_text(name).split())
if class_tokens and class_tokens.issubset(pred_tokens):
return CLASS_TO_INT[name]
hits = sum(token in pred_tokens for token in class_tokens)
if hits > best_token_hits:
best_token_hits = hits
best_class = name
best_ratio = -1.0
for name in CLASS_NAMES:
ratio = SequenceMatcher(None, norm_pred, _normalize_text(name)).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_class = name
return CLASS_TO_INT[best_class]
def load_checkpoint():
checkpoint_path = f"{RUN_DIR}/checkpoint-80000/pytorch_model.bin"
state_dict = torch.load(checkpoint_path)
model = ModulatedPretrainedModel.from_state_dict(
state_dict,
train=False,
base_model_kwargs=dict(attn_implementation="flash_attention_2"),
use_flash_attn=True,
use_sequence_packing=False, # for generation
)
tokenizer = get_tokenizer("google/gemma-2-2b-it")
model.eval()
return model, tokenizer
def load_ctx_encoder():
model_id = "google/gemma-3-4b-it"
ctx_model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
ctx_encoder_config = Namespace(ctx_encoder_last_layer=26, keep_lm_head=True)
ctx_model.language_model = PerLayerActivations(
ctx_model.language_model, ctx_encoder_config
)
processor = AutoProcessor.from_pretrained(model_id)
return ctx_model, processor
def template_image(img, ctx_processor):
messages = [
{
"role": "system",
"content": [{"type": "text", "text": ""}],
},
{
"role": "user",
"content": [
{"type": "image", "image": img},
],
},
]
inputs = ctx_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
return inputs
@torch.inference_mode
def get_ctx_features(ctx_inputs, ctx_encoder):
forward_outputs = ctx_encoder(**ctx_inputs, output_hidden_states=True)
ctx_features = torch.stack(forward_outputs.hidden_states, dim=1)
return ctx_features
def generate_loras(ctx_inputs, ctx_features):
generated_loras, _ = model.hypernet.generate_weights(
ctx_features, attn_mask=torch.ones_like(ctx_inputs["input_ids"])
)
generated_loras = combine_lora(
generated_loras,
n_chunks=torch.tensor((1,), device=model.device),
lora_bias=model.hypernet.get_head_bias()
if model.hypernet.config.use_bias
else None,
)
return generated_loras
def apply_loras(model, generated_loras):
n_queries = torch.ones(1, dtype=torch.int32, device=model.device)
apply_lora_to_layers(
model.base_model,
model.hypernet.layer_indices,
generated_loras,
n_queries,
)
if __name__ == "__main__":
model, base_tokenizer = load_checkpoint()
ctx_encoder, ctx_processor = load_ctx_encoder()
ds = load_dataset("frgfm/imagenette", "full_size", split="validation")
# ds = ds.shuffle().select(range(int(0.05 * len(ds))))
input_ids = base_tokenizer.apply_chat_template(
[{"role": "user", "content": INPUT_TXT}],
add_special_tokens=False,
return_attention_mask=False,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
preds = []
pred_txts = []
corrects = []
labels = ds["label"]
for sample in tqdm(ds):
img = sample["image"]
ctx_inputs = template_image(img, ctx_processor).to(ctx_encoder.device)
ctx_features = get_ctx_features(ctx_inputs, ctx_encoder)
generated_loras = generate_loras(ctx_inputs, ctx_features)
apply_loras(model, generated_loras)
model_outputs = model.base_model.generate(
input_ids, max_new_tokens=256, do_sample=False
)
pred_txt = base_tokenizer.decode(
model_outputs[0][len(input_ids[0]) :], skip_special_tokens=True
)
pred_txts.append(pred_txt)
preds.append(pred_to_class_id(pred_txt))
is_correct = preds[-1] == labels[len(preds) - 1]
corrects.append(is_correct)
print(
f"GT: {CLASS_NAMES[labels[len(preds) - 1]]}, Pred: {pred_txt} -> {CLASS_NAMES[preds[-1]]}, Correct: {is_correct}"
)
acc = sum(corrects) / len(corrects)
print(f"Final accuracy: {acc:4f}")
jsonl_path = os.path.join(RUN_DIR, "imagenette_eval.jsonl")
meta_path = os.path.join(RUN_DIR, "imagenette_eval.meta.json")
with open(jsonl_path, "w") as f:
for i, (pred_txt, pred_id, label_id) in enumerate(
zip(pred_txts, preds, labels)
):
f.write(
json.dumps(
{
"index": i,
"label": int(label_id),
"label_name": CLASS_NAMES[label_id],
"pred_text": pred_txt,
"pred_class_id": int(pred_id),
"pred_class_name": CLASS_NAMES[pred_id],
"correct": bool(pred_id == label_id),
}
)
+ "\n"
)
meta = {
"dataset": "frgfm/imagenette",
"subset": "full_size",
"split": "validation",
"run_dir": RUN_DIR,
"prompt": INPUT_TXT,
"accuracy": float(acc),
"num_samples": len(preds),
"class_names": CLASS_NAMES,
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote samples to {jsonl_path}")
print(f"Wrote metadata to {meta_path}")

View file

@ -0,0 +1,5 @@
for dataset in squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e; do
for rate in 0.9 0.8 0.6 0.4 0.2 0.1; do
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets "$dataset" --split test --eval_batch_size_gen=1 --use_llmlingua --llmlingua_compression_rate "$rate" --truncate_if_too_long_ctx
done
done

4
scripts/main_exp/eval/t2l.sh Executable file
View file

@ -0,0 +1,4 @@
# download t2l checkpoint
uv run huggingface-cli download SakanaAI/text-to-lora --local-dir . --include "trained_t2l/gemma_2b_t2l"
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen=1 --use_t2l

20
scripts/main_exp/gen_data.sh Executable file
View file

@ -0,0 +1,20 @@
# download fineweb_edu to `data/raw_datasets/fineweb_edu
uv run data/download_fineweb_edu.py
# generate qa data
# run from 000 to 013
for shard_id in $(seq -f "%03g" 0 13); do
uv run data/generate_fw_edu_qa_v2.py --shard_pattern "${shard_id}_00000" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it --max_length=2000 --max_model_length=2048
uv run data/generate_fw_edu_qa_v2_repeat.py --shard_pattern "min_0_to_2000/${shard_id}*level_0" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it
# self-generated response QA data
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern "data/raw_datasets/fw_qa_v2/min_0_to_2000/${shard_id}*_level_1*" --closed_qa_prob 1.0
done
# val split
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern 'data/raw_datasets/fw_qa_v2/min_0_to_2000/*_level_0_val.parquet'
# self-gen data for other ds
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names squad_compact ropes_compact drop_compact --split train --closed_qa_prob 1.0
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names pwc_compact --split train --closed_qa_prob 0.0

View file

@ -0,0 +1,29 @@
#!/bin/bash
#SBATCH --job-name=ctxlora
#SBATCH --nodes=1
#SBATCH --partition=a3
#SBATCH --gpus=8
#SBATCH --output=slurm_logs/%x-%j.out
#SBATCH --error=slurm_logs/%x-%j.out
port=$((10000 + ($SLURM_JOBID % 50000)))
echo "Using port: $port"
# port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=20000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--warmup_steps=2000 \
--learning_rate=2e-5 \
"$@"

View file

@ -0,0 +1,24 @@
#!/bin/bash
#SBATCH --job-name=ctxlora
#SBATCH --nodes=1
#SBATCH --partition=a3
#SBATCH --gpus=8
#SBATCH --output=slurm_logs/%x-%j.out
#SBATCH --error=slurm_logs/%x-%j.out
port=$((10000 + ($SLURM_JOBID % 50000)))
echo "Using port: $port"
# port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
"$@"

15
scripts/main_exp/train_no_qa.sh Executable file
View file

@ -0,0 +1,15 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True \
"$@"