doc-to-lora/src/ctx_to_lora/model_loading.py

184 lines
5.5 KiB
Python
Raw Normal View History

2026-02-27 03:47:04 +00:00
import logging
import os
import torch
from peft import PeftModel
from peft import get_peft_config as _get_peft_config
from peft.utils import PeftType
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
logger = logging.getLogger()
GEMMA_VISION_MODELS = [
"google/gemma-3-4b-it",
"google/gemma-3-12b-it",
"google/gemma-3-27b-it",
]
def check_is_vision_model(model_name):
return model_name in GEMMA_VISION_MODELS
def get_model_and_tokenizer(
model_name_or_path,
train,
requires_grad,
use_flash_attn=True,
peft_config=None,
model_kwargs=None,
tokenizer_kwargs=None,
use_q_lora=False,
device="cuda",
dtype=torch.bfloat16,
):
model = get_model(
model_name_or_path,
train,
requires_grad,
use_flash_attn,
peft_config,
model_kwargs,
use_q_lora,
device,
dtype,
)
tokenizer = get_tokenizer(model_name_or_path, tokenizer_kwargs, peft_config, train)
model.config.pad_token_id = tokenizer.pad_token_id
if getattr(model, "generation_config", None):
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def get_tokenizer(
model_name_or_path, tokenizer_kwargs=None, peft_config=None, train=False
):
padding_side = "left" if not train else "right"
truncation_side = "left"
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
add_bos_tokens=False,
add_eos_tokens=False,
padding_side=padding_side,
truncation_side=truncation_side,
trust_remote_code=True,
**tokenizer_kwargs,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
template_path = f"chat_templates/{model_name_or_path}.jinja"
if not os.path.exists(template_path):
logger.warning(
f"Chat template not found at {template_path}. Using default template."
)
return tokenizer
logger.info(f"Using chat template from {template_path}")
chat_template = open(template_path).read()
chat_template = chat_template.replace(" ", "").replace("\n", "")
tokenizer.chat_template = chat_template
return tokenizer
def get_model(
model_name_or_path,
train,
requires_grad,
use_flash_attn=True,
peft_config=None,
model_kwargs=None,
use_q_lora=False,
device="cuda",
dtype=torch.bfloat16,
):
model_init_kwargs = dict(
pretrained_model_name_or_path=model_name_or_path,
device_map=device,
torch_dtype=dtype,
trust_remote_code=True,
attn_implementation="eager",
use_cache=None,
)
is_vision_model = check_is_vision_model(model_name_or_path)
if model_kwargs is not None:
model_init_kwargs.update(model_kwargs)
is_bidir_model = (
"bert" in model_name_or_path.lower() or "gte" in model_name_or_path.lower()
)
if use_flash_attn:
if "gte" not in model_name_or_path:
model_init_kwargs["attn_implementation"] = "flash_attention_2"
elif "gte" in model_name_or_path:
model_init_kwargs["attn_implementation"] = "sdpa"
if is_vision_model:
# always use sdpa for vision models
# model_init_kwargs["attn_implementation"] = "sdpa"
model_init_kwargs.pop("use_cache")
elif is_bidir_model:
model_init_kwargs["torch_dtype"] = torch.float32
model_init_kwargs.pop("use_cache")
if use_q_lora:
# https://huggingface.co/blog/4bit-transformers-bitsandbytes
# https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing
# see bitsandbytes for the quantization implementation https://github.com/bitsandbytes-foundation/bitsandbytes
# see unsloth https://huggingface.co/docs/trl/v0.7.11/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth
# does work currently bc it modifies the forward pass call of Linear
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_init_kwargs["quantization_config"] = bnb_config
logger.debug(f"Model init kwargs: {model_init_kwargs}")
if not is_vision_model:
if is_bidir_model:
model = AutoModel.from_pretrained(**model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**model_init_kwargs)
else:
model = Gemma3ForConditionalGeneration.from_pretrained(**model_init_kwargs)
model = model.language_model
if peft_config is not None:
model = PeftModel(model, peft_config)
model.train(train)
for name, param in model.named_parameters():
param.requires_grad = requires_grad
return model
def get_lora_config(model_dir, **kwargs):
if "target_modules" not in kwargs or kwargs["target_modules"] is None:
logger.info("No target modules specified for LoRA.")
return None
r = kwargs.pop("lora_r", 8)
peft_conf_kwargs = dict(
r=r,
peft_type=PeftType.LORA,
base_model_name_or_path=model_dir,
task_type="CAUSAL_LM",
lora_dropout=kwargs.get("lora_dropout", 0.0),
lora_alpha=r ** (3 / 2) * 2,
)
peft_conf_kwargs.update(kwargs)
peft_config = _get_peft_config(peft_conf_kwargs)
return peft_config