mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-04-26 00:26:22 +02:00
Doc-to-LoRA release
This commit is contained in:
commit
1abe8ae16d
92 changed files with 22131 additions and 0 deletions
8
scripts/main_exp/eval/base_model.sh
Executable file
8
scripts/main_exp/eval/base_model.sh
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
# no truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1
|
||||
|
||||
# w/ truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --truncate_if_too_long_inp
|
||||
|
||||
# no context
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --remove_context
|
||||
6
scripts/main_exp/eval/cd.sh
Executable file
6
scripts/main_exp/eval/cd.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=4
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=1
|
||||
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
|
|
@ -0,0 +1 @@
|
|||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets $1 --split test --use_cd --cd_update_iterations 50 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=5 --cd_batch_size=2
|
||||
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
13
scripts/main_exp/eval/d2l.sh
Executable file
13
scripts/main_exp/eval/d2l.sh
Executable file
|
|
@ -0,0 +1,13 @@
|
|||
# main results
|
||||
# batched
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1
|
||||
|
||||
# iterative
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1 --use_iterative_mode
|
||||
|
||||
# query internalization
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad --split test --eval_batch_size_gen=1 --flip_ctx_inp
|
||||
|
||||
# replaced squad context
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_assistant_ctx_no_passage --split test
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_negative_no_passage --split test
|
||||
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.ctx_encoder import PerLayerActivations
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers
|
||||
from ctx_to_lora.modeling.lora_merger import combine_lora
|
||||
|
||||
CLASS_NAMES = [
|
||||
"tench",
|
||||
"English springer",
|
||||
"cassette player",
|
||||
"chain saw",
|
||||
"church",
|
||||
"French horn",
|
||||
"garbage truck",
|
||||
"gas pump",
|
||||
"golf ball",
|
||||
"parachute",
|
||||
]
|
||||
CLASS_TO_INT = {name: i for i, name in enumerate(CLASS_NAMES)}
|
||||
INPUT_TXT = f"What is in this image? Choose exactly one of the following classes: {', '.join(CLASS_NAMES)}. Response with only the correct class without any other text."
|
||||
RUN_DIR = "train_outputs/runs/Oct16_02-37-04_slurm0-a3nodeset-8_94074_1d62ecb8"
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _normalize_compact(text: str) -> str:
|
||||
return re.sub(r"[^a-z0-9]", "", text.lower())
|
||||
|
||||
|
||||
def _build_alias_map():
|
||||
alias_overrides = {
|
||||
"english springer spaniel": "English springer",
|
||||
"springer spaniel": "English springer",
|
||||
"chainsaw": "chain saw",
|
||||
"dump truck": "garbage truck",
|
||||
"refuse truck": "garbage truck",
|
||||
"garbage lorry": "garbage truck",
|
||||
"fuel pump": "gas pump",
|
||||
"gas station pump": "gas pump",
|
||||
"cassette deck": "cassette player",
|
||||
"cassette recorder": "cassette player",
|
||||
"fish": "tench",
|
||||
"tench fish": "tench",
|
||||
"french horn instrument": "French horn",
|
||||
"golfball": "golf ball",
|
||||
"skydiving": "parachute",
|
||||
"parachutist": "parachute",
|
||||
}
|
||||
|
||||
alias_map = {}
|
||||
|
||||
def register(alias: str, canonical: str):
|
||||
alias = _normalize_text(alias)
|
||||
if alias:
|
||||
alias_map[alias] = canonical
|
||||
alias_map[_normalize_compact(alias)] = canonical
|
||||
|
||||
for name in CLASS_NAMES:
|
||||
register(name, name)
|
||||
register(name.replace(" ", ""), name)
|
||||
register(name.replace(" ", "-"), name)
|
||||
|
||||
for alias, canonical in alias_overrides.items():
|
||||
register(alias, canonical)
|
||||
|
||||
return alias_map
|
||||
|
||||
|
||||
CLASS_ALIAS_MAP = _build_alias_map()
|
||||
|
||||
|
||||
def pred_to_class_id(pred_txt: str) -> int:
|
||||
norm_pred = _normalize_text(pred_txt)
|
||||
compact_pred = _normalize_compact(pred_txt)
|
||||
|
||||
for alias, canonical in CLASS_ALIAS_MAP.items():
|
||||
if alias and (alias in norm_pred or alias in compact_pred):
|
||||
return CLASS_TO_INT[canonical]
|
||||
|
||||
pred_tokens = set(norm_pred.split())
|
||||
best_class = None
|
||||
best_token_hits = -1
|
||||
for name in CLASS_NAMES:
|
||||
class_tokens = set(_normalize_text(name).split())
|
||||
if class_tokens and class_tokens.issubset(pred_tokens):
|
||||
return CLASS_TO_INT[name]
|
||||
hits = sum(token in pred_tokens for token in class_tokens)
|
||||
if hits > best_token_hits:
|
||||
best_token_hits = hits
|
||||
best_class = name
|
||||
|
||||
best_ratio = -1.0
|
||||
for name in CLASS_NAMES:
|
||||
ratio = SequenceMatcher(None, norm_pred, _normalize_text(name)).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_class = name
|
||||
|
||||
return CLASS_TO_INT[best_class]
|
||||
|
||||
|
||||
def load_checkpoint():
|
||||
checkpoint_path = f"{RUN_DIR}/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict,
|
||||
train=False,
|
||||
base_model_kwargs=dict(attn_implementation="flash_attention_2"),
|
||||
use_flash_attn=True,
|
||||
use_sequence_packing=False, # for generation
|
||||
)
|
||||
tokenizer = get_tokenizer("google/gemma-2-2b-it")
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_ctx_encoder():
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
ctx_model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
).eval()
|
||||
ctx_encoder_config = Namespace(ctx_encoder_last_layer=26, keep_lm_head=True)
|
||||
ctx_model.language_model = PerLayerActivations(
|
||||
ctx_model.language_model, ctx_encoder_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
return ctx_model, processor
|
||||
|
||||
|
||||
def template_image(img, ctx_processor):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = ctx_processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def get_ctx_features(ctx_inputs, ctx_encoder):
|
||||
forward_outputs = ctx_encoder(**ctx_inputs, output_hidden_states=True)
|
||||
ctx_features = torch.stack(forward_outputs.hidden_states, dim=1)
|
||||
return ctx_features
|
||||
|
||||
|
||||
def generate_loras(ctx_inputs, ctx_features):
|
||||
generated_loras, _ = model.hypernet.generate_weights(
|
||||
ctx_features, attn_mask=torch.ones_like(ctx_inputs["input_ids"])
|
||||
)
|
||||
generated_loras = combine_lora(
|
||||
generated_loras,
|
||||
n_chunks=torch.tensor((1,), device=model.device),
|
||||
lora_bias=model.hypernet.get_head_bias()
|
||||
if model.hypernet.config.use_bias
|
||||
else None,
|
||||
)
|
||||
return generated_loras
|
||||
|
||||
|
||||
def apply_loras(model, generated_loras):
|
||||
n_queries = torch.ones(1, dtype=torch.int32, device=model.device)
|
||||
|
||||
apply_lora_to_layers(
|
||||
model.base_model,
|
||||
model.hypernet.layer_indices,
|
||||
generated_loras,
|
||||
n_queries,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, base_tokenizer = load_checkpoint()
|
||||
ctx_encoder, ctx_processor = load_ctx_encoder()
|
||||
ds = load_dataset("frgfm/imagenette", "full_size", split="validation")
|
||||
# ds = ds.shuffle().select(range(int(0.05 * len(ds))))
|
||||
input_ids = base_tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": INPUT_TXT}],
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
preds = []
|
||||
pred_txts = []
|
||||
corrects = []
|
||||
labels = ds["label"]
|
||||
for sample in tqdm(ds):
|
||||
img = sample["image"]
|
||||
ctx_inputs = template_image(img, ctx_processor).to(ctx_encoder.device)
|
||||
ctx_features = get_ctx_features(ctx_inputs, ctx_encoder)
|
||||
generated_loras = generate_loras(ctx_inputs, ctx_features)
|
||||
apply_loras(model, generated_loras)
|
||||
|
||||
model_outputs = model.base_model.generate(
|
||||
input_ids, max_new_tokens=256, do_sample=False
|
||||
)
|
||||
pred_txt = base_tokenizer.decode(
|
||||
model_outputs[0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
pred_txts.append(pred_txt)
|
||||
preds.append(pred_to_class_id(pred_txt))
|
||||
is_correct = preds[-1] == labels[len(preds) - 1]
|
||||
corrects.append(is_correct)
|
||||
print(
|
||||
f"GT: {CLASS_NAMES[labels[len(preds) - 1]]}, Pred: {pred_txt} -> {CLASS_NAMES[preds[-1]]}, Correct: {is_correct}"
|
||||
)
|
||||
|
||||
acc = sum(corrects) / len(corrects)
|
||||
print(f"Final accuracy: {acc:4f}")
|
||||
|
||||
jsonl_path = os.path.join(RUN_DIR, "imagenette_eval.jsonl")
|
||||
meta_path = os.path.join(RUN_DIR, "imagenette_eval.meta.json")
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
for i, (pred_txt, pred_id, label_id) in enumerate(
|
||||
zip(pred_txts, preds, labels)
|
||||
):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"index": i,
|
||||
"label": int(label_id),
|
||||
"label_name": CLASS_NAMES[label_id],
|
||||
"pred_text": pred_txt,
|
||||
"pred_class_id": int(pred_id),
|
||||
"pred_class_name": CLASS_NAMES[pred_id],
|
||||
"correct": bool(pred_id == label_id),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
meta = {
|
||||
"dataset": "frgfm/imagenette",
|
||||
"subset": "full_size",
|
||||
"split": "validation",
|
||||
"run_dir": RUN_DIR,
|
||||
"prompt": INPUT_TXT,
|
||||
"accuracy": float(acc),
|
||||
"num_samples": len(preds),
|
||||
"class_names": CLASS_NAMES,
|
||||
}
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(f"Wrote samples to {jsonl_path}")
|
||||
print(f"Wrote metadata to {meta_path}")
|
||||
5
scripts/main_exp/eval/llmlingua.sh
Executable file
5
scripts/main_exp/eval/llmlingua.sh
Executable file
|
|
@ -0,0 +1,5 @@
|
|||
for dataset in squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e; do
|
||||
for rate in 0.9 0.8 0.6 0.4 0.2 0.1; do
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets "$dataset" --split test --eval_batch_size_gen=1 --use_llmlingua --llmlingua_compression_rate "$rate" --truncate_if_too_long_ctx
|
||||
done
|
||||
done
|
||||
4
scripts/main_exp/eval/t2l.sh
Executable file
4
scripts/main_exp/eval/t2l.sh
Executable file
|
|
@ -0,0 +1,4 @@
|
|||
# download t2l checkpoint
|
||||
uv run huggingface-cli download SakanaAI/text-to-lora --local-dir . --include "trained_t2l/gemma_2b_t2l"
|
||||
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen=1 --use_t2l
|
||||
Loading…
Add table
Add a link
Reference in a new issue