import logging import os import torch from peft import PeftModel from peft import get_peft_config as _get_peft_config from peft.utils import PeftType from transformers import ( AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Gemma3ForConditionalGeneration, ) logger = logging.getLogger() GEMMA_VISION_MODELS = [ "google/gemma-3-4b-it", "google/gemma-3-12b-it", "google/gemma-3-27b-it", ] def check_is_vision_model(model_name): return model_name in GEMMA_VISION_MODELS def get_model_and_tokenizer( model_name_or_path, train, requires_grad, use_flash_attn=True, peft_config=None, model_kwargs=None, tokenizer_kwargs=None, use_q_lora=False, device="cuda", dtype=torch.bfloat16, ): model = get_model( model_name_or_path, train, requires_grad, use_flash_attn, peft_config, model_kwargs, use_q_lora, device, dtype, ) tokenizer = get_tokenizer(model_name_or_path, tokenizer_kwargs, peft_config, train) model.config.pad_token_id = tokenizer.pad_token_id if getattr(model, "generation_config", None): model.generation_config.pad_token_id = tokenizer.pad_token_id return model, tokenizer def get_tokenizer( model_name_or_path, tokenizer_kwargs=None, peft_config=None, train=False ): padding_side = "left" if not train else "right" truncation_side = "left" if tokenizer_kwargs is None: tokenizer_kwargs = {} tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, add_bos_tokens=False, add_eos_tokens=False, padding_side=padding_side, truncation_side=truncation_side, trust_remote_code=True, **tokenizer_kwargs, ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id template_path = f"chat_templates/{model_name_or_path}.jinja" if not os.path.exists(template_path): logger.warning( f"Chat template not found at {template_path}. Using default template." ) return tokenizer logger.info(f"Using chat template from {template_path}") chat_template = open(template_path).read() chat_template = chat_template.replace(" ", "").replace("\n", "") tokenizer.chat_template = chat_template return tokenizer def get_model( model_name_or_path, train, requires_grad, use_flash_attn=True, peft_config=None, model_kwargs=None, use_q_lora=False, device="cuda", dtype=torch.bfloat16, ): model_init_kwargs = dict( pretrained_model_name_or_path=model_name_or_path, device_map=device, torch_dtype=dtype, trust_remote_code=True, attn_implementation="eager", use_cache=None, ) is_vision_model = check_is_vision_model(model_name_or_path) if model_kwargs is not None: model_init_kwargs.update(model_kwargs) is_bidir_model = ( "bert" in model_name_or_path.lower() or "gte" in model_name_or_path.lower() ) if use_flash_attn: if "gte" not in model_name_or_path: model_init_kwargs["attn_implementation"] = "flash_attention_2" elif "gte" in model_name_or_path: model_init_kwargs["attn_implementation"] = "sdpa" if is_vision_model: # always use sdpa for vision models # model_init_kwargs["attn_implementation"] = "sdpa" model_init_kwargs.pop("use_cache") elif is_bidir_model: model_init_kwargs["torch_dtype"] = torch.float32 model_init_kwargs.pop("use_cache") if use_q_lora: # https://huggingface.co/blog/4bit-transformers-bitsandbytes # https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing # see bitsandbytes for the quantization implementation https://github.com/bitsandbytes-foundation/bitsandbytes # see unsloth https://huggingface.co/docs/trl/v0.7.11/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth # does work currently bc it modifies the forward pass call of Linear bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model_init_kwargs["quantization_config"] = bnb_config logger.debug(f"Model init kwargs: {model_init_kwargs}") if not is_vision_model: if is_bidir_model: model = AutoModel.from_pretrained(**model_init_kwargs) else: model = AutoModelForCausalLM.from_pretrained(**model_init_kwargs) else: model = Gemma3ForConditionalGeneration.from_pretrained(**model_init_kwargs) model = model.language_model if peft_config is not None: model = PeftModel(model, peft_config) model.train(train) for name, param in model.named_parameters(): param.requires_grad = requires_grad return model def get_lora_config(model_dir, **kwargs): if "target_modules" not in kwargs or kwargs["target_modules"] is None: logger.info("No target modules specified for LoRA.") return None r = kwargs.pop("lora_r", 8) peft_conf_kwargs = dict( r=r, peft_type=PeftType.LORA, base_model_name_or_path=model_dir, task_type="CAUSAL_LM", lora_dropout=kwargs.get("lora_dropout", 0.0), lora_alpha=r ** (3 / 2) * 2, ) peft_conf_kwargs.update(kwargs) peft_config = _get_peft_config(peft_conf_kwargs) return peft_config