mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-04-28 01:26:21 +02:00
Doc-to-LoRA release
This commit is contained in:
commit
1abe8ae16d
92 changed files with 22131 additions and 0 deletions
17
scripts/main_exp/0-download_data.py
Normal file
17
scripts/main_exp/0-download_data.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == "__main__":
|
||||
self_gen_data_dir = "./data/raw_datasets/self_gen/"
|
||||
snapshot_download(
|
||||
"SakanaAI/self_gen_qa_d2l",
|
||||
repo_type="dataset",
|
||||
local_dir=self_gen_data_dir,
|
||||
# we can filter based on model by using the `allow_patterns` argument
|
||||
# based on https://huggingface.co/datasets/SakanaAI/self_gen_qa_d2l/tree/main
|
||||
# we can use
|
||||
# - `Qwen` for downloading the data for `Qwen/Qwen3-4B-Instruct-2507`
|
||||
# - `google` for downloading the data for `google/gemma-2-2b-it`
|
||||
# - `mistralai` for downloading the data for `mistralai/Mistral-7B-Instruct-v0.2`
|
||||
#
|
||||
# allow_patterns="google/*", # downloading the data for `google/gemma-2-2b-it`
|
||||
)
|
||||
14
scripts/main_exp/1-train.sh
Executable file
14
scripts/main_exp/1-train.sh
Executable file
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True
|
||||
30
scripts/main_exp/2-train-chunk.sh
Executable file
30
scripts/main_exp/2-train-chunk.sh
Executable file
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj \
|
||||
--lora_r=8 \
|
||||
--eval_strategy=no \
|
||||
--max_qas_len=512 \
|
||||
--max_qas_per_sample=1 \
|
||||
--per_rank_gen=True \
|
||||
--per_layer_processing=True \
|
||||
--gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=20000 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--max_packed_inp_len=1024 \
|
||||
--max_packed_ctx_len=2048 \
|
||||
--use_per_ctx_average_loss=True \
|
||||
--use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True \
|
||||
--torch_empty_cache_steps=10 \
|
||||
--from_pretrained_checkpoint=train_outputs/runs/$RUN_NAME/checkpoint-80000/pytorch_model.bin \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--warmup_steps=2000 \
|
||||
--learning_rate=2e-5
|
||||
25
scripts/main_exp/README.md
Normal file
25
scripts/main_exp/README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# D2L pipeline
|
||||
### Data
|
||||
You can either download the generated data (recommended, ~100 GB for each model) or generate them by youself.
|
||||
Please see [`0-download_data.sh`](0-download_data.sh) for how to do model-specific data download.
|
||||
```bash
|
||||
# download training data for all three models (328GB)
|
||||
uv run bash scripts/main_exp/0-download_data.sh
|
||||
```
|
||||
|
||||
Generating data from scratch can take very long if not parallelized across multiple gpus.
|
||||
```bash
|
||||
# generate training data (takes very long if not parallelized across multiple gpus)
|
||||
# optional: use the command below for generating data from scratch
|
||||
# uv run bash scripts/main_exp/gen_data.sh
|
||||
```
|
||||
|
||||
### Training
|
||||
Simply run the training script once the data is ready.
|
||||
```bash
|
||||
# train
|
||||
uv run bash scripts/main_exp/1-train.sh
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
All evaluation scripts for reproducing the main results in the paper are included in [eval](eval/) directory.
|
||||
8
scripts/main_exp/eval/base_model.sh
Executable file
8
scripts/main_exp/eval/base_model.sh
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
# no truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1
|
||||
|
||||
# w/ truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --truncate_if_too_long_inp
|
||||
|
||||
# no context
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --remove_context
|
||||
6
scripts/main_exp/eval/cd.sh
Executable file
6
scripts/main_exp/eval/cd.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=4
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=1
|
||||
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
|
|
@ -0,0 +1 @@
|
|||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets $1 --split test --use_cd --cd_update_iterations 50 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=5 --cd_batch_size=2
|
||||
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
13
scripts/main_exp/eval/d2l.sh
Executable file
13
scripts/main_exp/eval/d2l.sh
Executable file
|
|
@ -0,0 +1,13 @@
|
|||
# main results
|
||||
# batched
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1
|
||||
|
||||
# iterative
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1 --use_iterative_mode
|
||||
|
||||
# query internalization
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad --split test --eval_batch_size_gen=1 --flip_ctx_inp
|
||||
|
||||
# replaced squad context
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_assistant_ctx_no_passage --split test
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_negative_no_passage --split test
|
||||
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.ctx_encoder import PerLayerActivations
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers
|
||||
from ctx_to_lora.modeling.lora_merger import combine_lora
|
||||
|
||||
CLASS_NAMES = [
|
||||
"tench",
|
||||
"English springer",
|
||||
"cassette player",
|
||||
"chain saw",
|
||||
"church",
|
||||
"French horn",
|
||||
"garbage truck",
|
||||
"gas pump",
|
||||
"golf ball",
|
||||
"parachute",
|
||||
]
|
||||
CLASS_TO_INT = {name: i for i, name in enumerate(CLASS_NAMES)}
|
||||
INPUT_TXT = f"What is in this image? Choose exactly one of the following classes: {', '.join(CLASS_NAMES)}. Response with only the correct class without any other text."
|
||||
RUN_DIR = "train_outputs/runs/Oct16_02-37-04_slurm0-a3nodeset-8_94074_1d62ecb8"
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _normalize_compact(text: str) -> str:
|
||||
return re.sub(r"[^a-z0-9]", "", text.lower())
|
||||
|
||||
|
||||
def _build_alias_map():
|
||||
alias_overrides = {
|
||||
"english springer spaniel": "English springer",
|
||||
"springer spaniel": "English springer",
|
||||
"chainsaw": "chain saw",
|
||||
"dump truck": "garbage truck",
|
||||
"refuse truck": "garbage truck",
|
||||
"garbage lorry": "garbage truck",
|
||||
"fuel pump": "gas pump",
|
||||
"gas station pump": "gas pump",
|
||||
"cassette deck": "cassette player",
|
||||
"cassette recorder": "cassette player",
|
||||
"fish": "tench",
|
||||
"tench fish": "tench",
|
||||
"french horn instrument": "French horn",
|
||||
"golfball": "golf ball",
|
||||
"skydiving": "parachute",
|
||||
"parachutist": "parachute",
|
||||
}
|
||||
|
||||
alias_map = {}
|
||||
|
||||
def register(alias: str, canonical: str):
|
||||
alias = _normalize_text(alias)
|
||||
if alias:
|
||||
alias_map[alias] = canonical
|
||||
alias_map[_normalize_compact(alias)] = canonical
|
||||
|
||||
for name in CLASS_NAMES:
|
||||
register(name, name)
|
||||
register(name.replace(" ", ""), name)
|
||||
register(name.replace(" ", "-"), name)
|
||||
|
||||
for alias, canonical in alias_overrides.items():
|
||||
register(alias, canonical)
|
||||
|
||||
return alias_map
|
||||
|
||||
|
||||
CLASS_ALIAS_MAP = _build_alias_map()
|
||||
|
||||
|
||||
def pred_to_class_id(pred_txt: str) -> int:
|
||||
norm_pred = _normalize_text(pred_txt)
|
||||
compact_pred = _normalize_compact(pred_txt)
|
||||
|
||||
for alias, canonical in CLASS_ALIAS_MAP.items():
|
||||
if alias and (alias in norm_pred or alias in compact_pred):
|
||||
return CLASS_TO_INT[canonical]
|
||||
|
||||
pred_tokens = set(norm_pred.split())
|
||||
best_class = None
|
||||
best_token_hits = -1
|
||||
for name in CLASS_NAMES:
|
||||
class_tokens = set(_normalize_text(name).split())
|
||||
if class_tokens and class_tokens.issubset(pred_tokens):
|
||||
return CLASS_TO_INT[name]
|
||||
hits = sum(token in pred_tokens for token in class_tokens)
|
||||
if hits > best_token_hits:
|
||||
best_token_hits = hits
|
||||
best_class = name
|
||||
|
||||
best_ratio = -1.0
|
||||
for name in CLASS_NAMES:
|
||||
ratio = SequenceMatcher(None, norm_pred, _normalize_text(name)).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_class = name
|
||||
|
||||
return CLASS_TO_INT[best_class]
|
||||
|
||||
|
||||
def load_checkpoint():
|
||||
checkpoint_path = f"{RUN_DIR}/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict,
|
||||
train=False,
|
||||
base_model_kwargs=dict(attn_implementation="flash_attention_2"),
|
||||
use_flash_attn=True,
|
||||
use_sequence_packing=False, # for generation
|
||||
)
|
||||
tokenizer = get_tokenizer("google/gemma-2-2b-it")
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_ctx_encoder():
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
ctx_model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
).eval()
|
||||
ctx_encoder_config = Namespace(ctx_encoder_last_layer=26, keep_lm_head=True)
|
||||
ctx_model.language_model = PerLayerActivations(
|
||||
ctx_model.language_model, ctx_encoder_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
return ctx_model, processor
|
||||
|
||||
|
||||
def template_image(img, ctx_processor):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = ctx_processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def get_ctx_features(ctx_inputs, ctx_encoder):
|
||||
forward_outputs = ctx_encoder(**ctx_inputs, output_hidden_states=True)
|
||||
ctx_features = torch.stack(forward_outputs.hidden_states, dim=1)
|
||||
return ctx_features
|
||||
|
||||
|
||||
def generate_loras(ctx_inputs, ctx_features):
|
||||
generated_loras, _ = model.hypernet.generate_weights(
|
||||
ctx_features, attn_mask=torch.ones_like(ctx_inputs["input_ids"])
|
||||
)
|
||||
generated_loras = combine_lora(
|
||||
generated_loras,
|
||||
n_chunks=torch.tensor((1,), device=model.device),
|
||||
lora_bias=model.hypernet.get_head_bias()
|
||||
if model.hypernet.config.use_bias
|
||||
else None,
|
||||
)
|
||||
return generated_loras
|
||||
|
||||
|
||||
def apply_loras(model, generated_loras):
|
||||
n_queries = torch.ones(1, dtype=torch.int32, device=model.device)
|
||||
|
||||
apply_lora_to_layers(
|
||||
model.base_model,
|
||||
model.hypernet.layer_indices,
|
||||
generated_loras,
|
||||
n_queries,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, base_tokenizer = load_checkpoint()
|
||||
ctx_encoder, ctx_processor = load_ctx_encoder()
|
||||
ds = load_dataset("frgfm/imagenette", "full_size", split="validation")
|
||||
# ds = ds.shuffle().select(range(int(0.05 * len(ds))))
|
||||
input_ids = base_tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": INPUT_TXT}],
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
preds = []
|
||||
pred_txts = []
|
||||
corrects = []
|
||||
labels = ds["label"]
|
||||
for sample in tqdm(ds):
|
||||
img = sample["image"]
|
||||
ctx_inputs = template_image(img, ctx_processor).to(ctx_encoder.device)
|
||||
ctx_features = get_ctx_features(ctx_inputs, ctx_encoder)
|
||||
generated_loras = generate_loras(ctx_inputs, ctx_features)
|
||||
apply_loras(model, generated_loras)
|
||||
|
||||
model_outputs = model.base_model.generate(
|
||||
input_ids, max_new_tokens=256, do_sample=False
|
||||
)
|
||||
pred_txt = base_tokenizer.decode(
|
||||
model_outputs[0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
pred_txts.append(pred_txt)
|
||||
preds.append(pred_to_class_id(pred_txt))
|
||||
is_correct = preds[-1] == labels[len(preds) - 1]
|
||||
corrects.append(is_correct)
|
||||
print(
|
||||
f"GT: {CLASS_NAMES[labels[len(preds) - 1]]}, Pred: {pred_txt} -> {CLASS_NAMES[preds[-1]]}, Correct: {is_correct}"
|
||||
)
|
||||
|
||||
acc = sum(corrects) / len(corrects)
|
||||
print(f"Final accuracy: {acc:4f}")
|
||||
|
||||
jsonl_path = os.path.join(RUN_DIR, "imagenette_eval.jsonl")
|
||||
meta_path = os.path.join(RUN_DIR, "imagenette_eval.meta.json")
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
for i, (pred_txt, pred_id, label_id) in enumerate(
|
||||
zip(pred_txts, preds, labels)
|
||||
):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"index": i,
|
||||
"label": int(label_id),
|
||||
"label_name": CLASS_NAMES[label_id],
|
||||
"pred_text": pred_txt,
|
||||
"pred_class_id": int(pred_id),
|
||||
"pred_class_name": CLASS_NAMES[pred_id],
|
||||
"correct": bool(pred_id == label_id),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
meta = {
|
||||
"dataset": "frgfm/imagenette",
|
||||
"subset": "full_size",
|
||||
"split": "validation",
|
||||
"run_dir": RUN_DIR,
|
||||
"prompt": INPUT_TXT,
|
||||
"accuracy": float(acc),
|
||||
"num_samples": len(preds),
|
||||
"class_names": CLASS_NAMES,
|
||||
}
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(f"Wrote samples to {jsonl_path}")
|
||||
print(f"Wrote metadata to {meta_path}")
|
||||
5
scripts/main_exp/eval/llmlingua.sh
Executable file
5
scripts/main_exp/eval/llmlingua.sh
Executable file
|
|
@ -0,0 +1,5 @@
|
|||
for dataset in squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e; do
|
||||
for rate in 0.9 0.8 0.6 0.4 0.2 0.1; do
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets "$dataset" --split test --eval_batch_size_gen=1 --use_llmlingua --llmlingua_compression_rate "$rate" --truncate_if_too_long_ctx
|
||||
done
|
||||
done
|
||||
4
scripts/main_exp/eval/t2l.sh
Executable file
4
scripts/main_exp/eval/t2l.sh
Executable file
|
|
@ -0,0 +1,4 @@
|
|||
# download t2l checkpoint
|
||||
uv run huggingface-cli download SakanaAI/text-to-lora --local-dir . --include "trained_t2l/gemma_2b_t2l"
|
||||
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen=1 --use_t2l
|
||||
20
scripts/main_exp/gen_data.sh
Executable file
20
scripts/main_exp/gen_data.sh
Executable file
|
|
@ -0,0 +1,20 @@
|
|||
# download fineweb_edu to `data/raw_datasets/fineweb_edu
|
||||
uv run data/download_fineweb_edu.py
|
||||
|
||||
# generate qa data
|
||||
# run from 000 to 013
|
||||
for shard_id in $(seq -f "%03g" 0 13); do
|
||||
uv run data/generate_fw_edu_qa_v2.py --shard_pattern "${shard_id}_00000" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it --max_length=2000 --max_model_length=2048
|
||||
uv run data/generate_fw_edu_qa_v2_repeat.py --shard_pattern "min_0_to_2000/${shard_id}*level_0" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it
|
||||
|
||||
# self-generated response QA data
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern "data/raw_datasets/fw_qa_v2/min_0_to_2000/${shard_id}*_level_1*" --closed_qa_prob 1.0
|
||||
done
|
||||
|
||||
|
||||
# val split
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern 'data/raw_datasets/fw_qa_v2/min_0_to_2000/*_level_0_val.parquet'
|
||||
|
||||
# self-gen data for other ds
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names squad_compact ropes_compact drop_compact --split train --closed_qa_prob 1.0
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names pwc_compact --split train --closed_qa_prob 0.0
|
||||
29
scripts/main_exp/train-cross-enc-chunk-slurm.sh
Normal file
29
scripts/main_exp/train-cross-enc-chunk-slurm.sh
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=ctxlora
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --partition=a3
|
||||
#SBATCH --gpus=8
|
||||
#SBATCH --output=slurm_logs/%x-%j.out
|
||||
#SBATCH --error=slurm_logs/%x-%j.out
|
||||
|
||||
port=$((10000 + ($SLURM_JOBID % 50000)))
|
||||
echo "Using port: $port"
|
||||
|
||||
# port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=20000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--warmup_steps=2000 \
|
||||
--learning_rate=2e-5 \
|
||||
"$@"
|
||||
24
scripts/main_exp/train-cross-enc-slurm.sh
Normal file
24
scripts/main_exp/train-cross-enc-slurm.sh
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=ctxlora
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --partition=a3
|
||||
#SBATCH --gpus=8
|
||||
#SBATCH --output=slurm_logs/%x-%j.out
|
||||
#SBATCH --error=slurm_logs/%x-%j.out
|
||||
|
||||
port=$((10000 + ($SLURM_JOBID % 50000)))
|
||||
echo "Using port: $port"
|
||||
|
||||
# port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
|
||||
"$@"
|
||||
15
scripts/main_exp/train_no_qa.sh
Executable file
15
scripts/main_exp/train_no_qa.sh
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True \
|
||||
"$@"
|
||||
Loading…
Add table
Add a link
Reference in a new issue