doc-to-lora/train.py

420 lines
14 KiB
Python
Raw Permalink Normal View History

2026-02-27 03:47:04 +00:00
import contextlib
import logging
import os
from copy import deepcopy
from functools import partial
import numpy as np
import torch
import wandb
from datasets import disable_caching
from peft import PeftModel
from transformers import (
AutoConfig,
set_seed,
)
from ctx_to_lora.configs import (
AggregatorArguments,
ArgumentParser,
CtxEncoderArguments,
CtxTrainingArguments,
DataArguments,
ExperimentSetup,
HypernetArguments,
LoRAArguments,
ModelArguments,
TrainingArguments,
)
from ctx_to_lora.data.collator import ( # train_packed_collator,; DefaultDataCollator,
flatten_if_not_packed,
)
from ctx_to_lora.data.processing import get_tokenized_dataset, pack
from ctx_to_lora.metrics import (
Evaluator,
compute_metrics,
compute_per_token_acc,
compute_perplexity,
compute_prefix_matching,
)
from ctx_to_lora.model_loading import (
check_is_vision_model,
get_lora_config,
get_model_and_tokenizer,
get_tokenizer,
)
from ctx_to_lora.modeling.hypernet import (
ModulatedPretrainedModel,
get_hypernet_config,
)
from ctx_to_lora.trainer import train_model
from ctx_to_lora.utils import (
compile_linear,
extract_cli_args,
get_run_name,
log_num_train_params,
save_yaml,
setup_logging,
validate_args,
)
logger = logging.getLogger()
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
def main():
############ Argument parsing
parser = ArgumentParser(
(
DataArguments,
CtxTrainingArguments,
ModelArguments,
LoRAArguments,
TrainingArguments,
HypernetArguments,
AggregatorArguments,
CtxEncoderArguments,
)
)
(
data_args,
ctx_args,
model_args,
lora_args,
training_args,
hypernet_args,
aggregator_args,
ctx_encoder_args,
) = parser.parse()
# there shouldn't be overlap between args
validate_args(
[
data_args,
ctx_args,
model_args,
lora_args,
training_args,
hypernet_args,
aggregator_args,
ctx_encoder_args,
]
)
assert ctx_args.use_sequence_packing, (
f"Please set use_sequence_packing=True in {ctx_args}. It's faster!"
)
set_seed(training_args.seed)
checkpoint_dir = training_args.resume_from_checkpoint
# should be the same across processes
# still possible to have a name crash though
# logging_dir is just "runs/DATE_TIME_HOSTNAME"
slurm_job_id = f"_{os.getenv('SLURM_JOB_ID')}" if os.getenv("SLURM_JOB_ID") else ""
run_name = (
get_run_name(seed_str=training_args.logging_dir.strip("runs/") + slurm_job_id)
if not checkpoint_dir
else checkpoint_dir.strip("/").split("/")[-2]
)
output_dir = f"train_outputs/runs/{run_name}"
setup_logging(output_dir, debug=os.getenv("DEBUG", False))
logger.debug(f"CMD: {' '.join(os.sys.argv)}")
cli_args = extract_cli_args(os.sys.argv)
save_yaml(cli_args, f"{output_dir}/cli_args.yaml")
if "config" in cli_args:
config_name = os.path.basename(cli_args["config"]).split(".yaml")[0]
os.environ["WANDB_TAGS"] = config_name
run_name = os.path.basename(output_dir)
training_args.run_name = run_name
training_args.output_dir = output_dir
training_args.logging_dir = output_dir
if (
training_args.lr_scheduler_type == "cosine_with_min_lr"
and training_args.lr_scheduler_kwargs is None
):
training_args.lr_scheduler_kwargs = {"min_lr": 1e-7}
args = {
**vars(deepcopy(data_args)),
**vars(deepcopy(ctx_args)),
**vars(deepcopy(model_args)),
**vars(deepcopy(lora_args)),
**vars(deepcopy(training_args)),
**vars(deepcopy(hypernet_args)),
**vars(deepcopy(aggregator_args)),
**vars(deepcopy(ctx_encoder_args)),
}
args["deepspeed_plugin"] = None
logger.debug(f"args: {args}")
save_yaml(args, f"{output_dir}/args.yaml")
############ Model setup
if not ctx_args.from_pretrained_checkpoint:
model_name = model_args.model_name_or_path
base_model, tokenizer = get_model_and_tokenizer(
**vars(model_args),
train=True,
requires_grad=False,
peft_config=get_lora_config(model_name, **vars(lora_args)),
)
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_name is not None:
ctx_encoder_model_config = AutoConfig.from_pretrained(
ctx_name, trust_remote_code=True
)
if ("Llama" in ctx_name and "Vision" in ctx_name) or check_is_vision_model(
ctx_name
):
ctx_encoder_model_config = ctx_encoder_model_config.text_config
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
else:
ctx_name = base_model.base_model.config.name_or_path
ctx_encoder_model_config = base_model.config
ctx_tokenizer = tokenizer
if ctx_args.exp_setup == ExperimentSetup.HYPERLORA:
logger.info("Using HyperLoRA")
if not ctx_args.from_pretrained_checkpoint:
hypernet_config = get_hypernet_config(
base_model,
ctx_encoder_model_config,
hypernet_args,
aggregator_args,
ctx_encoder_args,
)
if ctx_encoder_args.layer_idx is None:
ctx_encoder_args.layer_idx = (
ctx_encoder_model_config.num_hidden_layers // 4
)
logger.info(
f"Using the first {ctx_encoder_args.layer_idx} layers"
" as the context encoder"
)
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_encoder_args.ctx_encoder_last_layer is None and (
ctx_name is not None and ctx_name != base_model.name_or_path
):
logger.info(
f"Setting ctx_encoder_last_layer to {base_model.name_or_path} max layers"
f":{base_model.config.num_hidden_layers}"
)
ctx_encoder_args.ctx_encoder_last_layer = (
base_model.config.num_hidden_layers
)
model = ModulatedPretrainedModel(
base_model, hypernet_config, ctx_encoder_args
)
else:
if checkpoint_dir:
ctx_args.from_pretrained_checkpoint = (
f"{checkpoint_dir}/pytorch_model.bin"
)
logger.info(
f"Loading from checkpoint: {ctx_args.from_pretrained_checkpoint}"
)
model = ModulatedPretrainedModel.from_state_dict(
torch.load(ctx_args.from_pretrained_checkpoint, weights_only=False),
train=True,
use_flash_attn=model_args.use_flash_attn,
)
tokenizer = get_tokenizer(model.base_model.config.name_or_path, train=True)
ctx_name = model.ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_name is None:
ctx_name = model.base_model.config.name_or_path
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
training_args.gen_lora_l1_reg_coef = ctx_args.gen_lora_l1_reg_coef
training_args.use_kl_loss = ctx_args.use_kl_loss
training_args.use_per_ctx_average_loss = ctx_args.use_per_ctx_average_loss
if len([p for p in model.ctx_encoder.parameters() if p.requires_grad]):
raise ValueError("ctx_encoder contains trainable parameters")
if len([p for p in model.base_model.parameters() if p.requires_grad]):
raise ValueError("base model contains trainable parameters")
model.hypernet.compile(fullgraph=True, mode="max-autotune")
else:
# activate LoRA
base_model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
base_model_config.save_pretrained(output_dir)
logger.info("Using LoRA")
model.set_adapter("default")
model = torch.compile(model)
model.train()
logger.debug(model)
log_num_train_params(model)
############ Dataset setup
logger.info("Loading dataset...")
add_ctx_to_chat = not isinstance(model, ModulatedPretrainedModel)
ctx_model_max_len = model.ctx_encoder.config.max_position_embeddings
if ctx_args.max_ctx_len > 0:
ctx_model_max_len = ctx_args.max_ctx_len
if ctx_args.max_ctx_chunk_len <= 0:
# set default chunk size to max length of the ctx encoder
ctx_args.max_ctx_chunk_len = ctx_model_max_len
if ctx_args.num_chunk_probs is not None:
ctx_args.num_chunk_probs = {
int(k): float(v) for k, v in ctx_args.num_chunk_probs.items()
}
_get_tokenized_dataset = partial(
get_tokenized_dataset,
max_qas_len=ctx_args.max_qas_len,
max_qas_per_sample=ctx_args.max_qas_per_sample,
base_model_max_len=model.base_model.config.max_position_embeddings,
tokenizer=tokenizer,
ctx_model_max_len=ctx_model_max_len,
ctx_tokenizer=ctx_tokenizer,
add_ctx_to_chat=add_ctx_to_chat,
add_negative_prompt=ctx_args.add_negative_prompt,
max_ctx_chunk_len=ctx_args.max_ctx_chunk_len,
min_ctx_chunk_len=ctx_args.min_ctx_chunk_len,
num_chunk_probs=ctx_args.num_chunk_probs,
max_ctx_chunk_num=ctx_args.max_ctx_chunk_num,
use_kl_loss=ctx_args.use_kl_loss,
)
splits = ["train"]
if training_args.eval_strategy != "no":
splits.append("validation")
tokenized_ds = {split: {} for split in splits}
for split, ds_names in zip(
splits,
[data_args.train_ds_names, data_args.val_ds_names],
):
if not ds_names:
continue
ctx_mgr = (
training_args.main_process_first()
if split == "train"
else contextlib.nullcontext()
)
with ctx_mgr:
# process and tokenize on the main process
# then other replicas can just load the cached dataset
# we dont save cache for validation ds
for ds_name in ds_names:
ds = _get_tokenized_dataset(ds_name, split)
base_name = os.path.basename(ds_name)
if ds_name.startswith("self_gen/"):
ds_name = "self_gen/" + base_name
else:
ds_name = base_name
tokenized_ds[split][ds_name] = ds
train_ds = tokenized_ds["train"]
if data_args.max_train_samples_per_ds is not None:
for ds_name, ds in train_ds.items():
if data_args.max_train_samples_per_ds >= len(ds):
continue
train_ds[ds_name] = ds.take(data_args.max_train_samples_per_ds)
logging.info(f"train_ds: {train_ds}")
val_ds = dict()
if "validation" in tokenized_ds:
n_val_samples = data_args.max_val_samples_per_ds
for ds_name, ds in tokenized_ds["validation"].items():
if ds is None:
# take some samples from train_ds
ds = train_ds[ds_name].take(n_val_samples)
train_ds[ds_name] = train_ds[ds_name].skip(n_val_samples)
val_ds[ds_name] = ds
val_indices = np.random.permutation(len(ds))[:n_val_samples]
val_ds[ds_name] = val_ds[ds_name].select(val_indices)
with training_args.main_process_first():
train_ds = pack(
train_ds,
ctx_args.max_packed_inp_len,
ctx_args.max_packed_ctx_len,
max_packed_size=-1,
seed=training_args.seed,
num_proc=30,
)
logger.info("Setting per_device_train_batch_size to 1")
training_args.per_device_train_batch_size = 1
logger.info(f"train_ds: {train_ds}")
logger.info(f"val_ds: {val_ds}")
collator = flatten_if_not_packed
if isinstance(model, ModulatedPretrainedModel):
if isinstance(model.base_model, PeftModel):
base_model = model.base_model.base_model
else:
base_model = model.base_model
if ctx_name is not None:
logger.info("Compiling ctx_encoder_model")
ctx_base_model = model.ctx_encoder.base_model
compile_linear(ctx_base_model)
elif isinstance(model, PeftModel):
base_model = model.base_model
logger.info("Compiling base_model")
base_model.compile(fullgraph=True, mode="max-autotune")
if LOCAL_RANK == 0:
wandb.init(
project=os.getenv("WANDB_PROJECT"),
name=run_name,
group=run_name,
config=args,
tags=os.getenv("WANDB_TAGS").split(","),
notes=ctx_args.notes,
resume="allow",
)
else:
wandb.init(mode="disabled")
train_model(
model,
training_args,
train_ds,
val_ds,
collator,
compute_metrics=partial(
compute_metrics,
evaluator=Evaluator(
[compute_per_token_acc, compute_prefix_matching, compute_perplexity]
),
),
)
logger.info(f"Training run finished and saved to {output_dir}")
if __name__ == "__main__":
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_DIR"] = ".wandb/"
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT") or "ctx_to_lora"
os.environ["WANDB_WATCH"] = ""
os.environ["WANDB_CONSOLE"] = "off"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["OMP_NUM_THREADS"] = "23"
torch._dynamo.config.capture_scalar_outputs = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if os.getenv("DEBUG", False):
disable_caching()
main()