Doc-to-LoRA release

This commit is contained in:
51616 2026-02-27 03:47:04 +00:00
commit 1abe8ae16d
92 changed files with 22131 additions and 0 deletions

34
.gitignore vendored Normal file
View file

@ -0,0 +1,34 @@
.venv/
.cursorrules
tmp/
openai_batches/
models/
results/
datasets/
llm-comparator/
trained_d2l/
# Ignore data files but not scripts
/data/processed_datasets/
/data/raw_datasets/
/data/distil/
__pycache__/
*egg-info/
.vscode/
.ipynb_checkpoints/
wandb/
*.bak
*outputs/
plots/
*.out
*.err
*.pt
*.pth
*.bin
*.safetensors
*tfevents*
watcher_state.yaml
eval_results/
.github/
*.code-workspace
.wandb/
.ruff_cache/

24
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,24 @@
repos:
- repo: https://github.com/asottile/pyupgrade
rev: v3.20.0
hooks:
- id: pyupgrade
args:
- --py310-plus
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.11.10
hooks:
# Run the linter.
- id: ruff
types_or: [python, pyi]
args: [--fix]
# Run the formatter.
- id: ruff-format
types_or: [python, pyi]
- repo: https://github.com/PyCQA/isort
rev: 6.0.1
hooks:
- id: isort
args:
- --profile=black

109
README.md Normal file
View file

@ -0,0 +1,109 @@
<div align="center">
<h1>Doc-to-LoRA (D2L): Learning to Instantly Internalize Contexts</h1>
:sparkles:<a href="https://pub.sakana.ai/doc-to-lora/">Interactive Web</a> |
:newspaper:<a href="https://x.com/SakanaAILabs">X</a> |
:scroll:<a href="https://arxiv.org/abs/2602.15902">Paper</a> |
:hugs:<a href="https://huggingface.co/SakanaAI">Hugging Face</a> |
:octocat:<a href="https://github.com/SakanaAI/doc-to-lora">GitHub</a>
<br>A reference implementation of Doc-to-LoRA (D2L).<br>
</div>
<div align="center">
<img height="300px" src="assets/overview_animation.gif" />
</div>
---
## 🛠️ Installation
```
curl -LsSf https://astral.sh/uv/install.sh | sh
./install.sh
```
## 🤗 Pre-Trained Models
```
uv run huggingface-cli login
uv run huggingface-cli download SakanaAI/doc-to-lora --local-dir trained_d2l --include "*/"
```
## 🚀 Python API Usage
```python
# caveat: this interface only supports non-batched inputs
# for batched inference please see `src/ctx_to_lora/modeling/hypernet.py`
import torch
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
# model loading
checkpoint_path = "trained_d2l/gemma_demo/checkpoint-80000/pytorch_model.bin"
state_dict = torch.load(checkpoint_path, weights_only=False)
model = ModulatedPretrainedModel.from_state_dict(
state_dict, train=False, use_sequence_packing=False
)
model.reset()
tokenizer = get_tokenizer(model.base_model.name_or_path)
# prepare data
doc = open("data/sakana_wiki.txt", "r").read()
chat = [{"role": "user", "content": "Tell me about Sakana AI."}]
chat_ids = tokenizer.apply_chat_template(
chat,
add_special_tokens=False,
return_attention_mask=False,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
# calls after internalization will be influenced by internalized info
model.internalize(doc)
outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
# remove internalized info
# model.reset()
# without internalized info, the model will halucinate
# outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
# print(tokenizer.decode(outputs[0]))
```
### 🎮 Interactive Demo
```bash
uv run demo/app.py
```
<div align="center">
<h3>Video Demo</h3>
<video src="https://github.com/user-attachments/assets/16781365-5ec2-4c1c-b4f4-aeeebe3c2be5" controls autoplay muted playsinline preload="metadata" width="900"></video>
</div>
### 🧪 Experimental Scripts
To run any of the following scripts, use `uv run $PATH_TO_SCRIPT` from the root of this project.
| Experiment | Data prep | Training | Evaluation | Notes |
| ------------------------------------ | ------------------------------------- | ----------------------------- | ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- |
| [Main experiment](scripts/main_exp/) | `scripts/main_exp/0-download_data.sh` | `scripts/main_exp/1-train.sh` | `scripts/main_exp/eval/*.sh` | Downloading data is fastest; regenerate only if you need fresh synthetic data. Evaluation scripts reproduce the main paper metrics. |
| [NIAH](scripts/niah/) | `scripts/niah/0-gen_data.sh` | `scripts/niah/1-train.sh` | `scripts/niah/2-eval.sh` | Run the scripts in order; data generation only needs to happen once |
### 🔬 Self-Generated Data Viewer
After downloading/generating the data, we can see samples of the data using this script.
```bash
uv run webui/self_gen_viewer.py
```
See more info at [webui/SELF_GEN_VIEWER.md](webui/SELF_GEN_VIEWER.md).
### 📚 Citation
```bibtex
@techreport{sakana2025doc-to-lora,
title = {{Doc-to-LoRA: Learning to Instantly Internalize Contexts}},
author = {Rujikorn Charakorn and Edoardo Cetin and Shinnosuke Uesaka and Robert Tjarko Lange},
institution = {Sakana AI},
year = {2026},
month = {Febuary},
note = {Technical Report}
}
```

17
accelerate_config.yaml Normal file
View file

@ -0,0 +1,17 @@
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
enable_cpu_affinity: true
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

BIN
assets/cover.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 MiB

View file

@ -0,0 +1,18 @@
{%- if messages[0].role == 'system' %}
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
{%- endif %}
{%- for message in messages %}
{%- if message.content is string %}
{%- set content = message.content %}
{%- else %}
{%- set content = '' %}
{%- endif %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- elif message.role == "assistant" %}
{{- '<|im_start|>' + message.role + '\n' }}{% generation %}{{ content + '<|im_end|>' }}{% endgeneration %}{{ '\n' }}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}

View file

@ -0,0 +1,24 @@
{{- bos_token }}
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{% for message in loop_messages %}
{% if (message['role'] == 'assistant') %}
{% set role = 'model' %}
{% else %}
{% set role = message['role'] %}
{% endif %}
{%- if message['role'] == 'user' and loop.first and system_message is defined %}
{{ '<start_of_turn>' + role + '\n' + system_message + '\n\n' + message['content'] | trim + '<end_of_turn>\n' }}
{%- elif message['role'] == 'assistant' %}
{{ '<start_of_turn>' + role + '\n' }}{% generation %}{{ message['content'] + '<end_of_turn>' }}{% endgeneration %}{{ '\n' }}
{%- else %}
{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
{%- endif %}
{% endfor %}
{% if add_generation_prompt %}
{{'<start_of_turn>model\n'}}
{% endif %}

View file

@ -0,0 +1,24 @@
{%- if messages[0]['role'] == 'system' %}
{%- set system_message = messages[0]['content'] %}
{%- set loop_messages = messages[1:] %}
{%- else %}
{%- set loop_messages = messages %}
{%- endif %}
{{- bos_token }}
{%- for message in loop_messages %}
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
{%- endif %}
{%- if message['role'] == 'user' %}
{%- if loop.first and system_message is defined %}
{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST] ' }}
{%- else %}
{{- ' [INST] ' + message['content'] + ' [/INST] ' }}
{%- endif %}
{%- elif message['role'] == 'assistant' %}
{% generation %}{{ message['content'] + eos_token }}{% endgeneration %}
{%- else %}
{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
{%- endif %}
{%- endfor %}

View file

@ -0,0 +1,31 @@
# LoRA
lora_r: 8
lora_dropout: 0.0
target_modules:
- down_proj
use_kl_loss: true
ctx_encoder_type: per_layer_activations
n_latent_queries: 8
num_blocks: 9
num_self_attn_per_block: 0
gradient_accumulation_steps: 11
max_packed_inp_len: 6144
max_packed_ctx_len: 6144
# data
train_ds_names:
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_0.0/pwc_compact
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/squad_compact
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/ropes_compact
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/drop_compact
val_ds_names:
- squad
- pwc
- drop
- ropes
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet

View file

@ -0,0 +1,30 @@
# LoRA
lora_r: 8
lora_dropout: 0.0
target_modules:
- down_proj
use_kl_loss: true
ctx_encoder_type: per_layer_activations
n_latent_queries: 8
num_blocks: 9
num_self_attn_per_block: 0
gradient_accumulation_steps: 11
max_packed_inp_len: 6144
max_packed_ctx_len: 6144
# data
train_ds_names:
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_0.0/pwc_compact
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/squad_compact
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/ropes_compact
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/drop_compact
val_ds_names:
- squad
- pwc
- drop
- ropes

View file

@ -0,0 +1,31 @@
# LoRA
lora_r: 8
lora_dropout: 0.0
target_modules:
- down_proj
use_kl_loss: true
ctx_encoder_type: per_layer_activations
n_latent_queries: 8
num_blocks: 9
num_self_attn_per_block: 0
gradient_accumulation_steps: 11
max_packed_inp_len: 6144
max_packed_ctx_len: 6144
# data
train_ds_names:
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/pwc_compact
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/squad_compact
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/ropes_compact
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/drop_compact
val_ds_names:
- squad
- pwc
- drop
- ropes
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet

View file

@ -0,0 +1,27 @@
# LoRA
lora_r: 8
lora_dropout: 0.0
target_modules:
- down_proj
use_kl_loss: true
ctx_encoder_type: per_layer_activations
n_latent_queries: 8
num_blocks: 9
num_self_attn_per_block: 0
gradient_accumulation_steps: 11
max_packed_inp_len: 6144
max_packed_ctx_len: 6144
# data
train_ds_names:
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
val_ds_names:
- squad
- pwc
- drop
- ropes
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet

View file

@ -0,0 +1,14 @@
# LoRA
lora_r: 8
lora_dropout: 0.0
target_modules:
- down_proj
# data
train_ds_names:
- ctx_magic_number_32_128
- ctx_magic_number_128_256
val_ds_names:
- ctx_magic_number_32_128
- ctx_magic_number_128_256

View file

@ -0,0 +1,42 @@
import gc
from datasets import Dataset, load_dataset
from tqdm import tqdm
if __name__ == "__main__":
ds_name = "ucinlp/drop"
for split in ["train", "validation"]:
ctx_qa_dict = dict()
ds = load_dataset(ds_name, split=split)
print(f"Original size: {len(ds)}")
for i, sample in tqdm(enumerate(ds)):
ctx = sample["passage"]
if ctx not in ctx_qa_dict:
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
question = sample["question"]
answer = sample["answers_spans"]["spans"][0]
ctx_qa_dict[ctx]["prompts"].append(question)
ctx_qa_dict[ctx]["responses"].append(answer)
print(f"Unique contexts: {len(ctx_qa_dict)}")
# convert ctx_qa_dict to a list of dictionaries
samples = [
{
"context": ctx,
"prompts": ctx_qa_dict[ctx]["prompts"],
"responses": ctx_qa_dict[ctx]["responses"],
}
for ctx in ctx_qa_dict
]
print(f"Sampled data: {samples[0]}")
# breakpoint()
# save to a new dataset
ds = Dataset.from_list(samples)
save_path = f"./data/raw_datasets/drop_compact/{split}/ds.parquet"
print(f"Saving dataset to {save_path}")
ds.to_parquet(save_path)
print("=" * 80)
del ds, samples, ctx_qa_dict
gc.collect()

43
data/build_pwc_compact.py Normal file
View file

@ -0,0 +1,43 @@
import gc
from datasets import Dataset, load_dataset
from tqdm import tqdm
if __name__ == "__main__":
ds_name = "sggetao/PwC"
for split in ["train", "test"]:
ctx_qa_dict = dict()
ds = load_dataset(ds_name, split=split)
print(f"Original size: {len(ds)}")
for i, sample in tqdm(enumerate(ds)):
ctx = sample["input"]
if ctx not in ctx_qa_dict:
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
# question = closed_qa_prompting(sample["prompt"])
question = sample["prompt"]
answer = sample["answer"]
ctx_qa_dict[ctx]["prompts"].append(question)
ctx_qa_dict[ctx]["responses"].append(answer)
print(f"Unique contexts: {len(ctx_qa_dict)}")
# convert ctx_qa_dict to a list of dictionaries
samples = [
{
"context": ctx,
"prompts": ctx_qa_dict[ctx]["prompts"],
"responses": ctx_qa_dict[ctx]["responses"],
}
for ctx in ctx_qa_dict
]
print(f"Sampled data: {samples[0]}")
# breakpoint()
# save to a new dataset
ds = Dataset.from_list(samples)
save_path = f"./data/raw_datasets/pwc_compact/{split}/ds.parquet"
print(f"Saving dataset to {save_path}")
ds.to_parquet(save_path)
print("=" * 80)
del ds, samples, ctx_qa_dict
gc.collect()

View file

@ -0,0 +1,45 @@
import gc
from datasets import Dataset, load_dataset
from tqdm import tqdm
if __name__ == "__main__":
ds_name = "allenai/ropes"
for split in ["train", "validation"]:
ctx_qa_dict = dict()
ds = load_dataset(ds_name, split=split)
print(f"Original size: {len(ds)}")
for i, sample in tqdm(enumerate(ds)):
ctx_template = "{background}\n{situation}"
response = sample["answers"]["text"][0]
bg_txt = sample["background"]
situation_txt = sample["situation"]
ctx = ctx_template.format(background=bg_txt, situation=situation_txt)
q = sample["question"]
if ctx not in ctx_qa_dict:
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
ctx_qa_dict[ctx]["prompts"].append(q)
ctx_qa_dict[ctx]["responses"].append(response)
print(f"Unique contexts: {len(ctx_qa_dict)}")
# convert ctx_qa_dict to a list of dictionaries
samples = [
{
"context": ctx,
"prompts": ctx_qa_dict[ctx]["prompts"],
"responses": ctx_qa_dict[ctx]["responses"],
}
for ctx in ctx_qa_dict
]
print(f"Sampled data: {samples[0]}")
# breakpoint()
# save to a new dataset
ds = Dataset.from_list(samples)
save_path = f"./data/raw_datasets/ropes_compact/{split}/ds.parquet"
print(f"Saving dataset to {save_path}")
ds.to_parquet(save_path)
print("=" * 80)
del ds, samples, ctx_qa_dict
gc.collect()

View file

@ -0,0 +1,42 @@
import gc
from datasets import Dataset, load_dataset
from tqdm import tqdm
if __name__ == "__main__":
ds_name = "data/raw_datasets/squad"
for split in ["train", "validation"]:
ctx_qa_dict = dict()
ds = load_dataset(ds_name, split=split)
print(f"Original size: {len(ds)}")
for i, sample in tqdm(enumerate(ds)):
ctx = sample["context"]
if ctx not in ctx_qa_dict:
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
question = sample["question"]
answer = sample["answers"]["text"][0]
ctx_qa_dict[ctx]["prompts"].append(question)
ctx_qa_dict[ctx]["responses"].append(answer)
print(f"Unique contexts: {len(ctx_qa_dict)}")
# convert ctx_qa_dict to a list of dictionaries
samples = [
{
"context": ctx,
"prompts": ctx_qa_dict[ctx]["prompts"],
"responses": ctx_qa_dict[ctx]["responses"],
}
for ctx in ctx_qa_dict
]
print(f"Sampled data: {samples[0]}")
# breakpoint()
# save to a new dataset
ds = Dataset.from_list(samples)
save_path = f"./data/raw_datasets/squad_compact/{split}/ds.parquet"
print(f"Saving dataset to {save_path}")
ds.to_parquet(save_path)
print("=" * 80)
del ds, samples, ctx_qa_dict
gc.collect()

View file

@ -0,0 +1,10 @@
from huggingface_hub import snapshot_download
if __name__ == "__main__":
fw_dir = "./data/raw_datasets/fineweb_edu/"
snapshot_download(
"HuggingFaceFW/fineweb-edu",
repo_type="dataset",
local_dir=fw_dir,
allow_patterns="sample/100BT/*",
)

View file

@ -0,0 +1,288 @@
import argparse
import json
import math
import os
import random
# -----------------------------
# Config knobs (edit or use CLI)
# -----------------------------
TOKENS_PER_BLOCK = 40 # rough heuristic tokens per noise block
BASE_SAMPLES_PER_BIN = (
320_000 # training samples budget scaler only (val/test fixed at 1000 each)
)
RNG_SEED = 42
NOISE_BLOCK = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
SPECIAL_TPL = "The special magic number is {magic_number}."
SEP = "\n" # between blocks
def save_jsonl(data: list[dict], filepath: str) -> None:
parent_dir = os.path.dirname(filepath)
if parent_dir:
os.makedirs(parent_dir, exist_ok=True)
with open(filepath, "w") as f:
for entry in data:
json.dump(entry, f)
f.write("\n")
essential_digits4 = lambda: f"{random.randint(0, 9_999):04d}"
def _choose_position(total_blocks: int, depth_bin: int) -> int:
"""Choose an insertion index for the special sentence within [0, total_blocks-1]
such that its relative depth falls within the depth bin [i/10, (i+1)/10).
"""
if total_blocks <= 0:
return 0
# Use floor for start and ceil for end to cover boundaries evenly
start = math.floor(total_blocks * (depth_bin / 10))
end = math.ceil(total_blocks * ((depth_bin + 1) / 10)) - 1
# clamp
start = max(0, min(start, total_blocks - 1))
end = max(start, min(end, total_blocks - 1))
return random.randint(start, end)
def _build_example(total_blocks: int, depth_bin: int) -> dict:
"""Build one example with a special line inserted among noise blocks.
total_blocks: total number of blocks in the final context (including the special one)
depth_bin: integer in [0, 9]
"""
total_blocks = max(1, total_blocks)
# Prepare blocks
magic = essential_digits4()
special_line = SPECIAL_TPL.format(magic_number=magic)
# We'll have (total_blocks - 1) noise blocks and 1 special line
noise_count = max(0, total_blocks - 1)
blocks: list[str] = [NOISE_BLOCK for _ in range(noise_count)]
insert_at = _choose_position(total_blocks, depth_bin)
# Insert special line at the desired position within the final sequence
# If noise_count == 0, we just return special
if noise_count == 0:
final_blocks = [special_line]
else:
# Compose by interleaving noise and inserting special at index
# Build a list of length `total_blocks` and fill
final_blocks = []
noise_idx = 0
for idx in range(total_blocks):
if idx == insert_at:
final_blocks.append(special_line)
else:
final_blocks.append(blocks[noise_idx])
noise_idx += 1
context = SEP.join(final_blocks)
prompt = "What is the special magic number? Reply with only the number."
response = magic
return {"context": context, "prompt": prompt, "response": response}
def generate_examples(n: int, k: int) -> list[dict]:
"""Generate n examples (all for block length k) evenly across 10 depth bins."""
if n <= 0:
return []
base = n // 10
rem = n % 10
counts = [base + (1 if i < rem else 0) for i in range(10)]
out: list[dict] = []
for depth_bin, c in enumerate(counts):
for _ in range(c):
out.append(_build_example(total_blocks=k, depth_bin=depth_bin))
random.shuffle(out)
return out
def main():
parser = argparse.ArgumentParser(
description="Generate noise-wrapped special magic number dataset (similar structure to generate_ctx_kv.py)",
)
parser.add_argument("--seed", type=int, default=RNG_SEED, help="Random seed")
parser.add_argument(
"--tokenizer-name",
type=str,
default="google/gemma-2-2b-it",
help=("Tokenizer name"),
)
parser.add_argument(
"--base-samples-per-bin",
type=int,
default=BASE_SAMPLES_PER_BIN,
help="Baseline number of TRAINING samples per token bin (scaled by bin width). Validation & test are always 1000 each.",
)
parser.add_argument(
"--out-prefix",
type=str,
default="data/raw_datasets/ctx_magic_number",
help="Output directory prefix (bin range will be appended)",
)
parser.add_argument(
"--tokens-per-block",
"--tokens-per-pair",
dest="tokens_per_block",
type=int,
default=TOKENS_PER_BLOCK,
help="Heuristic tokens per noise block for bucketing",
)
parser.add_argument(
"--only-first-n-bins",
type=int,
default=None,
help="For quick tests: only generate the first N token bins",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Print a small sample and exit without writing files",
)
args = parser.parse_args()
random.seed(args.seed)
# ----------------------------------------------------
# Optional: report tokenizer-based token length stats
# ----------------------------------------------------
if args.tokenizer_name:
try:
from transformers import AutoTokenizer # type: ignore
except Exception as e: # pragma: no cover
raise RuntimeError(
"Failed to import transformers. Install it or omit --tokenizer-name."
) from e
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
noise_token_count = len(tokenizer(NOISE_BLOCK).input_ids)
special_example = SPECIAL_TPL.format(magic_number="0000")
special_token_count = len(tokenizer(special_example).input_ids)
print(
f"[Tokenizer: {args.tokenizer_name}] Noise block tokens: {noise_token_count} | Special line tokens: {special_token_count}"
)
tok_bins = [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
]
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
if args.only_first_n_bins is not None:
tok_bins = tok_bins[: args.only_first_n_bins]
if args.tokenizer_name:
max_hi = max(hi for _, hi in tok_bins)
def measure_len(k: int) -> int:
if k == 1:
ctx = SPECIAL_TPL.format(magic_number="0000")
else:
blocks = [NOISE_BLOCK] * (k - 1) + [
SPECIAL_TPL.format(magic_number="0000")
]
ctx = SEP.join(blocks)
return len(tokenizer(ctx).input_ids)
lengths: list[int] = [0]
k = 1
while True:
L = measure_len(k)
lengths.append(L)
if L >= max_hi:
break
k += 1
len_bins = []
for lo, hi in tok_bins:
k_lo = None
for kk in range(1, len(lengths)):
if lengths[kk] >= lo:
k_lo = kk
break
if k_lo is None or lengths[k_lo] >= hi:
len_bins.append((0, 0))
continue
k_hi = len(lengths)
for kk in range(k_lo, len(lengths)):
if lengths[kk] >= hi:
k_hi = kk
break
len_bins.append((k_lo, k_hi))
base_tokens = lengths[1]
delta = (lengths[2] - lengths[1]) if len(lengths) > 2 else 0
print(
f"Using tokenizer-measured block ranges. base_tokens={base_tokens} approx_delta={delta}"
)
else:
len_bins = [
(lo // args.tokens_per_block, hi // args.tokens_per_block)
for (lo, hi) in tok_bins
]
if args.dry_run:
for lb in len_bins:
if lb[1] > lb[0]:
k = max(1, lb[0])
sample = generate_examples(10, k)
print("Sample entry:")
print(json.dumps(sample[0], indent=2))
break
return
# -----------------------------------------------
# Main generation per token bin
# -----------------------------------------------
TARGET_VAL = 1000
TARGET_TEST = 1000
for len_bin, tok_bin in zip(len_bins, tok_bins):
if len_bin[1] <= len_bin[0]:
print(f"Skipping token bin {tok_bin} (no valid block counts)")
continue
k_start = max(1, len_bin[0])
k_end = max(1, len_bin[1])
k_values = list(range(k_start, k_end))
bin_size = len(k_values)
save_dir = f"{args.out_prefix}_{tok_bin[0]}_{tok_bin[1]}"
training_enabled = tok_bin[1] <= 1024 # unchanged policy
if training_enabled:
train_data: list[dict] = []
# Distribute training budget across k values.
# Scale: per_k = base_samples_per_bin / bin_size
per_k_train = max(1, args.base_samples_per_bin // max(1, bin_size))
for k in k_values:
train_data += generate_examples(per_k_train, k)
val_data: list[dict] = []
test_data: list[dict] = []
base_val = TARGET_VAL // bin_size
rem_val = TARGET_VAL % bin_size
base_test = TARGET_TEST // bin_size
rem_test = TARGET_TEST % bin_size
for idx, k in enumerate(k_values):
n_val_k = base_val + (1 if idx < rem_val else 0)
n_test_k = base_test + (1 if idx < rem_test else 0)
if n_val_k:
val_data += generate_examples(n_val_k, k)
if n_test_k:
test_data += generate_examples(n_test_k, k)
random.shuffle(val_data)
random.shuffle(test_data)
os.makedirs(save_dir, exist_ok=True)
if training_enabled:
save_jsonl(train_data, f"{save_dir}/train.jsonl")
save_jsonl(val_data, f"{save_dir}/val.jsonl")
save_jsonl(test_data, f"{save_dir}/test.jsonl")
if training_enabled:
print(
f"Dataset generated at {save_dir} (train={len(train_data)} val={len(val_data)} test={len(test_data)})"
)
else:
print(
f"Dataset (val/test only) generated at {save_dir} (val={len(val_data)} test={len(test_data)})"
)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,269 @@
import argparse
import os
import re
from glob import glob
import pandas as pd
from datasets import Dataset, load_dataset
from vllm import LLM, SamplingParams
STOP_STRINGS = {
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
}
SYSTEM_TEMPLATE = (
"You are a creative and helpful assistant.\n"
"You are tasked with generating questions for reading comprehension tests.\n"
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
"**DO NOT** hallucinate or make up information."
)
# based on Make Your LLM Fully Utilize the Context (https://arxiv.org/pdf/2404.16811)
PROMPT_TEMPLATE = (
"### Instructions ###\n"
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
"information provided in the context, not general questions that suit any context.\n\n"
"### Context ###\n"
"{context}\n\n\n"
"### Rules ###\n"
"Rules to follow when generating the questions:\n"
"1. The questions must be specific to the given context and fully answerable from information present in the given context.\n"
"2. Ask questions that are fact-seeking based on the information provided.\n"
"3. Make sure the questions are clear and unambiguous.\n"
"4. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the questions.\n"
"5. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
"6. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
"7. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
"8. Ignore typos, spacing, and grammatical errors in the context.\n\n"
"Rules to follow when generating the answers:\n"
"1. The answers must use the (implied) information provided in the context.\n"
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the answers.\n"
"3. Do not just copy words from the context. Answer the question in your own words.\n"
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
"Respond with {n_qa_pairs} question-answer pairs.\n"
"Always use proper grammar and punctuation.\n"
"Try to use different question forms and styles.\n"
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
"The question-answer pairs should be in the following format:\n"
"Question 1: {{question_1}}\n"
"Answer 1: {{answer_1}}\n"
"Question 2: {{question_2}}\n"
"Answer 2: {{answer_2}}\n"
"..."
)
def get_prompt(context, n_qa_pairs):
prompt = PROMPT_TEMPLATE.format(context=context, n_qa_pairs=n_qa_pairs)
return prompt
def check_should_skip(txt: str, vllm_model: str) -> bool:
"""Check if the response should be skipped based on stop strings."""
for stop in STOP_STRINGS[vllm_model]:
if stop in txt[-len(stop) :]:
return (txt.split(stop)[0], False) # Found a valid stop string
return (txt, True) # No valid stop string found, skip this response
def postprocess_qa_pairs(res_txt: str):
"""
Postprocesses the QA pairs from the response text.
Args:
res_txt: The response text.
n_qa_pairs: The number of QA pairs.
Returns:
A tuple of two lists, the first containing the questions and the second containing the answers.
"""
# capture everything after each "Question {number}:" until "Answer"
res_txt = remove_think(res_txt)
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
questions = re.findall(q_pattern, res_txt, flags=re.S)
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
answers = re.findall(a_pattern, res_txt, flags=re.S)
if len(questions) != len(answers):
print(f"Warning---number of questions and answers do not match")
print(f"Number of questions: {len(questions)}")
print(f"Number of answers: {len(answers)}")
out_q = []
out_a = []
n_skips = 0
if (len(questions) > 0) and (len(answers) > 0):
n_gen_pairs = min(len(questions), len(answers))
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
for i in range(n_gen_pairs):
response = answers[i].strip()
question = questions[i].strip()
if not response or not question:
print(f"Skipping empty question or answer at index {i}")
continue
if (not has_left_over) and (i == n_gen_pairs - 1):
response, skip = check_should_skip(response, vllm_model)
if skip:
print(f"Skipping due to missing stop string")
n_skips += 1
continue
out_q.append(question.strip())
out_a.append(response.strip())
print(f"Skipped {n_skips} responses due to missing stop strings")
return out_q, out_a
def length_filter(sample, min_len, max_len):
return min_len <= len(sample["text"]) <= max_len
def remove_think(txt):
return txt.split("</think>")[-1]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate QA pairs from FineWeb Edu dataset"
)
parser.add_argument(
"--vllm_model",
type=str,
default=os.environ.get("vllm_model", "google/gemma-2-27b-it"),
help="VLLM model to use for generation",
)
parser.add_argument(
"--shard_pattern",
type=str,
required=True,
help="Pattern to match shard files (e.g., '000_0000*')",
)
parser.add_argument(
"--n_qa_pairs",
type=int,
required=True,
help="Number of question-answer pairs to generate per context",
)
parser.add_argument(
"--min_length",
type=int,
default=0,
help="Minimum length of the context to consider for generation",
)
parser.add_argument(
"--max_length",
type=int,
default=2000,
help="Maximum length of the context to consider for generation",
)
parser.add_argument(
"--max_model_length",
type=int,
default=2**14,
help="Maximum length of the model input (context + prompt + response) in tokens",
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug mode - process only first 100 samples",
)
args = parser.parse_args()
vllm_model = args.vllm_model
print(f"Using model: {vllm_model}")
llm_kwargs = dict(
model=vllm_model,
dtype="bfloat16",
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_model_len=args.max_model_length,
limit_mm_per_prompt={"image": 0},
)
llm = LLM(**llm_kwargs)
tokenizer = llm.get_tokenizer()
shard_pattern = args.shard_pattern
n_qa_pairs = args.n_qa_pairs
paths = glob(
f"./data/raw_datasets/fineweb_edu/sample/100BT/{shard_pattern}.parquet"
)
split = "train[:100]" if args.debug else "train"
for path in paths:
ds = load_dataset(
"parquet",
data_files=path,
split=split,
)
ds = ds.filter(
length_filter,
fn_kwargs={"min_len": args.min_length, "max_len": args.max_length},
num_proc=8,
)
ctxs = [sample["text"] for sample in iter(ds)]
messages = [
[
{"role": "system", "content": SYSTEM_TEMPLATE},
{"role": "user", "content": get_prompt(ctx, n_qa_pairs)},
]
for ctx in ctxs
]
print(f"Generating from {len(messages)} contexts")
completions = llm.chat(
messages,
sampling_params=SamplingParams(
max_tokens=2048,
temperature=0.0,
# needed for checking if stop tokens are present
skip_special_tokens=False,
include_stop_str_in_output=True,
),
)
samples = []
for ctx, completion in zip(ctxs, completions):
questions, answers = postprocess_qa_pairs(completion.outputs[0].text)
samples.append(
{
"context": ctx,
"prompts_level_0": questions,
"responses_level_0": answers,
}
)
if args.debug:
print(f"{ctx=}")
print(f"{completion.outputs[0].text=}")
for q, a in zip(questions, answers):
print(f"{q=}")
print(f"{a=}")
print()
print("=" * 80)
print(f"Generated {len(samples)} samples")
df = pd.DataFrame(samples)
ds = Dataset.from_pandas(df)
val_ds = ds.take(10)
ds = ds.skip(10)
shard_name = path.split("/")[-1].split(".")[0]
shard_name += "_level_0"
if args.debug:
shard_name += "_debug"
ds.to_parquet(
f"data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}.parquet"
)
val_ds.to_parquet(
f"data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}_val.parquet"
)
print(
f"Saved to data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}.parquet"
)
print(
f"Saved to data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}_val.parquet"
)

View file

@ -0,0 +1,296 @@
import argparse
import gc
import os
import re
from glob import glob
from datasets import load_dataset
from vllm import LLM, SamplingParams
STOP_STRINGS = {
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
}
SYSTEM_TEMPLATE = (
"You are a creative and helpful assistant.\n"
"You are tasked with generating questions for reading comprehension tests.\n"
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
"**DO NOT** hallucinate or make up information."
)
# based on Make Your LLM Fully Utilize the Context (https://arxiv.org/pdf/2404.16811)
PROMPT_TEMPLATE = (
"### Instructions ###\n"
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
"information provided in the context, not general questions that suit any context.\n\n"
"### Context ###\n"
"{context}\n\n\n"
"### Example Question-Answer Pairs ###\n"
"{qa_pairs}\n\n\n"
"### Rules ###\n"
"Rules to follow when generating the questions:\n"
"1. The questions must be specific to the given context and fully answerable from information present in *or* implied from the given context.\n"
"2. The questions must *not* be redundant with the example questions-answer pairs provided.\n"
"3. You should prioritize fact-seeking questions. Consider reversal questions, e.g., asking 'What causes X to happen?' is valid when 'Y causes X' is presented in the context.\n"
"4. If all the facts in the context are already covered by the provided examples, you must generate *more complicated* questions that require reasoning beyond simple information retrieval.\nThis includes asking about information that can be inferred, requiring synthesizing information from multiple parts of the text, or understanding relationships between concepts, events, or individuals mentioned in the context. For example, if the context says 'The Eiffel Tower was completed in 1889 after 2 years of construction', you can ask 'When did the construction of the Eiffel Tower begin?'. Here's another example: if the context says 'Alice is Bob's mother. Bob is Charlie's Dad', you can ask 'Who is Charlie's grandmother?'.\n"
"5. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the questions.\n"
"6. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
"7. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
"8. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
"9. Ignore typos, spacing, and grammatical errors in the context.\n\n"
"Rules to follow when generating the answers:\n"
"1. The answers must use the (implied) information provided in the context.\n"
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the answers.\n"
"3. Do not just copy words from the context. Answer the question in your own words.\n"
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
"Respond with {n_qa_pairs} question-answer pairs.\n"
"Always use proper grammar and punctuation.\n"
"Try to use different question forms and styles.\n"
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
"The question-answer pairs should be in the following format:\n"
"Question 1: {{question_1}}\n"
"Answer 1: {{answer_1}}\n"
"Question 2: {{question_2}}\n"
"Answer 2: {{answer_2}}\n"
"..."
)
def get_prompt(context, example_qa_pairs, n_qa_pairs):
prompt = PROMPT_TEMPLATE.format(
context=context,
qa_pairs=example_qa_pairs,
n_qa_pairs=n_qa_pairs,
)
return prompt
def check_should_skip(txt: str, vllm_model: str) -> bool:
"""Check if the response should be skipped based on stop strings."""
for stop in STOP_STRINGS[vllm_model]:
if stop in txt[-len(stop) :]:
return (txt.split(stop)[0], False) # Found a valid stop string
return (txt, True) # No valid stop string found, skip this response
def postprocess_qa_pairs(res_txt: str):
"""
Postprocesses the QA pairs from the response text.
Args:
res_txt: The response text.
n_qa_pairs: The number of QA pairs.
Returns:
A tuple of two lists, the first containing the questions and the second containing the answers.
"""
# capture everything after each "Question {number}:" until "Answer"
res_txt = remove_think(res_txt)
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
questions = re.findall(q_pattern, res_txt, flags=re.S)
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
answers = re.findall(a_pattern, res_txt, flags=re.S)
if len(questions) != len(answers):
print(f"Warning---number of questions and answers do not match")
print(f"Number of questions: {len(questions)}")
print(f"Number of answers: {len(answers)}")
out_q = []
out_a = []
n_skips = 0
if (len(questions) > 0) and (len(answers) > 0):
n_gen_pairs = min(len(questions), len(answers))
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
for i in range(n_gen_pairs):
response = answers[i].strip()
question = questions[i].strip()
if not response or not question:
print(f"Skipping empty question or answer at index {i}")
continue
if (not has_left_over) and (i == n_gen_pairs - 1):
response, skip = check_should_skip(response, vllm_model)
if skip:
print(f"Skipping due to missing stop string")
n_skips += 1
continue
out_q.append(question.strip())
out_a.append(response.strip())
print(f"Skipped {n_skips} responses due to missing stop strings")
return out_q, out_a
def flatten_list(l):
out = []
for x in l:
out += x
return out
def remove_think(txt):
return txt.split("</think>")[-1]
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate QA pairs from FineWeb Edu dataset"
)
parser.add_argument(
"--vllm_model",
type=str,
default=os.environ.get("vllm_model", "google/gemma-2-27b-it"),
help="VLLM model to use for generation",
)
parser.add_argument(
"--shard_pattern",
type=str,
required=True,
help="Pattern to match shard files (e.g., '000_0000*')",
)
parser.add_argument(
"--n_qa_pairs",
type=int,
required=True,
help="Number of question-answer pairs to generate per context",
)
parser.add_argument(
"--max_model_length",
type=int,
default=2**12,
help="Maximum length of the model input (context + prompt + response) in tokens",
)
parser.add_argument(
"--debug",
action="store_true",
help="Debug mode - process only first 100 samples",
)
args = parser.parse_args()
vllm_model = args.vllm_model
print(f"Using model: {vllm_model}")
llm_kwargs = dict(
model=vllm_model,
dtype="bfloat16",
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_model_len=2**14,
limit_mm_per_prompt={"image": 0},
)
llm = LLM(**llm_kwargs)
tokenizer = llm.get_tokenizer()
shard_pattern = args.shard_pattern
n_qa_pairs = args.n_qa_pairs
paths = glob(f"./data/raw_datasets/fw_qa_v2/{shard_pattern}.parquet")
split = "train[:100]" if args.debug else "train"
for path in paths:
assert "_level" in path, (
"Path must contain '_level' to indicate the dataset level"
)
shard_name = path.split("/")[-1].split(".")[0].split("_debug")[0]
if "/" in shard_pattern:
shard_name = "/".join(shard_pattern.split("/")[:-1]) + "/" + shard_name
cur_level = int(shard_name.split("_level_")[-1])
next_level = cur_level + 1
ds = load_dataset(
"parquet",
data_files=path,
split=split,
)
prompt_cols = [col for col in ds.column_names if col.startswith("prompts")]
response_cols = [col for col in ds.column_names if col.startswith("responses")]
assert len(prompt_cols) > 0, "No prompt columns found in the dataset"
if len(prompt_cols) != len(response_cols):
raise ValueError(
"Number of prompt columns does not match number of response columns"
)
samples_data = []
for sample in iter(ds):
# Format existing QA pairs as examples
example_qa_pairs = ""
questions = flatten_list([sample[col] for col in prompt_cols])
answers = flatten_list([sample[col] for col in response_cols])
for i, (q, a) in enumerate(zip(questions, answers), 1):
example_qa_pairs += f"Question {i}: {q}\nAnswer {i}: {a}\n"
samples_data.append(
{"context": sample["context"], "example_qa_pairs": example_qa_pairs}
)
del ds
gc.collect()
messages = [
[
{"role": "system", "content": SYSTEM_TEMPLATE},
{
"role": "user",
"content": get_prompt(
sample["context"], sample["example_qa_pairs"], n_qa_pairs
),
},
]
for sample in samples_data
]
print(f"Generating from {len(messages)} contexts")
completions = llm.chat(
messages,
sampling_params=SamplingParams(
temperature=0.0,
# needed for checking if stop tokens are present
skip_special_tokens=False,
include_stop_str_in_output=True,
),
)
samples = []
for sample_data, completion in zip(samples_data, completions):
questions, answers = postprocess_qa_pairs(completion.outputs[0].text)
samples.append(
{
"context": sample_data["context"],
f"prompts_level_{next_level}": questions,
f"responses_level_{next_level}": answers,
}
)
if args.debug:
print(f"context={sample_data['context']}")
print(f"example_qa_pairs={sample_data['example_qa_pairs']}")
print(f"{completion.outputs[0].text=}")
for q, a in zip(questions, answers):
print(f"{q=}")
print(f"{a=}")
print()
print("=" * 80)
del samples_data
gc.collect()
print(f"Generated {len(samples)} samples")
ds = load_dataset(
"parquet",
data_files=path,
split=split,
)
ds = ds.add_column(
f"prompts_level_{next_level}",
[sample[f"prompts_level_{next_level}"] for sample in samples],
)
ds = ds.add_column(
f"responses_level_{next_level}",
[sample[f"responses_level_{next_level}"] for sample in samples],
)
shard_name_base = shard_name.split("_level_")[0]
shard_name = f"{shard_name_base}_level_{next_level}"
if args.debug:
shard_name += "_debug"
ds.to_parquet(f"data/raw_datasets/fw_qa_v2/{shard_name}.parquet")
print(f"Saved to data/raw_datasets/fw_qa_v2/{shard_name}.parquet")

387
data/gutenburg_sample.txt Normal file
View file

@ -0,0 +1,387 @@
The Project Gutenberg eBook, Addison, by William John Courthope
This eBook is for the use of anyone anywhere at no cost and with
almost no restrictions whatsoever. You may copy it, give it away or
re-use it under the terms of the Project Gutenberg License included
with this eBook or online at www.gutenberg.org
Title: Addison
Author: William John Courthope
Release Date: November 27, 2012 [eBook #41496]
Language: English
Character set encoding: ISO-8859-1
***START OF THE PROJECT GUTENBERG EBOOK ADDISON***
E-text prepared by the Online Distributed Proofreading Team
(http://www.pgdp.net) from page images generously made available by
Internet Archive (http://archive.org)
Note: Images of the original pages are available through
Internet Archive. See
http://archive.org/details/addison_00cour
Transcriber's note:
Text enclosed by underscores is in italics (_italics_).
Text enclosed by curly brackets is superscripted
(example: y{e}).
English Men of Letters
Edited by John Morley
ADDISON
by
W. J. COURTHOPE
Harper & Brothers Publishers
New York and London
1902
* * * * *
ENGLISH MEN OF LETTERS.
EDITED BY JOHN MORLEY.
JOHNSON Leslie Stephen.
GIBBON J. C. Morison.
SCOTT R. H. Hutton.
SHELLEY J. A. Symonds.
HUME T. H. Huxley.
GOLDSMITH William Black.
DEFOE William Minto.
BURNS J. C. Shairp.
SPENSER R. W. Church.
THACKERAY Anthony Trollope.
BURKE John Morley.
MILTON Mark Pattison.
HAWTHORNE Henry James, Jr.
SOUTHEY E. Dowden.
CHAUCER A. W. Ward.
BUNYAN J. A. Froude.
COWPER Goldwin Smith.
POPE Leslie Stephen.
BYRON John Nichol.
LOCKE Thomas Fowler.
WORDSWORTH F. Myers.
DRYDEN G. Saintsbury.
LANDOR Sidney Colvin.
DE QUINCEY David Masson.
LAMB Alfred Ainger.
BENTLEY R. C. Jebb.
DICKENS A. W. Ward.
GRAY E. W. Gosse.
SWIFT Leslie Stephen.
STERNE H. D. Traill.
MACAULAY J. Cotter Morison.
FIELDING Austin Dobson.
SHERIDAN Mrs. Oliphant.
ADDISON W. J. Courthope.
BACON R. W. Church.
COLERIDGE H. D. Traill.
SIR PHILIP SIDNEY J. A. Symonds.
KEATS Sidney Colvin.
CARLYLE John Nichol.
12mo, Cloth, 75 cents per volume.
_Other volumes in preparation._
PUBLISHED BY HARPER & BROTHERS, NEW YORK.
_Any of the above works will be sent by mail, postage prepaid, to any part
of the United States, Canada, or Mexico, on receipt of the price._
* * * * *
CONTENTS.
PAGE
CHAPTER I.
THE STATE OF ENGLISH SOCIETY AND LETTERS
AFTER THE RESTORATION 1
CHAPTER II.
ADDISON'S FAMILY AND EDUCATION 21
CHAPTER III.
ADDISON ON HIS TRAVELS 38
CHAPTER IV.
HIS EMPLOYMENT IN AFFAIRS OF STATE 53
CHAPTER V.
THE "TATLER" AND "SPECTATOR" 78
CHAPTER VI.
"CATO" 110
CHAPTER VII.
ADDISON'S QUARREL WITH POPE 125
CHAPTER VIII.
THE LAST YEARS OF HIS LIFE 139
CHAPTER IX.
THE GENIUS OF ADDISON 153
ADDISON.
CHAPTER I.
THE STATE OF ENGLISH SOCIETY AND LETTERS AFTER THE RESTORATION.
Of the four English men of letters whose writings most fully embody the
spirit of the eighteenth century, the one who provides the biographer with
the scantiest materials is Addison. In his _Journal to Stella_, his social
verses, and his letters to his friends, we have a vivid picture of those
relations with women and that protracted suffering which invest with such
tragic interest the history of Swift. Pope, by the publication of his own
correspondence, has enabled us, in a way that he never intended, to
understand the strange moral twist which distorted a nature by no means
devoid of noble instincts. Johnson was fortunate in the companionship of
perhaps the best biographer who ever lived. But of the real life and
character of Addison scarcely any contemporary record remains. The formal
narrative prefixed to his works by Tickell is, by that writer's own
admission, little more than a bibliography. Steele, who might have told us
more than any man about his boyhood and his manner of life in London, had
become estranged from his old friend before his death. No writer has
taken the trouble to preserve any account of the wit and wisdom that
enlivened the "little senate" at Button's. His own letters are, as a rule,
compositions as finished as his papers in the _Spectator_. Those features
in his character which excite the greatest interest have been delineated
by the hand of an enemy--an enemy who possessed an unrivalled power of
satirical portrait-painting, and was restrained by no regard for truth
from creating in the public mind such impressions about others as might
serve to heighten the favourable opinion of himself.
This absence of dramatic incident in Addison's life would lead us
naturally to conclude that he was deficient in the energy and passion
which cause a powerful nature to leave a mark upon its age. Yet such a
judgment would certainly be erroneous. Shy and reserved as he was, the
unanimous verdict of his most illustrious contemporaries is decisive as to
the respect and admiration which he excited among them. The man who could
exert so potent an influence over the mercurial Steele, who could
fascinate the haughty and cynical intellect of Swift, whose conversation,
by the admission of his satirist Pope, had in it something more charming
than that of any other man; of whom it was said that he might have been
chosen king if he wished it; such a man, though to the coarse perception
of Mandeville he might have seemed no more than "a parson in a tye-wig,"
can hardly have been deficient in force of character.
Nor would it have been possible for a writer distinguished by mere
elegance and refinement to leave a lasting impress on the literature and
society of his country. In one generation after another, men representing
opposing elements of rank, class, interest, and taste, have agreed in
acknowledging Addison's extraordinary merits. "Whoever wishes," says
Johnson--at the end of a biography strongly coloured with the
prepossessions of a semi-Jacobite Tory--"whoever wishes to attain an
English style, familiar but not coarse, and elegant but not ostentatious,
must give his days and nights to the volumes of Addison." "Such a mark of
national respect," says Macaulay, the best representative of middle-class
opinion in the present century, speaking of the statue erected to Addison
in Westminster Abbey, "was due to the unsullied statesman, to the
accomplished scholar, to the master of pure English eloquence, to the
consummate painter of life and manners. It was due, above all, to the
great satirist who alone knew how to use ridicule without abusing it; who,
without inflicting a wound, effected a great social reform, and who
reconciled wit and virtue after a long and disastrous separation, during
which wit had been led astray by profligacy, and virtue by fanaticism."
This verdict of a great critic is accepted by an age to which the grounds
of it are, perhaps, not very apparent. The author of any ideal creation--a
poem, a drama, or a novel--has an imprescriptible property in the fame of
his work. But to harmonise conflicting social elements, to bring order out
of chaos in the sphere of criticism, to form right ways of thinking about
questions of morals, taste, and breeding, are operations of which the
credit, though it is certainly to be ascribed to particular individuals,
is generally absorbed by society itself. Macaulay's eulogy is as just as
it is eloquent, but the pages of the _Spectator_ alone will hardly show
the reader why Addison should be so highly praised for having reconciled
wit with virtue. Nor, looking at him as a critic, will it appear a great
achievement to have pointed out to English society the beauties of
_Paradise Lost_, unless it be remembered that the taste of the preceding
generation still influenced Addison's contemporaries, and that in that
generation Cowley was accounted a greater poet than Milton.
To estimate Addison at his real value we must regard him as the chief
architect of Public Opinion in the eighteenth century. But here again we
are met by an initial difficulty, because it has become almost a
commonplace of contemporary criticism to represent the eighteenth century
as a period of sheer destruction. It is tacitly assumed by a school of
distinguished philosophical writers that we have arrived at a stage in the
world's history in which it is possible to take a positive and scientific
view of human affairs. As it is of course necessary that from such a
system all belief in the supernatural shall be jealously excluded, it has
not seemed impossible to write the history of Thought itself in the
eighteenth century. And in tracing the course of this supposed continuous
stream it is natural that all the great English writers of the period
should be described as in one way or another helping to pull down, or
vainly to strengthen, the theological barriers erected by centuries of
bigotry against the irresistible tide of enlightened progress.
It would be of course entirely out of place to discuss here the merits of
this new school of history. Those who consider that, whatever glimpses we
may obtain of the law and order of the universe, man is, as he always has
been and always will be, a mystery to himself, will hardly allow that the
operations of the human spirit can be traced in the dissecting-room. But
it is, in any case, obvious that to treat the great _imaginative_ writers
of any age as if they were only mechanical agents in an evolution of
thought is to do them grave injustice. Such writers are, above all things,
creative. Their first aim is to "show the very age and body of the time
his form and pressure." No work of the eighteenth century, composed in a
consciously destructive spirit, has taken its place among the acknowledged
classics of the language. Even the _Tale of a Tub_ is to be regarded as a
satire upon the aberrations of theologians from right reason, not upon the
principles of Christianity itself. The _Essay on Man_ has, no doubt,
logically a tendency towards Deism, but nobody ever read the poem for the
sake of its philosophy; and it is well known that Pope was much alarmed
when it was pointed out to him that his conclusions might be represented
as incompatible with the doctrines of revealed religion.
The truth indeed seems to be the exact converse of what is alleged by the
scientific historians. So far from the eighteenth century in England being
an age of destructive analysis, its energies were chiefly devoted to
political, social, and literary reconstruction. Whatever revolution in
faith and manners the English nation had undergone had been the work of
the two preceding centuries, and though the historic foundations of
society remained untouched, the whole form of the superstructure had been
profoundly modified.
"So tenacious are we," said Burke, towards the close of the last
century, "of our old ecclesiastical modes and fashions of institution
that very little change has been made in them since the fourteenth or
fifteenth centuries, adhering in this particular as in all else to our
old settled maxim never entirely nor at once to depart from antiquity.
We found these institutions on the whole favourable to morality and
discipline, and we thought they were susceptible of amendment without
altering the ground. We thought they were capable of receiving and
meliorating, and, above all, of preserving the accessories of science
and literature as the order of Providence should successively produce
them. And after all, with this Gothic and monkish education (for such
it is the groundwork), we may put in our claim to as ample and early
a share in all the improvements in science, in arts, and in literature
which have illuminated the modern world as any other nation in Europe.
We think one main cause of this improvement was our not despising the
patrimony of knowledge which was left us by our forefathers."
All this is, in substance, true of our political as well as our
ecclesiastical institutions. And yet, when Burke wrote, the great feudal
and mediæval structure of England had been so transformed by the Wars of
the Roses, the Reformation, the Rebellion, and the Revolution, that its
ancient outlines were barely visible. In so far, therefore, as his words
seem to imply that the social evolution he describes was produced by an
imperceptible and almost mechanical process of national instinct, the
impression they tend to create is entirely erroneous.
If we have been hitherto saved from such corruption as undermined the
republics of Italy, from the religious wars that so long enfeebled and
divided Germany, and from the Revolution that has severed modern France
from her ancient history, thanks for this are due partly, no doubt, to
favouring conditions of nature and society, but quite as much to the
genius of great individuals who prepared the mind of the nation for the
gradual assimilation of new ideas. Thus Langland and Wycliffe and their
numerous followers, long before the Reformation, had so familiarised the
minds of the people with their ideas of the Christian religion that the
Sovereign was able to assume the Headship of the Church without the shock
of a social convulsion. Fresh feelings and instincts grew up in the hearts
of whole classes of the nation without at first producing any change in
outward habits of life, and even without arousing a sense of their logical
incongruity. These mixed ideas were constantly brought before the
imagination in the works of the poets. Shakespeare abounds with passages
in which, side by side with the old feudal, monarchical, catholic, and
patriotic instincts of Englishmen, we find the sentiments of the Italian
Renaissance. Spenser conveys Puritan doctrines sometimes by the mouth of
shepherds, whose originals he had found in Theocritus and Virgil;
sometimes under allegorical forms derived from books of chivalry and the
ceremonial of the Catholic Church. Milton, the most rigidly Calvinistic of
all the English poets in his opinions, is also the most severely classical
in his style.
It was the task of Addison to carry on the reconciling traditions of our
literature. It is his praise to have accomplished his task under
conditions far more difficult than any that his predecessors had
experienced. What they had done was to give instinctive and characteristic
expression to the floating ideas of the society about them; what Addison
and his contemporaries did was to found a public opinion by a conscious
effort of reason and persuasion. Before the Civil Wars there had been at
least no visible breach in the principle of Authority in Church and State.
At the beginning of the eighteenth century constituted authority had been
recently overthrown; one king had been beheaded, another had been
expelled; the Episcopalian form of Church Government had been violently
displaced in favour of the Presbyterian, and had been with almost equal
violence restored. Whole classes of the population had been drawn into
opposing camps during the Civil War, and still stood confronting each
other with all the harsh antagonism of sentiment inherited from that
conflict. Such a bare summary alone is sufficient to indicate the nature
of the difficulties Addison had to encounter in his efforts to harmonise
public opinion; but a more detailed examination of the state of society
after the Restoration is required to place in its full light the
extraordinary merits of the success that he achieved.
There was, to begin with, a vehement opposition between town and country.
In the country the old ideas of Feudalism, modified by circumstances, but
vigorous and deep-rooted, still prevailed. True, the military system of
land-tenure had disappeared with the Restoration, but it was not so with
the relations of life, and the habits of thought and feeling which the
system had created. The features of surviving Feudalism have been
inimitably preserved for us in the character of Sir Roger de Coverley.
Living in the patriarchal fashion, in the midst of tenants and retainers,
who looked up to him as their chief, and for whose welfare and protection
he considered himself responsible, the country gentleman valued above all
things the principle of Loyalty. To the moneyed classes in the towns he
was instinctively opposed; he regarded their interests, both social and
commercial, as contrary to his own; he looked with dislike and suspicion
on the economical principles of government and conduct on which these
classes naturally rely. Even the younger sons of county families had in
Addison's day abandoned the custom, common enough in the feudal times, of
seeking their fortune in trade. Many a Will Wimble now spent his whole
life in the country, training dogs for his neighbours, fishing their
streams, making whips for their young heirs, and even garters for their
wives and daughters.[1]

8
data/sakana_wiki.txt Normal file
View file

@ -0,0 +1,8 @@
Sakana AI Co, Ltd. is a Japanese artificial intelligence company based in Tokyo.
Overview
Sakana AI's main research fields are evolution and collective intelligence of AI. The company's name is derived from the Japanese word さかな (sakana), which means fish. This represents the idea of a school of fish coming together and forming a coherent entity from simple rules, which is an analogy of collective intelligence.[2]
The company was founded by David Ha, Llion Jones and Ren Ito. Llion Jones co-authored the famous paper "Attention Is All You Need" when he was working for Google in 2017. The company raised $30M in its seed funding round from Lux Capital and Khosla Ventures.[3] The company raised approximately $200M from companies such as Mitsubishi UFJ, SMBC, Mizuho, Itochu, KDDI, Nomura and Nvidia in its series A funding round in 2024.[4]
In January 2024, Sakana AI developed a method to build new AI models by 'breeding' multiple existing models, which it sees as a means to democratise AI development, as this process does not require large computational resources.[5] Sakana AI is also developing a model called the AI Scientist, which automates the entire process of scientific research.[6] The Nikkei estimated the company's value at 19 billion yen in 2024.[7]

620
data/self_generate_qa.py Normal file
View file

@ -0,0 +1,620 @@
import argparse
import os
import random
import re
from glob import glob
import numpy as np
import yaml
from datasets import Dataset, load_dataset
from vllm import LLM, SamplingParams
from ctx_to_lora.data.definitions import (
CLOSED_QA_INTX_TEMPLATES,
RAW_DATA_DIR,
SELF_GEN_DATA_DIR,
)
from ctx_to_lora.data.processing import (
filter_none,
get_preprocessing_fn,
load_and_process_dataset,
tokenize_ctx_text,
)
from ctx_to_lora.data.self_gen_template import (
PRE_CTX,
PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE,
SELF_GEN_SYSTEM_MSG,
SELF_QA_INTX,
)
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.utils import clear_gpu
STOP_STRINGS = {
"google/gemma-2-2b-it": ["<eos>", "<end_of_turn>"],
}
MODEL_CTX_LEN = {
"google/gemma-2-27b-it": 8192,
"google/gemma-2-2b-it": 8192,
"google/gemma-2-9b-it": 8192,
# qwen 4b has 256k ctx length but using lower max lengths is faster
"Qwen/Qwen3-4B-Instruct-2507": 2**13 + 2**12,
}
def truncate_middle_if_too_long(
input_ids: list[int],
max_length: int,
max_new_tokens: int = 256,
) -> list[int]:
"""
Truncate the middle of a list of tokens to fit within a maximum length.
Args:
tokens: List of token IDs
max_length: Maximum length for the truncated tokens
Returns:
List of truncated token IDs
"""
max_new_tokens_half = max_new_tokens // 2
# leave max_new_tokens for generation
half = max_length // 2 - max_new_tokens_half
if len(input_ids) > max_length:
return input_ids[:half] + input_ids[-half:]
return input_ids
def get_prompt(context: str, q: str, remove_qa_template: bool) -> str:
prompt = QA_PROMPT_TEMPLATE if not remove_qa_template else PROMPT_TEMPLATE
return prompt.format(context=context, question=q)
def add_closed_qa_prompt(q: str, closed_qa_prob: float = 0.1) -> str:
if random.random() <= closed_qa_prob:
q = random.choice(CLOSED_QA_INTX_TEMPLATES).format(input=q)
return q
def load_config(config_path: str) -> dict:
"""Load dataset names from YAML config file."""
with open(config_path) as f:
config = yaml.safe_load(f)
return config
def get_dataset_configs(
ds_names: list[str] | None,
config: dict | None,
split: str | None,
) -> list[tuple[str, str]]:
assert not (ds_names and config), "Cannot provide both ds_names and config"
if ds_names:
assert split, "When using ds_names, --split must be provided"
# Validate ds_names format
for ds_name in ds_names:
if not isinstance(ds_name, str):
raise ValueError(f"Invalid dataset name: {ds_name}")
return [(ds_name, split) for ds_name in ds_names]
if config:
dataset_configs = []
# Process train datasets
train_ds_names = config.get("train_ds_names", [])
# self_gen_train_ds_names = [
# (ds_name.split("/")[-1], "train")
# for ds_name in train_ds_names
# if ds_name.startswith("self_gen/")
# ]
self_gen_train_ds_names = [
(ds_name, "train")
for ds_name in train_ds_names
if ds_name.startswith("self_gen/")
]
if not self_gen_train_ds_names:
print("No self_gen datasets found in train_ds_names")
dataset_configs.extend(self_gen_train_ds_names)
# Process validation datasets
val_ds_names = config.get("val_ds_names", [])
self_gen_val_ds_names = [
(ds_name, "validation")
for ds_name in val_ds_names
if ds_name.startswith("self_gen/")
]
if not self_gen_val_ds_names:
print("No self_gen datasets found in val_ds_names")
dataset_configs.extend(self_gen_val_ds_names)
return dataset_configs
def create_messages(
ctxs: list[str],
questions: list[list[str]],
vllm_model: str,
system_template: str,
remove_qa_template: bool,
) -> list[list[dict]]:
"""Create chat messages for the model."""
# if "gemma" in vllm_model:
# gemma models do not support system messages
return [
[
{
"role": "user",
"content": (
system_template + "\n\n\n" + get_prompt(ctx, q, remove_qa_template)
).strip(),
}
]
for ctx, q_list in zip(ctxs, questions)
for q in q_list
]
# else:
# return [
# [
# {"role": "system", "content": system_template},
# {"role": "user", "content": get_prompt(ctx, q)},
# ]
# for ctx, q_list in zip(ctxs, questions)
# for q in q_list
# ]
def self_generate(
ds_name: str,
split: str,
args: argparse.Namespace,
llm: LLM,
system_template: str,
parquet_file: str | None = None,
do_truncate: bool = False,
) -> None:
"""Process a single dataset and generate QA pairs."""
shard_name = ""
# Conflict checks for ds_name-derived overrides
if ds_name is not None:
# temperature & closed_qa already handled later; add new ones
if "_temp_" in ds_name and args.temp != 0.0:
raise ValueError(
f"Multiple sources of truth for temperature: CLI arg --temp={args.temp} and dataset name contains temp specification."
)
if "_closed_qa_prob_" in ds_name and args.closed_qa_prob != 0.0:
raise ValueError(
f"Multiple sources of truth for closed_qa_prob: CLI arg --closed_qa_prob={args.closed_qa_prob} and dataset name contains closed_qa_prob specification."
)
# Base values from args
temp = args.temp
closed_qa_prob = args.closed_qa_prob
# Overrides from ds_name pattern if present
if ds_name is not None:
if "_temp_" in ds_name:
m = re.search(r"_temp_([\d.]+)", ds_name)
if m:
temp = float(m.group(1))
if "_closed_qa_prob_" in ds_name:
m = re.search(r"_closed_qa_prob_([\d.]+)", ds_name)
if m:
closed_qa_prob = float(m.group(1))
print(f"Processing dataset: {ds_name}, split: {split}")
print(f"Using temperature: {temp}")
print(f"Using closed QA prompt probability: {closed_qa_prob}")
if parquet_file:
print(f"Loading dataset from parquet file: {parquet_file}")
split = "train"
ds_name = "/".join(parquet_file.split(RAW_DATA_DIR)[-1].split("/")[:-1])
shard_name = "_" + os.path.basename(parquet_file).replace(".parquet", "")
ds = load_dataset(path="parquet", data_files=[parquet_file], split="train")
processing_fn = get_preprocessing_fn(ds_name, is_eval=False)
ds = ds.map(processing_fn, num_proc=8)
else:
ds_name = ds_name.split("/")[-1] # Extract just the dataset name
print(f"Loading dataset: {ds_name} with split: {split}")
kwargs = dict(ds_name=ds_name, split=split)
ds = load_and_process_dataset(**kwargs, num_proc=8, remove_cols=False)
print(f"Loaded dataset: {ds_name} with split: {split}")
if args.debug:
ds = ds.take(10)
ds = ds.filter(filter_none, batched=False, num_proc=8)
tk = get_tokenizer(args.vllm_model, train=True)
self_qa_intx_tokens = tk(SELF_QA_INTX, add_special_tokens=False)["input_ids"][1:]
if args.remove_qa_template:
self_qa_intx_tokens = tk("\n\n", add_special_tokens=False)["input_ids"]
n_self_qa_intx_tokens = len(self_qa_intx_tokens)
pre_ctx_tokens = tk(PRE_CTX, add_special_tokens=False)["input_ids"]
n_pre_ctx_tokens = len(pre_ctx_tokens)
sys_tokens = tk(system_template.split("\n")[0], add_special_tokens=False)[
"input_ids"
][:-1]
n_sys_tokens = len(sys_tokens)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
ds = ds.map(
tokenize_ctx_text,
fn_kwargs={"tokenizer": tk},
batched=True,
batch_size=50_000,
keep_in_memory=True,
)
ctxs = [sample["context"] for sample in ds]
questions = [
[add_closed_qa_prompt(q, closed_qa_prob) for q in sample["prompts"] if q]
for sample in ds
]
questions = [q_list for q_list in ds["prompts"] if len(q_list) > 0]
print(f"Loaded {len(ctxs)} contexts and {len(questions)} questions")
k = 16
fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
chunk_size = 1_000
for chunk_idx, start in enumerate(range(0, len(ctxs), chunk_size)):
print(f"Processing chunk {chunk_idx}")
chunk_ctxs = ctxs[start : start + chunk_size]
chunk_questions = questions[start : start + chunk_size]
chunk_messages = create_messages(
chunk_ctxs,
chunk_questions,
args.vllm_model,
SELF_GEN_SYSTEM_MSG,
args.remove_qa_template,
)
if do_truncate:
# we should only do this for evaluation data
tokenized_contents = tk(
[m[0]["content"] for m in chunk_messages],
add_special_tokens=False,
return_attention_mask=False,
)
tokenized_contents["input_ids"] = [
truncate_middle_if_too_long(
ids,
max_length=MODEL_CTX_LEN[args.vllm_model],
max_new_tokens=args.max_new_tokens,
)
for ids in tokenized_contents["input_ids"]
]
contents = tk.batch_decode(
tokenized_contents["input_ids"], skip_special_tokens=True
)
for c, m in zip(contents, chunk_messages):
m[0]["content"] = c
print(f"Generating from {len(chunk_messages)} contexts")
# Clear GPU memory before processing the next chunk
clear_gpu()
execute_qa_generation(
fpath + f"_{chunk_idx:04d}",
args,
llm,
temp,
tk,
self_qa_intx_tokens,
n_self_qa_intx_tokens,
sys_tokens,
n_sys_tokens,
chunk_ctxs,
ds[start : start + chunk_size]["ctx_ids"],
chunk_questions,
chunk_messages,
k,
)
def execute_qa_generation(
fpath,
args,
llm,
temp,
tk,
self_qa_intx_tokens,
n_self_qa_intx_tokens,
sys_tokens,
n_sys_tokens,
ctxs,
ctx_ids,
questions,
messages,
k,
):
completions = llm.chat(
messages,
sampling_params=SamplingParams(
max_tokens=args.max_new_tokens,
logprobs=k,
temperature=temp,
seed=42,
spaces_between_special_tokens=False,
skip_special_tokens=False,
include_stop_str_in_output=True,
),
)
self_gen_data = {
ctx: {
"ctx_ids": ctx_ids,
"input_ids": [],
"response_start_end": [],
"logprobs_vals": [],
"logprobs_indices": [],
}
for ctx, ctx_ids in zip(ctxs, ctx_ids)
}
c = 0
n_skips = 0
sys_start = None
for ctx, q_list in zip(ctxs, questions):
# self_gen_data[ctx]["ctx_ids"] = ctx_ids
for i, _ in enumerate(q_list):
# response = completions[c + i].outputs[0].text
reason = completions[c + i].outputs[0].finish_reason
if reason != "stop":
# print(f"idx: {c + i}")
print(f"finish_reason: {completions[c + i].outputs[0].finish_reason}")
print(f"Skipping due to finish_reason={reason} != 'stop'")
n_skips += 1
continue
# includes the logprob before the first response token
# but excludes the logprob from eos token
logp = completions[c + i].outputs[0].logprobs
# len = num response tokens
n_response_tokens = len(completions[c + i].outputs[0].token_ids)
logp_indices = np.empty((n_response_tokens, k), dtype=np.int32)
# float-16 is better for this range
logp_vals = np.empty((n_response_tokens, k), dtype=np.float16)
assert len(logp) == n_response_tokens, (
f"Expected {n_response_tokens} logp entries, got {len(logp)}"
)
for li, info_d in enumerate(logp):
for j, (idx, tok_info) in enumerate(info_d.items()):
logp_indices[li, j] = idx
logp_vals[li, j] = tok_info.logprob
prompt_ids = completions[c + i].prompt_token_ids # 1d list
# token_ids only includes generated tokens, not the prompt
response_token_ids = completions[c + i].outputs[0].token_ids # 1d list
all_ids = prompt_ids + response_token_ids
res_start = len(prompt_ids)
res_end = res_start + n_response_tokens
if sys_start is None:
for ii in range(len(prompt_ids) - n_sys_tokens):
if prompt_ids[ii : ii + n_sys_tokens] == sys_tokens:
# found the start of the system message
sys_start = ii
break
q_start = None
for ii in range(
len(prompt_ids) - n_self_qa_intx_tokens,
-1,
-1,
):
if prompt_ids[ii : ii + n_self_qa_intx_tokens] == self_qa_intx_tokens:
# found the start of the user input
q_start = ii + n_self_qa_intx_tokens
break
# bos + question + eos + start model turn + response + eos
input_ids = all_ids[:sys_start] + all_ids[q_start:res_end]
# relative to the input_ids
res_start = res_start - q_start + sys_start
res_end = res_start + n_response_tokens
# arrays will be saved as nested lists of numbers
self_gen_data[ctx]["input_ids"].append(input_ids)
# assume single-turn chat
self_gen_data[ctx]["response_start_end"].append((res_start, res_end))
self_gen_data[ctx]["logprobs_vals"].append(logp_vals)
self_gen_data[ctx]["logprobs_indices"].append(logp_indices)
c += i + 1
print(f"Skipped {n_skips} responses due to missing stop strings")
samples = [
{
# "context": ctx,
# "prompts": q_list,
# "responses": self_gen_data[ctx]["responses"],
"ctx_ids": self_gen_data[ctx]["ctx_ids"],
"input_ids": self_gen_data[ctx]["input_ids"],
"response_start_end": self_gen_data[ctx]["response_start_end"],
# "prompt_start_end": self_gen_data[ctx]["prompt_start_end"],
"logprobs_vals": self_gen_data[ctx]["logprobs_vals"],
"logprobs_indices": self_gen_data[ctx]["logprobs_indices"],
}
for ctx, q_list in zip(ctxs, questions)
]
if args.debug:
for sample in samples:
# print(f"context={tk.decode(sample['ctx_ids'])}")
print(f"QA={[tk.decode(ids) for ids in sample['input_ids']]}")
for input_ids, (start, end) in zip(
sample["input_ids"], sample["response_start_end"]
):
print(f"start={start}, end={end}")
print(f"response={tk.decode(input_ids[start:end])}")
print(f"logprobs_vals={[x.shape for x in sample['logprobs_vals']]}")
print(f"logprobs_indices={[x.shape for x in sample['logprobs_indices']]}")
for indices in sample["logprobs_indices"]:
print(f"logprobs_indices={indices[-1]}")
print("=" * 80)
print(f"Generated {len(samples)} samples")
# random.shuffle(samples)
# Save results
# df = pd.DataFrame(samples)
# ds_out = Dataset.from_pandas(df)
ds_out = Dataset.from_list(samples)
# fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
if args.debug:
fpath += "_debug"
os.makedirs(os.path.dirname(fpath), exist_ok=True)
fpath = f"{fpath}.parquet"
ds_out.to_parquet(fpath)
print(f"Saved to {fpath}")
# Cleanup
del samples, ds_out, completions, messages, ctxs, questions
clear_gpu()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate QA pairs using VLLM")
parser.add_argument(
"--vllm_model",
type=str,
required=True,
help="VLLM model name (e.g., google/gemma-2-2b-it)",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode (process only 10 samples)",
)
# Either config file OR ds_names + split
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--config",
type=str,
help="Path to YAML config file with train_ds_names/val_ds_names",
)
group.add_argument(
"--ds_names",
type=str,
nargs="+",
help="List of dataset names/shard patterns",
)
group.add_argument(
"--glob_pattern",
type=str,
help="Glob pattern to match dataset names (e.g., 'data/raw_datasets/fw_qa_3/*')",
)
parser.add_argument(
"--split",
type=str,
help="Dataset split to use when using --ds_names (required with --ds_names)",
)
parser.add_argument(
"--temp",
type=float,
default=0.0,
help="Temperature for sampling (default: 0.0)",
)
parser.add_argument(
"--closed_qa_prob",
type=float,
default=0.0,
help="Probability of using closed QA prompt template (default: 0.0)",
)
parser.add_argument(
"--do_truncate",
action="store_true",
help="Truncate contexts to fit model context length",
)
parser.add_argument(
"--remove_qa_template",
action="store_true",
help="Remove QA template formatting from prompts",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=256,
help="Maximum number of new tokens to generate (default: 256)",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Validate arguments
if args.ds_names and not args.split:
raise ValueError("--split is required when using --ds_names")
vllm_model = args.vllm_model
print(f"Using model: {vllm_model}")
# Setup model-specific configurations
llm_kwargs = dict(
model=vllm_model,
dtype="bfloat16",
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_model_len=MODEL_CTX_LEN.get(vllm_model),
max_num_batched_tokens=16384,
max_num_seqs=32, # avoid oom when getting logprobs
)
print(f"{llm_kwargs=}")
llm = LLM(**llm_kwargs)
# Get dataset configs from config or CLI args
config = load_config(args.config) if args.config else None
if args.ds_names or args.config:
dataset_configs = get_dataset_configs(
ds_names=args.ds_names,
config=config,
split=args.split,
)
# Process each dataset
for ds_name, split in dataset_configs:
print(f"Processing dataset: {ds_name}, split: {split}")
self_generate(
ds_name, split, args, llm, SELF_GEN_SYSTEM_MSG, None, args.do_truncate
)
else:
assert args.glob_pattern, (
"glob_pattern must be provided if no ds_names or config"
)
files = glob(args.glob_pattern)
for file in files:
print(f"Processing file: {file}")
self_generate(
ds_name=None,
parquet_file=file,
split=args.split,
args=args,
llm=llm,
system_template=SELF_GEN_SYSTEM_MSG,
do_truncate=args.do_truncate,
)

685
demo/app.py Normal file
View file

@ -0,0 +1,685 @@
import os
import sys
from pathlib import Path
import gradio as gr
import torch
# Add the src directory to the path
sys.path.insert(0, str(Path(__file__).parent.parent))
from ctx_to_lora.data.processing import tokenize_ctx_text
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling import hypernet
sys.modules["ctx_to_lora.modeling_utils"] = hypernet
# Global state
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modulated_model = None
chat_history = []
ctx_tokenizer = None
base_tokenizer = None
try:
DEFAULT_CONTEXT = Path("data/sakana_wiki.txt").read_text(encoding="utf-8").strip()
except FileNotFoundError:
DEFAULT_CONTEXT = ""
WARNING_MESSAGE = (
"⚠️ **Caution**: This is an educational proof-of-concept demonstration.\n"
"The model may generate inaccurate information or hallucinate facts."
)
FOOTER = """
This model is an experimental prototype and is only available for educational and research and development purposes. It is not suitable for commercial use or in environments where failure can have significant effects (mission-critical environments).
The use of this model is at the user's own risk and its performance and results is not guaranteed in any way.
Sakana AI is not responsible for any direct or indirect loss resulting from using this model, regardless of the outcome.
"""
def load_custom_chat_template(tokenizer, model_name):
if "gemma" in model_name.lower():
template_path = "chat_templates/google/gemma-2-2b-it.jinja"
if os.path.exists(template_path):
with open(template_path) as f:
template_content = f.read()
tokenizer.chat_template = template_content
print(f"Loaded custom chat template from {template_path}")
return True
return False
def get_available_checkpoints():
trained_d2l_checkpoints = {
str(path)
for path in Path().glob("trained_d2l/**/pytorch_model.bin")
if path.is_file()
}
run_output_checkpoints = {
str(path)
for path in Path().glob("train_outputs/runs/**/pytorch_model.bin")
if path.is_file()
}
checkpoints = sorted(trained_d2l_checkpoints) + sorted(
run_output_checkpoints - trained_d2l_checkpoints
)
return checkpoints if checkpoints else ["No checkpoints found"]
def load_checkpoint(
checkpoint_path: str,
) -> tuple[str, gr.update, gr.update, gr.update, gr.update]:
global modulated_model, ctx_tokenizer, base_tokenizer, chat_history
if not checkpoint_path or checkpoint_path == "No checkpoints found":
return (
"⚠️ Please select a valid checkpoint",
gr.update(),
gr.update(),
gr.update(),
gr.update(),
)
try:
print(f"Loading checkpoint: {checkpoint_path}")
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
state_dict = torch.load(checkpoint_path, weights_only=False)
modulated_model = ModulatedPretrainedModel.from_state_dict(
state_dict,
train=False,
use_flash_attn=True,
use_sequence_packing=False,
)
modulated_model = modulated_model.to(device).to(torch.bfloat16)
modulated_model.eval()
ctx_encoder_model_name_or_path = (
modulated_model.ctx_encoder_args.ctx_encoder_model_name_or_path
or modulated_model.base_model.config.name_or_path
)
ctx_tokenizer = get_tokenizer(ctx_encoder_model_name_or_path, train=False)
base_tokenizer = get_tokenizer(
modulated_model.base_model.config.name_or_path, train=False
)
load_custom_chat_template(
base_tokenizer, modulated_model.base_model.config.name_or_path
)
chat_history = [{"role": "system", "content": ""}]
model_name = modulated_model.base_model.config.name_or_path
success_msg = (
f"✅ Successfully loaded checkpoint!\n\nBase Model: {model_name}\n\n"
"You can now add context and start chatting."
)
return (
success_msg,
gr.update(interactive=True), # msg
gr.update(interactive=True), # send_btn
gr.update(interactive=True), # system_msg
gr.update(interactive=True), # clear_btn
)
except Exception as e:
import traceback
error_msg = (
f"❌ Error loading checkpoint:\n{str(e)}\n\n{traceback.format_exc()}"
)
print(error_msg)
return (
error_msg,
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
gr.update(interactive=False),
)
def process_context(context: str) -> dict:
context = context.strip() if context else ""
tokenized_contexts = tokenize_ctx_text({"context": [context]}, ctx_tokenizer)
ctx_ids = tokenized_contexts["ctx_ids"]
ctx_ids = [
torch.tensor(ctx_id, dtype=torch.long, device=device) for ctx_id in ctx_ids
]
ctx_attn_mask = [torch.ones_like(ids) for ids in ctx_ids]
ctx_attn_mask = [
torch.tensor(mask, dtype=torch.long, device=device) for mask in ctx_attn_mask
]
ctx_ids = torch.nn.utils.rnn.pad_sequence(
ctx_ids,
batch_first=True,
padding_value=0,
)
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
ctx_attn_mask,
batch_first=True,
padding_value=0,
)
return {"ctx_ids": ctx_ids, "ctx_attn_mask": ctx_attn_mask}
def add_user_message(message: str, history):
if not message.strip():
return history, ""
return history + [[message, None]], ""
def generate_response(
history: list[list[str]],
system_msg: str,
context: str,
context_scaler: float,
bias_scaler: float,
):
global modulated_model, chat_history, ctx_tokenizer, base_tokenizer
if modulated_model is None:
history[-1][1] = "Please load a checkpoint first."
yield history
return
if not history or history[-1][0] is None:
yield history
return
try:
user_message = history[-1][0]
if system_msg.strip() and chat_history[0]["role"] == "system":
chat_history[0]["content"] = system_msg.strip()
chat_history.append({"role": "user", "content": user_message})
context = context.strip() if context else ""
print(f"Processing single context with scaler: {context_scaler}")
print(f"Bias scaler: {bias_scaler}")
with torch.inference_mode(), torch.amp.autocast(str(device)):
ctx_inputs = process_context(context)
ctx_ids = ctx_inputs["ctx_ids"].to(device)
ctx_attn_mask = ctx_inputs["ctx_attn_mask"].to(device)
scalers_tensor = torch.tensor(
[context_scaler], dtype=torch.float32, device=device
)
model_inputs = base_tokenizer.apply_chat_template(
chat_history, return_tensors="pt", add_generation_prompt=True
).to(device)
print(f"Context: {context}")
print(f"Chat history: {chat_history}")
outputs = modulated_model.generate(
ctx_ids=ctx_ids,
ctx_attn_mask=ctx_attn_mask,
n_ctx_chunks=torch.tensor([len(ctx_ids)], device=ctx_ids.device),
scalers=scalers_tensor,
bias_scaler=bias_scaler,
input_ids=model_inputs,
max_new_tokens=512,
do_sample=False,
temperature=0,
)
response = base_tokenizer.decode(
outputs[0][model_inputs.shape[1] :], skip_special_tokens=True
)
chat_history.append({"role": "assistant", "content": response})
words = response.split()
partial_response = ""
for word in words:
partial_response += word + " "
history[-1][1] = partial_response.strip()
yield history
history[-1][1] = response
yield history
except Exception as e:
import traceback
error_msg = f"❌ Error: {str(e)}"
print(f"Error generating response: {str(e)}\n\n{traceback.format_exc()}")
history[-1][1] = error_msg
yield history
def reset_chat(system_msg: str):
global chat_history
chat_history = [
{"role": "system", "content": system_msg.strip() if system_msg else ""}
]
return [[None, WARNING_MESSAGE]], "Chat history reset successfully!"
custom_css = """
:root {
color-scheme: light;
}
.gradio-container {
font-family: 'Inter', sans-serif;
}
.chat-container {
border-radius: 10px;
border: 2px solid #d1d5db;
}
.context-field {
border-radius: 8px;
padding: 15px;
margin-bottom: 10px;
border: 2px solid #d1d5db;
}
.status-box {
border-radius: 8px;
padding: 15px;
margin: 10px 0;
border: 2px solid #e5e7eb;
}
.primary-button {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: none;
border-radius: 6px;
color: white;
font-weight: 600;
}
.secondary-button {
background-color: #607d8b;
border: none;
border-radius: 6px;
color: white;
}
#chatbot {
height: 500px;
}
.instruction-text {
font-style: italic;
color: #666;
font-size: 0.9em;
margin-bottom: 10px;
}
.warning-box {
background-color: #fff3cd;
border: 2px solid #ffc107;
border-radius: 8px;
padding: 12px;
margin: 10px 0;
color: #856404;
font-size: 0.95em;
}
.warning-box strong {
color: #d97706;
}
.disabled-overlay {
background-color: #f5f5f5;
border: 3px dashed #999;
border-radius: 10px;
padding: 20px;
text-align: center;
color: #999;
}
.chat-disabled-notice {
border: 3px solid #f59e0b;
border-radius: 8px;
padding: 15px;
margin-bottom: 15px;
color: #92400e;
font-weight: 600;
text-align: center;
background-color: #fef3c7;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
}
.internalization-banner {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border-radius: 10px;
padding: 20px;
margin-bottom: 15px;
text-align: center;
font-weight: 600;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
border: 3px solid #5568d3;
}
.internalization-banner h3 {
margin: 0 0 10px 0;
font-size: 1.2em;
color: white;
}
.internalization-banner p {
margin: 5px 0;
font-size: 0.95em;
font-weight: 400;
color: white;
}
.context-section-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: 3px solid #5568d3;
border-left: 6px solid #4c51bf;
padding: 15px;
margin-bottom: 15px;
border-radius: 6px;
color: white;
box-shadow: 0 2px 6px rgba(102, 126, 234, 0.3);
}
.context-section-header strong,
.context-section-header small {
color: white;
}
.chat-section-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
border: 3px solid #2563eb;
border-left: 6px solid #1d4ed8;
padding: 15px;
margin-bottom: 15px;
border-radius: 6px;
color: white;
box-shadow: 0 2px 6px rgba(59, 130, 246, 0.3);
}
.chat-section-header strong,
.chat-section-header small {
color: white;
}
.panel-box {
background-color: rgba(249, 250, 251, 0.5);
border: 2px solid rgba(209, 213, 219, 0.5);
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
}
.chat-panel-box {
background-color: rgba(249, 250, 251, 0.5);
border: 2px solid rgba(209, 213, 219, 0.5);
border-radius: 12px;
padding: 20px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
}
#checkpoint-dropdown {
position: relative;
z-index: 20;
}
#checkpoint-dropdown [role="listbox"] {
z-index: 9999 !important;
}
/* Dark mode support */
.dark .panel-box,
.dark .chat-panel-box {
background-color: rgba(31, 41, 55, 0.5);
border: 2px solid rgba(75, 85, 99, 0.5);
}
.dark .context-field,
.dark .chat-container {
border-color: rgba(75, 85, 99, 0.5);
}
.dark .status-box {
border-color: rgba(55, 65, 81, 0.5);
}
"""
def create_demo():
with gr.Blocks(
title="Doc-to-LoRA Chat Interface",
theme=gr.themes.Soft(),
css=custom_css,
) as demo:
gr.Markdown(
"""
# 📜 Doc-to-LoRA Chat Interface
Load a hypernetwork checkpoint and chat with a context-modulated language model.
Add one context with a scaling parameter to influence the model's responses.
"""
)
gr.HTML(
"""
<div class="internalization-banner">
<h3>🧠 How Context Internalization Works</h3>
<p>📥 Contexts are processed by the hypernetworkto dynamically modulate the base model's parameters</p>
<p>🚫 Contexts are NOT passed as text to the base model they influence behavior internally</p>
<p>💬 Only your chat messages (below) are sent to the language model</p>
</div>
"""
)
gr.Markdown(
"""
### 📖 Usage Instructions
1. **Load a Checkpoint**: Select a hypernetwork checkpoint from the dropdown and click "Load Checkpoint"
2. **Configure Context**:
- Enter your context information in the text field
- Adjust the scaling slider to control context influence
3. **Set Bias Scaler**: Adjust the bias scaler to control overall model behavior
4. **Start Chatting**: Once the model is loaded, type your message and press Shift+Enter or click Send
5. **Reset**: Use the "Reset Chat" button to start a new conversation
💡 **Tip**: You can use context to provide background information or specific knowledge
that should influence the model's responses.
"""
)
with gr.Row():
with gr.Column(scale=1, elem_classes="panel-box"):
gr.Markdown("### 📦 Load Checkpoint")
gr.Markdown(
"*Select a trained hypernetwork checkpoint to begin.*",
elem_classes="instruction-text",
)
checkpoint_dropdown = gr.Dropdown(
choices=get_available_checkpoints(),
label="Select Checkpoint",
value=None,
interactive=True,
elem_id="checkpoint-dropdown",
)
load_btn = gr.Button("Load Checkpoint", variant="primary", size="lg")
status_box = gr.Textbox(
label="Status",
lines=8,
interactive=False,
elem_classes="status-box",
)
gr.Markdown("---")
gr.HTML(
"""
<div class="context-section-header">
<strong>🧠 Context Internalization (Hypernetwork Input)</strong><br>
<small>This context modulates the model internally it is NOT shown to the base model</small>
</div>
"""
)
context = gr.Textbox(
label="🧠 Context (Internalized via Hypernetwork)",
placeholder="Enter context to be internalized by the hypernetwork...",
lines=4,
value=DEFAULT_CONTEXT,
)
context_scaler = gr.Slider(
minimum=-2.0,
maximum=2.0,
step=0.01,
value=1.0,
label="Context Scaling",
)
gr.Markdown("---")
bias_scaler = gr.Slider(
minimum=-2.0,
maximum=2.0,
step=0.01,
value=1.0,
label="Bias Scaler",
info="A single scalar applied to bias parameters (independent of contexts)",
)
with gr.Column(scale=2, elem_classes="chat-panel-box"):
gr.HTML(
"""
<div class="chat-section-header">
<strong>💬 Chat Interface (Direct Input to Base Model)</strong><br>
<small>Your messages here are the ONLY text the base model sees contexts above influence it internally</small>
</div>
"""
)
chat_status_notice = gr.HTML(
"""
<div class="chat-disabled-notice">
🔒 <strong>Chat Disabled:</strong> Please load a checkpoint first to enable chat functionality.
</div>
""",
visible=True,
)
system_msg = gr.Textbox(
label="System Message (Optional - Sent to Base Model)",
placeholder="Load a checkpoint to enable chat...",
lines=2,
interactive=False,
)
chatbot = gr.Chatbot(
label="Conversation",
show_copy_button=True,
height=500,
elem_id="chatbot",
elem_classes="chat-container",
value=[
[
None,
"🔒 Chat is currently disabled. Please load a checkpoint from the left panel to begin chatting.",
]
],
)
with gr.Row():
msg = gr.Textbox(
label="Your Message (Sent Directly to Base Model)",
placeholder="⚠️ Load a checkpoint first to start chatting...",
lines=2,
scale=4,
interactive=False,
)
send_btn = gr.Button(
"🔒 Send (Disabled)",
variant="primary",
scale=1,
interactive=False,
)
with gr.Row():
clear_btn = gr.Button(
"🔒 Reset Chat (Disabled)",
variant="secondary",
interactive=False,
)
reset_status = gr.Textbox(label="Reset Status", visible=False)
load_btn.click(
fn=load_checkpoint,
inputs=[checkpoint_dropdown],
outputs=[status_box, msg, send_btn, system_msg, clear_btn],
).then(
fn=lambda: (
gr.update(visible=False),
gr.update(
placeholder="Type your message here... (Shift+Enter for new line)"
),
gr.update(value="Send"),
gr.update(value="🔄 Reset Chat"),
gr.update(value=[[None, WARNING_MESSAGE]]),
),
outputs=[chat_status_notice, msg, send_btn, clear_btn, chatbot],
)
msg.submit(
fn=add_user_message,
inputs=[msg, chatbot],
outputs=[chatbot, msg],
).then(
fn=generate_response,
inputs=[
chatbot,
system_msg,
context,
context_scaler,
bias_scaler,
],
outputs=[chatbot],
)
send_btn.click(
fn=add_user_message,
inputs=[msg, chatbot],
outputs=[chatbot, msg],
).then(
fn=generate_response,
inputs=[
chatbot,
system_msg,
context,
context_scaler,
bias_scaler,
],
outputs=[chatbot],
)
clear_btn.click(
fn=reset_chat,
inputs=[system_msg],
outputs=[chatbot, reset_status],
)
gr.Markdown(f"---\n{FOOTER.strip()}")
return demo
if __name__ == "__main__":
demo = create_demo()
demo.launch(
server_name="0.0.0.0",
server_port=7861,
share=False,
debug=True,
)

41
examples/python_api.py Normal file
View file

@ -0,0 +1,41 @@
# caveat: this interface only supports non-batched inputs
# for batched inference please see `src/ctx_to_lora/modeling/hypernet.py`
import torch
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
# model loading
checkpoint_path = "trained_d2l/gemma_demo/checkpoint-80000/pytorch_model.bin"
state_dict = torch.load(checkpoint_path, weights_only=False)
model = ModulatedPretrainedModel.from_state_dict(
state_dict, train=False, use_sequence_packing=False
)
model.reset()
tokenizer = get_tokenizer(model.base_model.name_or_path)
# prepare data
doc = open("data/sakana_wiki.txt", "r").read()
chat = [{"role": "user", "content": "Tell me about Sakana AI."}]
chat_ids = tokenizer.apply_chat_template(
chat,
add_special_tokens=False,
return_attention_mask=False,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
# calls after internalization will be influenced by internalized info
model.internalize(doc)
outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
print(tokenizer.decode(outputs[0]))
# remove internalized info
# model.reset()
# without internalized info, the model will halucinate
# outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
# print(tokenizer.decode(outputs[0]))

24
install.sh Executable file
View file

@ -0,0 +1,24 @@
curl -LsSf https://astral.sh/uv/install.sh | sh
uv self update
uv venv --python 3.10 --seed
uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --torch-backend=cu124
uv sync
uv pip install tokenizers==0.21.0
uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
uv pip install flashinfer-python==0.2.2 -i https://flashinfer.ai/whl/cu124/torch2.6
# download squad dataset
HF_HUB_ENABLE_HF_TRANSFER=1 uv run huggingface-cli download --repo-type dataset rajpurkar/squad --local-dir data/raw_datasets/squad
uv run data/build_drop_compact.py
uv run data/build_pwc_compact.py
uv run data/build_ropes_compact.py
uv run data/build_squad_compact.py
# optional: needed for gated models
# uv run huggingface-cli login
# optional: needed for logging with wandb
# wandb login
# optional: dev
# uv run pre-commit install

86
pyproject.toml Normal file
View file

@ -0,0 +1,86 @@
[project]
name = "ctx-to-lora"
version = "0.0.1"
authors = [{ name = "Rujikorn Charakorn" }]
description = ""
readme = "README.md"
requires-python = ">= 3.10"
dependencies = [
"transformers==4.51.3",
"deepspeed==0.17.1",
"accelerate==1.6.0",
"datasets==3.6.0",
"setuptools",
"peft",
"jupyter",
"matplotlib",
"hf_transfer",
"torchmetrics",
"inflect",
"pre-commit",
"tensorboardX",
"wandb",
"fasttext-wheel",
"einops",
"jaxtyping",
"liger-kernel",
"tensorboard",
"flask",
"gradio>=4.40.0",
"pandas",
"plotly",
"rouge-score",
"vllm==0.8.5.post1",
"huggingface-hub[hf-transfer]>=0.32.0",
"opt-einsum>=3.4.0",
"kagglehub[hf-datasets]>=0.3.12",
"kaggle>=1.7.4.5",
"bitsandbytes>=0.46.1",
"google-cloud-storage>=3.2.0",
"wonderwords>=2.2.0",
"llmlingua>=0.2.2",
]
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[tool.pyright]
exclude = [
"**/node_modules",
"**/__pycache__",
"**/.*",
".venv",
".github",
".vscode",
"chat_templates",
"eval_results",
"configs",
"EditingLlama",
"icae_v2",
"lm-evaluation-harness",
"llm-comparator",
"LongBench",
"scripts",
"train_outputs",
"./data/",
"/data/",
"generated_tasks",
"outputs",
"plots",
"tmp",
"wandb",
".wandb",
".ruff_cache",
"assets",
]
typeCheckingMode = "off"
[tool.ruff]
line-length = 88
select = ["F401"] # remove unused imports
ignore = ["E", "F"]
[tool.isort]
profile = "black"
known_local_folder = ["ctx_to_lora"]

180
run_eval.py Normal file
View file

@ -0,0 +1,180 @@
import logging
from ctx_to_lora.eval_utils import run_eval
logger = logging.getLogger()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Evaluate a checkpoint")
parser.add_argument(
"--model_name_or_path",
type=str,
default=None,
help="Evaluate a base model from HuggingFace Hub, without loading checkpoint",
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help="Path to the checkpoint to evaluate",
)
parser.add_argument(
"--split",
type=str,
choices=["validation", "test"],
default="validation",
help="Which split to evaluate on",
)
parser.add_argument(
"--datasets",
type=str,
nargs="+",
help=(
"Specific datasets to evaluate on."
"If not provided, uses default from args.yaml"
),
)
parser.add_argument(
"--eval_batch_size",
type=int,
default=8,
help="Eval batch size for teacher forcing",
)
parser.add_argument(
"--eval_batch_size_gen",
type=int,
default=32,
help="Eval batch size for generation",
)
parser.add_argument(
"--max_val_samples_per_ds",
type=int,
default=-1,
help=(
"Maximum number of validation samples per dataset. "
"If -1, uses values from checkpoint config."
),
)
parser.add_argument(
"--max_test_samples_per_ds",
type=int,
default=500,
help=(
"Maximum number of validation samples per dataset. "
"If -1, uses values from checkpoint config."
),
)
parser.add_argument(
"--max_ctx_chunk_len",
type=int,
default=-1,
help="Maximum length of context chunk for evaluation",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=256,
help="Maximum number of new tokens to generate during evaluation",
)
parser.add_argument(
"--remove_context",
action="store_true",
help="Remove context when evaluating the base model.",
)
parser.add_argument(
"--use_cd",
action="store_true",
help="Use context distillation model for evaluation.",
)
parser.add_argument(
"--cd_update_iterations",
type=int,
default=20,
help="Number of update iterations for context distillation during evaluation",
)
parser.add_argument(
"--cd_use_gen_q",
action="store_true",
help="Use generated queries for context distillation training.",
)
parser.add_argument(
"--q_gen_rounds",
type=int,
default=4,
help="Number of rounds of query generation for context distillation.",
)
parser.add_argument(
"--cd_batch_size",
type=int,
default=16,
help="Batch size for context distillation.",
)
parser.add_argument(
"--use_iterative_mode",
action="store_true",
help="Use iterative mode LoRA layer-by-layer generation",
)
parser.add_argument(
"--use_llmlingua",
action="store_true",
help="Use LLMLingua compression for evaluation",
)
parser.add_argument(
"--llmlingua_compression_rate",
type=float,
default=0.9,
help="Compression rate for LLMLingua",
)
parser.add_argument(
"--use_t2l",
action="store_true",
help="Use Text-to-LoRA model for evaluation",
)
parser.add_argument(
"--add_ctx_to_input",
action="store_true",
help="Add ctx to base model's input",
)
parser.add_argument(
"--truncate_if_too_long_inp",
action="store_true",
help="Truncate input sequences that are too long",
)
parser.add_argument(
"--truncate_if_too_long_ctx",
action="store_true",
help="Truncate ctx sequences that are too long",
)
parser.add_argument(
"--gen_lora_scaling",
type=float,
default=1.0,
)
parser.add_argument(
"--flip_ctx_inp",
action="store_true",
help="Flip the order of context and input",
)
parser.add_argument(
"--use_generative_adapter",
action="store_true",
help="Use generative adapter for evaluation",
)
cli_args = vars(parser.parse_args())
if cli_args["model_name_or_path"]:
assert cli_args["max_ctx_chunk_len"] <= 0, (
f"Evaluating base model shouldn't be used with `max_ctx_chunk_len`"
)
eval_batch_size_gen = cli_args.pop("eval_batch_size_gen")
eval_batch_size = cli_args.pop("eval_batch_size")
run_eval(
**cli_args,
eval_batch_size=eval_batch_size_gen,
generative=True,
)

View file

@ -0,0 +1,17 @@
from huggingface_hub import snapshot_download
if __name__ == "__main__":
self_gen_data_dir = "./data/raw_datasets/self_gen/"
snapshot_download(
"SakanaAI/self_gen_qa_d2l",
repo_type="dataset",
local_dir=self_gen_data_dir,
# we can filter based on model by using the `allow_patterns` argument
# based on https://huggingface.co/datasets/SakanaAI/self_gen_qa_d2l/tree/main
# we can use
# - `Qwen` for downloading the data for `Qwen/Qwen3-4B-Instruct-2507`
# - `google` for downloading the data for `google/gemma-2-2b-it`
# - `mistralai` for downloading the data for `mistralai/Mistral-7B-Instruct-v0.2`
#
# allow_patterns="google/*", # downloading the data for `google/gemma-2-2b-it`
)

14
scripts/main_exp/1-train.sh Executable file
View file

@ -0,0 +1,14 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True

View file

@ -0,0 +1,30 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj \
--lora_r=8 \
--eval_strategy=no \
--max_qas_len=512 \
--max_qas_per_sample=1 \
--per_rank_gen=True \
--per_layer_processing=True \
--gen_lora_l1_reg_coef=0.1 \
--max_steps=20000 \
--gradient_accumulation_steps=16 \
--max_packed_inp_len=1024 \
--max_packed_ctx_len=2048 \
--use_per_ctx_average_loss=True \
--use_kl_loss=True \
--quantize_ctx_encoder=True \
--torch_empty_cache_steps=10 \
--from_pretrained_checkpoint=train_outputs/runs/$RUN_NAME/checkpoint-80000/pytorch_model.bin \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--warmup_steps=2000 \
--learning_rate=2e-5

View file

@ -0,0 +1,25 @@
# D2L pipeline
### Data
You can either download the generated data (recommended, ~100 GB for each model) or generate them by youself.
Please see [`0-download_data.sh`](0-download_data.sh) for how to do model-specific data download.
```bash
# download training data for all three models (328GB)
uv run bash scripts/main_exp/0-download_data.sh
```
Generating data from scratch can take very long if not parallelized across multiple gpus.
```bash
# generate training data (takes very long if not parallelized across multiple gpus)
# optional: use the command below for generating data from scratch
# uv run bash scripts/main_exp/gen_data.sh
```
### Training
Simply run the training script once the data is ready.
```bash
# train
uv run bash scripts/main_exp/1-train.sh
```
### Evaluation
All evaluation scripts for reproducing the main results in the paper are included in [eval](eval/) directory.

View file

@ -0,0 +1,8 @@
# no truncation
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1
# w/ truncation
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --truncate_if_too_long_inp
# no context
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --remove_context

6
scripts/main_exp/eval/cd.sh Executable file
View file

@ -0,0 +1,6 @@
# qa
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=4
# longbench
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=1

View file

@ -0,0 +1 @@
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets $1 --split test --use_cd --cd_update_iterations 50 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=5 --cd_batch_size=2

View file

@ -0,0 +1,6 @@
# qa
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
# longbench
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp

13
scripts/main_exp/eval/d2l.sh Executable file
View file

@ -0,0 +1,13 @@
# main results
# batched
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1
# iterative
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1 --use_iterative_mode
# query internalization
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad --split test --eval_batch_size_gen=1 --flip_ctx_inp
# replaced squad context
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_assistant_ctx_no_passage --split test
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_negative_no_passage --split test

View file

@ -0,0 +1,279 @@
import json
import os
import re
from argparse import Namespace
from difflib import SequenceMatcher
import torch
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.modeling.ctx_encoder import PerLayerActivations
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers
from ctx_to_lora.modeling.lora_merger import combine_lora
CLASS_NAMES = [
"tench",
"English springer",
"cassette player",
"chain saw",
"church",
"French horn",
"garbage truck",
"gas pump",
"golf ball",
"parachute",
]
CLASS_TO_INT = {name: i for i, name in enumerate(CLASS_NAMES)}
INPUT_TXT = f"What is in this image? Choose exactly one of the following classes: {', '.join(CLASS_NAMES)}. Response with only the correct class without any other text."
RUN_DIR = "train_outputs/runs/Oct16_02-37-04_slurm0-a3nodeset-8_94074_1d62ecb8"
def _normalize_text(text: str) -> str:
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
return " ".join(text.split())
def _normalize_compact(text: str) -> str:
return re.sub(r"[^a-z0-9]", "", text.lower())
def _build_alias_map():
alias_overrides = {
"english springer spaniel": "English springer",
"springer spaniel": "English springer",
"chainsaw": "chain saw",
"dump truck": "garbage truck",
"refuse truck": "garbage truck",
"garbage lorry": "garbage truck",
"fuel pump": "gas pump",
"gas station pump": "gas pump",
"cassette deck": "cassette player",
"cassette recorder": "cassette player",
"fish": "tench",
"tench fish": "tench",
"french horn instrument": "French horn",
"golfball": "golf ball",
"skydiving": "parachute",
"parachutist": "parachute",
}
alias_map = {}
def register(alias: str, canonical: str):
alias = _normalize_text(alias)
if alias:
alias_map[alias] = canonical
alias_map[_normalize_compact(alias)] = canonical
for name in CLASS_NAMES:
register(name, name)
register(name.replace(" ", ""), name)
register(name.replace(" ", "-"), name)
for alias, canonical in alias_overrides.items():
register(alias, canonical)
return alias_map
CLASS_ALIAS_MAP = _build_alias_map()
def pred_to_class_id(pred_txt: str) -> int:
norm_pred = _normalize_text(pred_txt)
compact_pred = _normalize_compact(pred_txt)
for alias, canonical in CLASS_ALIAS_MAP.items():
if alias and (alias in norm_pred or alias in compact_pred):
return CLASS_TO_INT[canonical]
pred_tokens = set(norm_pred.split())
best_class = None
best_token_hits = -1
for name in CLASS_NAMES:
class_tokens = set(_normalize_text(name).split())
if class_tokens and class_tokens.issubset(pred_tokens):
return CLASS_TO_INT[name]
hits = sum(token in pred_tokens for token in class_tokens)
if hits > best_token_hits:
best_token_hits = hits
best_class = name
best_ratio = -1.0
for name in CLASS_NAMES:
ratio = SequenceMatcher(None, norm_pred, _normalize_text(name)).ratio()
if ratio > best_ratio:
best_ratio = ratio
best_class = name
return CLASS_TO_INT[best_class]
def load_checkpoint():
checkpoint_path = f"{RUN_DIR}/checkpoint-80000/pytorch_model.bin"
state_dict = torch.load(checkpoint_path)
model = ModulatedPretrainedModel.from_state_dict(
state_dict,
train=False,
base_model_kwargs=dict(attn_implementation="flash_attention_2"),
use_flash_attn=True,
use_sequence_packing=False, # for generation
)
tokenizer = get_tokenizer("google/gemma-2-2b-it")
model.eval()
return model, tokenizer
def load_ctx_encoder():
model_id = "google/gemma-3-4b-it"
ctx_model = Gemma3ForConditionalGeneration.from_pretrained(
model_id, device_map="auto"
).eval()
ctx_encoder_config = Namespace(ctx_encoder_last_layer=26, keep_lm_head=True)
ctx_model.language_model = PerLayerActivations(
ctx_model.language_model, ctx_encoder_config
)
processor = AutoProcessor.from_pretrained(model_id)
return ctx_model, processor
def template_image(img, ctx_processor):
messages = [
{
"role": "system",
"content": [{"type": "text", "text": ""}],
},
{
"role": "user",
"content": [
{"type": "image", "image": img},
],
},
]
inputs = ctx_processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
return inputs
@torch.inference_mode
def get_ctx_features(ctx_inputs, ctx_encoder):
forward_outputs = ctx_encoder(**ctx_inputs, output_hidden_states=True)
ctx_features = torch.stack(forward_outputs.hidden_states, dim=1)
return ctx_features
def generate_loras(ctx_inputs, ctx_features):
generated_loras, _ = model.hypernet.generate_weights(
ctx_features, attn_mask=torch.ones_like(ctx_inputs["input_ids"])
)
generated_loras = combine_lora(
generated_loras,
n_chunks=torch.tensor((1,), device=model.device),
lora_bias=model.hypernet.get_head_bias()
if model.hypernet.config.use_bias
else None,
)
return generated_loras
def apply_loras(model, generated_loras):
n_queries = torch.ones(1, dtype=torch.int32, device=model.device)
apply_lora_to_layers(
model.base_model,
model.hypernet.layer_indices,
generated_loras,
n_queries,
)
if __name__ == "__main__":
model, base_tokenizer = load_checkpoint()
ctx_encoder, ctx_processor = load_ctx_encoder()
ds = load_dataset("frgfm/imagenette", "full_size", split="validation")
# ds = ds.shuffle().select(range(int(0.05 * len(ds))))
input_ids = base_tokenizer.apply_chat_template(
[{"role": "user", "content": INPUT_TXT}],
add_special_tokens=False,
return_attention_mask=False,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
preds = []
pred_txts = []
corrects = []
labels = ds["label"]
for sample in tqdm(ds):
img = sample["image"]
ctx_inputs = template_image(img, ctx_processor).to(ctx_encoder.device)
ctx_features = get_ctx_features(ctx_inputs, ctx_encoder)
generated_loras = generate_loras(ctx_inputs, ctx_features)
apply_loras(model, generated_loras)
model_outputs = model.base_model.generate(
input_ids, max_new_tokens=256, do_sample=False
)
pred_txt = base_tokenizer.decode(
model_outputs[0][len(input_ids[0]) :], skip_special_tokens=True
)
pred_txts.append(pred_txt)
preds.append(pred_to_class_id(pred_txt))
is_correct = preds[-1] == labels[len(preds) - 1]
corrects.append(is_correct)
print(
f"GT: {CLASS_NAMES[labels[len(preds) - 1]]}, Pred: {pred_txt} -> {CLASS_NAMES[preds[-1]]}, Correct: {is_correct}"
)
acc = sum(corrects) / len(corrects)
print(f"Final accuracy: {acc:4f}")
jsonl_path = os.path.join(RUN_DIR, "imagenette_eval.jsonl")
meta_path = os.path.join(RUN_DIR, "imagenette_eval.meta.json")
with open(jsonl_path, "w") as f:
for i, (pred_txt, pred_id, label_id) in enumerate(
zip(pred_txts, preds, labels)
):
f.write(
json.dumps(
{
"index": i,
"label": int(label_id),
"label_name": CLASS_NAMES[label_id],
"pred_text": pred_txt,
"pred_class_id": int(pred_id),
"pred_class_name": CLASS_NAMES[pred_id],
"correct": bool(pred_id == label_id),
}
)
+ "\n"
)
meta = {
"dataset": "frgfm/imagenette",
"subset": "full_size",
"split": "validation",
"run_dir": RUN_DIR,
"prompt": INPUT_TXT,
"accuracy": float(acc),
"num_samples": len(preds),
"class_names": CLASS_NAMES,
}
with open(meta_path, "w") as f:
json.dump(meta, f, indent=2)
print(f"Wrote samples to {jsonl_path}")
print(f"Wrote metadata to {meta_path}")

View file

@ -0,0 +1,5 @@
for dataset in squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e; do
for rate in 0.9 0.8 0.6 0.4 0.2 0.1; do
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets "$dataset" --split test --eval_batch_size_gen=1 --use_llmlingua --llmlingua_compression_rate "$rate" --truncate_if_too_long_ctx
done
done

4
scripts/main_exp/eval/t2l.sh Executable file
View file

@ -0,0 +1,4 @@
# download t2l checkpoint
uv run huggingface-cli download SakanaAI/text-to-lora --local-dir . --include "trained_t2l/gemma_2b_t2l"
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen=1 --use_t2l

20
scripts/main_exp/gen_data.sh Executable file
View file

@ -0,0 +1,20 @@
# download fineweb_edu to `data/raw_datasets/fineweb_edu
uv run data/download_fineweb_edu.py
# generate qa data
# run from 000 to 013
for shard_id in $(seq -f "%03g" 0 13); do
uv run data/generate_fw_edu_qa_v2.py --shard_pattern "${shard_id}_00000" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it --max_length=2000 --max_model_length=2048
uv run data/generate_fw_edu_qa_v2_repeat.py --shard_pattern "min_0_to_2000/${shard_id}*level_0" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it
# self-generated response QA data
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern "data/raw_datasets/fw_qa_v2/min_0_to_2000/${shard_id}*_level_1*" --closed_qa_prob 1.0
done
# val split
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern 'data/raw_datasets/fw_qa_v2/min_0_to_2000/*_level_0_val.parquet'
# self-gen data for other ds
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names squad_compact ropes_compact drop_compact --split train --closed_qa_prob 1.0
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names pwc_compact --split train --closed_qa_prob 0.0

View file

@ -0,0 +1,29 @@
#!/bin/bash
#SBATCH --job-name=ctxlora
#SBATCH --nodes=1
#SBATCH --partition=a3
#SBATCH --gpus=8
#SBATCH --output=slurm_logs/%x-%j.out
#SBATCH --error=slurm_logs/%x-%j.out
port=$((10000 + ($SLURM_JOBID % 50000)))
echo "Using port: $port"
# port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=20000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--warmup_steps=2000 \
--learning_rate=2e-5 \
"$@"

View file

@ -0,0 +1,24 @@
#!/bin/bash
#SBATCH --job-name=ctxlora
#SBATCH --nodes=1
#SBATCH --partition=a3
#SBATCH --gpus=8
#SBATCH --output=slurm_logs/%x-%j.out
#SBATCH --error=slurm_logs/%x-%j.out
port=$((10000 + ($SLURM_JOBID % 50000)))
echo "Using port: $port"
# port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
"$@"

15
scripts/main_exp/train_no_qa.sh Executable file
View file

@ -0,0 +1,15 @@
#!/bin/bash
port=29051
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
--num_processes=8 --gpu_ids all train.py \
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--target_modules=down_proj --lora_r=8 \
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
--quantize_ctx_encoder=True \
"$@"

1
scripts/niah/0-gen_data.sh Executable file
View file

@ -0,0 +1 @@
uv run data/generate_ctx_magic_number.py

42
scripts/niah/1-train.sh Executable file
View file

@ -0,0 +1,42 @@
#!/bin/bash
WANDB_MODE=disabled uv run train.py \
configs/niah_exp/ctx_magic_number_32_256.yaml \
--model_name_or_path=google/gemma-2-2b-it \
--num_train_epochs=1 \
--per_device_train_batch_size=-1 \
--gradient_accumulation_steps=16 \
--per_device_eval_batch_size=16 \
--exp_setup=hyper_lora \
--aggregator_type=perceiver \
--target_modules=down_proj \
--num_blocks=8 \
--num_self_attn_per_block=0 \
--num_pre_head_layers=1 \
--lora_r=8 \
--eval_steps=100 \
--logging_steps=10 \
--save_steps=1000 \
--learning_rate=4e-5 \
--lora_dropout=0.0 \
--neftune_noise_alpha=0 \
--per_rank_gen=True \
--per_layer_processing=True \
--gen_lora_l1_reg_coef=1.5 \
--use_sequence_packing=True \
--max_packed_inp_len=4096 \
--max_packed_ctx_len=4096 \
--dataloader_num_workers=0 \
--dataloader_prefetch_factor=None \
--eval_on_start=False \
--ctx_encoder_type=early_exit \
--n_latent_queries=208 \
--use_kl_loss=False \
--eval_on_start=True \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--max_val_samples_per_ds=100 \
--seed=1 \
--use_per_ctx_average_loss=True \
--torch_empty_cache_steps=10 \
"$@"

1
scripts/niah/2-eval-test.sh Executable file
View file

@ -0,0 +1 @@
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path CHECKPOINT_PATH --datasets ctx_magic_number_32_1024 ctx_magic_number_1024_2048 ctx_magic_number_3072_4096 ctx_magic_number_7168_8192 ctx_magic_number_15360_16384 ctx_magic_number_28672_32768 ctx_magic_number_57344_65536 ctx_magic_number_122880_131072 --max_ctx_chunk_len=1024 --split test --eval_batch_size_gen=4

1
scripts/niah/2-eval.sh Executable file
View file

@ -0,0 +1 @@
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path $CHECKPOINT_PATH --datasets ctx_magic_number_32_1024 ctx_magic_number_1024_2048 ctx_magic_number_2048_3072 ctx_magic_number_3072_4096 ctx_magic_number_4096_5120 ctx_magic_number_5120_6144 ctx_magic_number_6144_7168 ctx_magic_number_7168_8192 ctx_magic_number_8192_9216 ctx_magic_number_9216_10240 ctx_magic_number_10240_11264 ctx_magic_number_11264_12288 ctx_magic_number_12288_13312 ctx_magic_number_13312_14336 ctx_magic_number_14336_15360 ctx_magic_number_15360_16384 ctx_magic_number_16384_20480 ctx_magic_number_20480_24576 ctx_magic_number_24576_28672 ctx_magic_number_28672_32768 ctx_magic_number_32768_40960 ctx_magic_number_40960_49152 ctx_magic_number_49152_57344 ctx_magic_number_57344_65536 ctx_magic_number_65536_73728 ctx_magic_number_73728_81920 ctx_magic_number_81920_90112 ctx_magic_number_90112_98304 ctx_magic_number_98304_106496 ctx_magic_number_106496_114688 ctx_magic_number_114688_122880 ctx_magic_number_122880_131072 --max_ctx_chunk_len=1024 --split test --eval_batch_size_gen=4

8
scripts/niah/README.md Normal file
View file

@ -0,0 +1,8 @@
# NIAH experiment
```bash
# run the scripts in this order
# data generation is only needed to be run once
uv run bash scripts/niah/0-gen_data.sh
uv run bash scripts/niah/1-train.sh
uv run bash scripts/niah/2-eval.sh
```

View file

@ -0,0 +1,39 @@
WANDB_MODE=disabled run uv run train.py configs/niah_exp/ctx_magic_number_32_256.yaml \
--model_name_or_path=mistralai/Mistral-7B-Instruct-v0.2 \
--num_train_epochs=1 \
--per_device_train_batch_size=-1 \
--gradient_accumulation_steps=64 \
--per_device_eval_batch_size=16 \
--exp_setup=hyper_lora \
--aggregator_type=perceiver \
--target_modules=down_proj \
--num_blocks=8 \
--num_self_attn_per_block=0 \
--num_pre_head_layers=1 \
--lora_r=8 \
--eval_steps=100 \
--logging_steps=10 \
--save_steps=1000 \
--learning_rate=4e-5 \
--lora_dropout=0.0 \
--neftune_noise_alpha=0 \
--per_rank_gen=True \
--per_layer_processing=True \
--gen_lora_l1_reg_coef=2.0 \
--use_sequence_packing=True \
--max_packed_inp_len=1024 \
--max_packed_ctx_len=1024 \
--dataloader_num_workers=0 \
--dataloader_prefetch_factor=None \
--eval_on_start=False \
--ctx_encoder_type=early_exit \
--n_latent_queries=208 \
--use_kl_loss=False \
--eval_on_start=True \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--max_val_samples_per_ds=100 \
--seed=1 \
--use_per_ctx_average_loss=True \
--torch_empty_cache_steps=10

View file

@ -0,0 +1,39 @@
WANDB_MODE=disabled run uv run train.py configs/niah_exp/ctx_magic_number_32_256.yaml \
--model_name_or_path=Qwen/Qwen3-4B-Instruct-2507 \
--num_train_epochs=1 \
--per_device_train_batch_size=-1 \
--gradient_accumulation_steps=32 \
--per_device_eval_batch_size=16 \
--exp_setup=hyper_lora \
--aggregator_type=perceiver \
--target_modules=down_proj \
--num_blocks=8 \
--num_self_attn_per_block=0 \
--num_pre_head_layers=1 \
--lora_r=8 \
--eval_steps=100 \
--logging_steps=10 \
--save_steps=1000 \
--learning_rate=4e-5 \
--lora_dropout=0.0 \
--neftune_noise_alpha=0 \
--per_rank_gen=True \
--per_layer_processing=True \
--gen_lora_l1_reg_coef=0.5 \
--use_sequence_packing=True \
--max_packed_inp_len=2048 \
--max_packed_ctx_len=2048 \
--dataloader_num_workers=0 \
--dataloader_prefetch_factor=None \
--eval_on_start=False \
--ctx_encoder_type=early_exit \
--n_latent_queries=208 \
--use_kl_loss=False \
--eval_on_start=True \
--max_ctx_chunk_len=512 \
--min_ctx_chunk_len=25 \
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
--max_val_samples_per_ds=100 \
--seed=1 \
--use_per_ctx_average_loss=True \
--torch_empty_cache_steps=10

12
setup.py Normal file
View file

@ -0,0 +1,12 @@
# read the contents of the README file
from pathlib import Path
from setuptools import setup
this_directory = Path(__file__).parent
long_description = (this_directory / "README.md").read_text()
setup(
long_description=long_description,
long_description_content_type="text/markdown",
)

View file

596
src/ctx_to_lora/configs.py Normal file
View file

@ -0,0 +1,596 @@
import dataclasses
import os
import sys
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Literal, NewType
import torch
import yaml
from transformers import (
MODEL_FOR_CAUSAL_LM_MAPPING,
HfArgumentParser,
TrainingArguments,
)
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
DataClassType = NewType("DataClassType", Any)
class ArgumentParser(HfArgumentParser):
def parse_yaml_and_args(
self, yaml_arg: str, other_args: list[str] | None = None
) -> list[dataclass]:
"""
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
Args:
yaml_arg (`str`):
The path to the config file used
other_args (`List[str]`, *optional`):
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
Returns:
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
"""
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
outputs = []
# strip other args list into dict of key-value pairs
other_args = {
arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args
}
used_args = {}
# overwrite the default/loaded value with the value provided to the command line
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
for arg, val in other_args.items():
# add only if in keys
if arg in keys:
if val in ["None", "none", "null", "NULL"]:
val = None
inputs[arg] = val
used_args[arg] = val
continue
base_type = data_yaml.__dataclass_fields__[arg].type
inputs[arg] = val
# cast type for ints, floats (default to strings)
if base_type in [int, float]:
inputs[arg] = base_type(val)
if base_type == list[str]:
inputs[arg] = [str(v) for v in val.split(",")]
# bool of a non-empty string is True, so we manually check for bools
if base_type == bool:
if val in ["true", "True"]:
inputs[arg] = True
else:
inputs[arg] = False
if base_type == dict:
inputs[arg] = yaml.load(val, Loader=yaml.FullLoader)
# add to used-args so we can check if double add
if arg not in used_args:
used_args[arg] = val
else:
raise ValueError(
f"Duplicate argument provided: {arg}, may cause unexpected behavior"
)
obj = data_class(**inputs)
outputs.append(obj)
for arg in other_args:
if arg not in used_args:
raise ValueError(f"Argument provided not found in dataclass: {arg}")
return outputs
def parse(self) -> DataClassType | tuple[DataClassType]:
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
# If we pass only one argument to the script and it's the path to a YAML file,
# let's parse it to get our arguments.
output = self.parse_yaml_file(os.path.abspath(sys.argv[1].split("=")[-1]))
# parse command line args and yaml file
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
output = self.parse_yaml_and_args(
os.path.abspath(sys.argv[1].split("=")[-1]), sys.argv[2:]
)
# parse --config for the yaml path and other command line args
elif any([arg.startswith("--config") for arg in sys.argv]):
yaml_arg = [
arg
for arg in sys.argv[1:]
if arg.startswith("--config") and arg.endswith(".yaml")
][0]
other_args = [arg for arg in sys.argv[1:] if arg != yaml_arg]
output = self.parse_yaml_and_args(
os.path.abspath(yaml_arg.split("=")[-1]), other_args
)
# parse command line args only
else:
output = self.parse_args_into_dataclasses()
if len(output) == 1:
output = output[0]
return output
class ExperimentSetup(str, Enum):
HYPERLORA = "hyper_lora"
@dataclass
class TrainingArguments(TrainingArguments):
output_dir: str = field(
default="",
metadata={"help": "Placeholder. Will be overwritten by train.py"},
)
tf32: bool = field(
default=True,
metadata={"help": "Whether to use tf32 precision."},
)
bf16: bool = field(
default=True,
metadata={"help": "Whether to use bf16 precision."},
)
label_names: list[str] = field(
default=("labels",),
metadata={
"help": "List of strings to specify the label names in the dataset. "
"This is used to compute the loss and metrics."
},
)
include_for_metrics: list[str] = field(
default=("inputs",),
metadata={
"help": "List of strings to specify additional data to include in the `compute_metrics` function."
"Options: 'inputs', 'loss'."
},
)
per_device_eval_batch_size: int = field(
default=64,
metadata={
"help": "Batch size for evaluation. "
"If not set, will use the same as per_device_train_batch_size."
},
)
per_device_train_batch_size: int = field(
default=1,
metadata={
"help": "Batch size for training. "
"If not set, will use the same as per_device_eval_batch_size."
},
)
# TODO: use this! (check trainer.py for proper computation)
average_tokens_across_devices: bool = field(
default=False,
metadata={"help": "compute num_items_in_batch across devices."},
)
# mem leak if use persistent workers
# https://github.com/pytorch/pytorch/issues/62066
# https://github.com/huggingface/transformers/issues/30943
dataloader_persistent_workers: bool = field(
default=False,
metadata={
"help": "Whether to keep the workers alive after a dataset has been consumed once."
},
)
dataloader_prefetch_factor: int = field(
default=16,
metadata={"help": "Number of batches loaded in advance by each worker."},
)
dataloader_num_workers: int = field(
default=8,
metadata={"help": "Number of subprocesses to use for data loading."},
)
neftune_noise_alpha: float = field(
default=5.0,
metadata={"help": "Neftune noise alpha for the optimizer."},
)
learning_rate: float = field(
default=4e-5,
metadata={"help": "Initial learning rate."},
)
weight_decay: float = field(
default=0.01,
metadata={"help": "Weight decay for the optimizer."},
)
optim: str = field(
default="adamw_torch_fused",
metadata={"help": "Optimizer."},
)
adam_beta1: float = field(
default=0.9,
metadata={"help": "Adam beta 1."},
)
adam_beta2: float = field(
default=0.999,
metadata={"help": "Adam beta 2."},
)
adam_epsilon: float = field(
default=1e-8,
metadata={"help": "Adam epsilon."},
)
lr_scheduler_type: str = field(
default="cosine_with_min_lr",
metadata={"help": "Learning rate scheduler type."},
)
lr_scheduler_kwargs: dict = field(
default=None,
metadata={"help": "Learning rate scheduler kwargs."},
)
warmup_steps: int = field(
default=100,
metadata={"help": "Number of warmup steps."},
)
eval_on_start: bool = field(
default=False,
metadata={"help": "Whether to evaluate on the start of training."},
)
eval_strategy: str = field(
default="steps",
metadata={"help": "Evaluation strategy."},
)
eval_steps: int = field(
default=1_000,
metadata={"help": "Evaluation steps."},
)
metric_for_best_model: str = field(
default=None,
metadata={"help": "Metric for best model."},
)
load_best_model_at_end: bool = field(
default=False,
metadata={"help": "Whether to load the best model at the end of training."},
)
save_total_limit: int = field(
default=2,
metadata={"help": "Total number of checkpoints to save."},
)
save_strategy: str = field(
default="steps",
)
save_steps: int = field(
default=5_000,
)
save_safetensors: bool = field(
default=False,
)
logging_strategy: str = field(
default="steps",
)
logging_steps: int = field(
default=100,
)
use_liger_kernel: bool = field(
default=False,
)
remove_unused_columns: bool = field(
default=False,
)
# needed to avoid OOM by compute the metrics batch by batch
# w/o this the trainer stores logits of all sample in memory...
batch_eval_metrics: bool = field(
default=True,
)
logging_first_step: bool = field(
default=True,
metadata={"help": "Whether to log the first step."},
)
ddp_find_unused_parameters: bool = field(
default=False,
metadata={"help": "Whether to find unused parameters in DDP."},
)
ddp_timeout: int = field(
default=2**20,
metadata={"help": "Timeout for distributed data parallel training."},
)
@dataclass
class ModelArguments:
"""
Arguments for the base model.
"""
model_name_or_path: str = field(
default=None,
metadata={"help": ("Base model name or path.")},
)
use_flash_attn: bool = field(
default=True,
metadata={"help": "Whether to use flash attention."},
)
@dataclass
class LoRAArguments:
lora_r: int | None = field(
default=8,
metadata={"help": ("LoRA R value.")},
)
lora_dropout: float | None = field(
default=0.0,
metadata={"help": ("LoRA dropout.")},
)
target_modules: list[str] | None = field(
default=None,
metadata={"help": ("LoRA target modules.")},
)
@dataclass
class CtxTrainingArguments:
exp_setup: ExperimentSetup = field(
default=ExperimentSetup.HYPERLORA,
metadata={"help": "Experiment setup - LoRA, HyperLoRA, or full finetuning"},
)
from_pretrained_checkpoint: str = field(
default=None,
metadata={"help": "Path to the pretrained checkpoint."},
)
max_base_len: int | None = field(
default=2**13,
metadata={"help": "Maximum base length for training."},
)
use_sequence_packing: bool = field(
default=True,
metadata={"help": "Whether to use sequence packing."},
)
max_ctx_len: int = field(
default=-1,
metadata={"help": "Max context length. Overrides ctx tokenizer length."},
)
max_qas_len: int = field(
default=2**11,
metadata={
"help": "Maximum question-answering token length of each sample for training. "
"QA pairs that are longer than this value will be split up into multiple samples."
},
)
max_qas_per_sample: int = field(
default=-1,
metadata={
"help": "Max QA pair per context. If a context has more QA pairs than this value, "
"they will be split up into multiple samples."
},
)
num_chunk_probs: dict = field(
default=None,
metadata={"help": "Probability distribution over chunk nums."},
)
max_ctx_chunk_len: int = field(
default=-1,
metadata={
"help": "Max context chunk length. If a context is longer than this value, "
"it will be split up into multiple chunks."
},
)
min_ctx_chunk_len: int = field(
default=-1,
metadata={
"help": "Min context chunk length. Used only with random chunking training"
},
)
max_ctx_chunk_num: int | None = field(
default=None,
metadata={"help": "Max number of context chunks per sample."},
)
max_packed_inp_len: int | None = field(
default=2**14,
metadata={"help": "Maximum packed input length for training."},
)
max_packed_ctx_len: int | None = field(
# forward pass of the ctx encoder is cheaper --> longer packed len
default=2**15,
metadata={"help": "Maximum packed context length for training."},
)
max_new_tokens: int | None = field(
default=256,
metadata={"help": "Maximum new tokens for generation-based evaluation."},
)
gen_per_device_eval_batch_size: int | None = field(
default=1,
metadata={"help": "Per device evaluation batch size for generation."},
)
notes: str | None = field(
default=None,
metadata={"help": "Wandb notes for the experiment."},
)
use_kl_loss: bool = field(
default=False,
metadata={"help": "Whether to use KL loss."},
)
use_per_ctx_average_loss: bool = field(
default=False,
metadata={"help": "Whether to use per-context average loss."},
)
gen_lora_l1_reg_coef: float = field(
default=0.0,
metadata={"help": "L1 regularization coefficient for generated LoRAs."},
)
add_negative_prompt: bool = field(
default=False,
metadata={"help": "Whether to add negative prompt training."},
)
@dataclass
class DataArguments:
train_ds_names: list[str] = field(
default=None,
metadata={"help": "Training dataset names."},
)
streaming: bool = field(
default=False,
metadata={"help": "Whether to use streaming dataset for training."},
)
val_ds_names: list[str] | None = field(
default=None,
metadata={"help": "Validation dataset names."},
)
test_ds_names: list[str] | None = field(
default=None,
metadata={"help": "Test dataset names."},
)
max_train_samples_per_ds: int | None = field(
default=None,
metadata={"help": "Maximum number of training samples per dataset."},
)
max_val_samples_per_ds: int | None = field(
default=1000,
metadata={"help": "Maximum number of validation samples per dataset."},
)
max_test_samples_per_ds: int | None = field(
default=500,
metadata={"help": "Maximum number of test samples per dataset."},
)
@dataclass
class HypernetArguments:
latent_size: int = field(
default=512,
metadata={"help": "Latent size for HyperLoRA."},
)
use_light_weight_lora: bool = field(
default=False,
metadata={"help": "Whether to use light-weight LoRA."},
)
light_weight_latent_size: int = field(
default=128,
metadata={"help": "Latent size for light-weight LoRA."},
)
dropout_rate: float = field(
default=0.0,
metadata={"help": "Dropout rate for HyperLoRA."},
)
extra_modules: list[str] | None = field(
default=None,
metadata={"help": "Extra modules to train."},
)
per_rank_gen: bool = field(
default=False,
metadata={"help": "Whether to use per-rank generation."},
)
use_bias: bool = field(
default=True, metadata={"help": "Whether to include data-dependent LoRA"}
)
use_per_rank_bias: bool = field(
default=False, metadata={"help": "Whether to use per-rank bias."}
)
per_layer_processing: bool = field(
default=False,
metadata={"help": "Whether to use per-layer processing (after preceiver)."},
)
use_token_mixing: bool = field(
default=False,
metadata={"help": "Whether to use token mixing block."},
)
num_pre_head_layers: int = field(
default=1, metadata={"help": "# of layers before hypernet head"}
)
@dataclass
class CtxEncoderArguments:
ctx_encoder_model_name_or_path: str = field(
default=None,
metadata={"help": "Context encoder model name or path."},
)
ctx_encoder_type: Literal["embed_only", "per_layer_activations", "early_exit"] = (
field(
default="early_exit",
metadata={
"help": "Context encoder type. "
"Options: 'embed_only', 'per_layer_activations', 'early_exit'."
},
)
)
# used only with `early_exit` type
layer_idx: int | None = field(
default=None,
metadata={
"help": "Layer index for context encoder. "
"Default to L//4 where L is the number of layers of the ctx model. "
"Only used when ctx_encoder_type==early_exit"
},
)
quantize_ctx_encoder: bool = field(
default=False, metadata={"help": "Wheter to quantize the ctx encoder."}
)
ctx_encoder_last_layer: int | None = field(
default=None,
metadata={
"help": "Maximum number of layers for the context encoder. "
"Only used when ctx_encoder_type==per_layer_activations"
},
)
@dataclass
class AggregatorArguments:
aggregator_type: Literal["pooler", "perceiver"] = field(
default="perceiver",
metadata={"help": "Aggregator type for HyperLoRA."},
)
# pooler
pooling_type: str = field(
default="mean",
metadata={"help": "Pooling type for HyperLoRA."},
)
num_latent_factor: int = field(
default=8,
metadata={"help": "Number of latent factors for Perceiver."},
)
n_latent_queries: int = field(
default=208, # 26 * 8
metadata={"help": "Number of latent queries of Perceiver."},
)
num_blocks: int = field(
default=8,
metadata={"help": "Number of blocks for Perceiver."},
)
num_self_attn_per_block: int = field(
default=0,
metadata={"help": "Number of self-attention layers per block for Perceiver."},
)
shared_weights: bool = field(
default=False,
metadata={"help": "Whether to share weights across blocks for Perceiver."},
)
# needed for loading model from checkpoint
# see https://github.com/huggingface/transformers/pull/34632
torch.serialization.add_safe_globals(
[
DataArguments,
CtxTrainingArguments,
ModelArguments,
LoRAArguments,
TrainingArguments,
HypernetArguments,
AggregatorArguments,
CtxEncoderArguments,
]
)
if __name__ == "__main__":
print(ExperimentSetup)
print(ExperimentSetup.LORA)
print(ExperimentSetup.HYPER_LORA)
print(ExperimentSetup.FULL_FINETUNE)

View file

View file

@ -0,0 +1,145 @@
import numpy as np
import torch
from transformers.data import (
DataCollatorWithFlattening,
default_data_collator,
)
from ctx_to_lora.utils import check_is_iterable, concat_list
flattener = DataCollatorWithFlattening()
def flatten_if_not_packed(inp_list):
# no padding
sample = inp_list[0]
n = len(inp_list)
# training data is packed
if "position_ids" in sample:
if n == 1:
n_queries = sample.pop("n_queries")
n_ctx_chunks = sample.pop("n_ctx_chunks")
batch = default_data_collator(inp_list, return_tensors="pt")
batch["n_queries"] = torch.tensor(n_queries)
batch["n_ctx_chunks"] = torch.tensor(n_ctx_chunks)
return batch
elif n > 1:
raise NotImplementedError("Please use batch_size=1 when using packed data")
# when batch_size > 1 (never used?)
# return default_data_collator(concat_batch(inp_list), return_tensors="pt")
# for eval data (not packed) during training
need_flatten = check_is_iterable(sample["input_ids"][0])
assert not need_flatten, f"Validation data should not be nested."
n_queries = torch.ones(len(inp_list), dtype=torch.int32)
n_ctx_chunks = torch.tensor(
[len(example["ctx_ids"]) for example in inp_list], dtype=torch.int32
)
packed_inputs = flattener(inp_list, return_tensors="pt")
packed_inputs["n_queries"] = n_queries
packed_inputs["n_ctx_chunks"] = n_ctx_chunks
if "ctx_ids" in sample:
# HACK: assumes 1 ctx chunk here
# assert all(len(ctx_ids) == 1 for ctx_ids in sample["ctx_ids"]), (
# "ctx_ids can only have one chunk for eval. "
# "Please implement chunked ctx forward pass to handle this."
# )
ctx_ids = concat_list([example.pop("ctx_ids") for example in inp_list])
ctx_position_ids = torch.cat([torch.arange(len(ids)) for ids in ctx_ids])
ctx_ids = torch.tensor(concat_list(ctx_ids))
packed_inputs["ctx_ids"] = ctx_ids.unsqueeze(0)
packed_inputs["ctx_position_ids"] = ctx_position_ids.unsqueeze(0)
# for eval info
if "ctx_ids_len" in sample:
packed_inputs["ctx_ids_len"] = [
example["ctx_ids_len"] for example in inp_list
]
return packed_inputs
def eval_collator(inp_list, tokenizer):
# only used for teacher-forcing eval
# input is a list of tokenized sequences
padding_kwargs = dict(padding=True, padding_side="right", return_tensors="pt")
has_ctx_ids = "ctx_ids" in inp_list[0]
if has_ctx_ids:
# pad to the longest ctx_len in the batch
# which can have a different length from the input_ids, attn_mask, labels
ctx_ids = [example.pop("ctx_ids") for example in inp_list]
ctx_attn_mask = [torch.ones_like(x) for x in ctx_ids]
ctx_ids = torch.nn.utils.rnn.pad_sequence(
ctx_ids,
batch_first=True,
padding_value=0,
)
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
ctx_attn_mask,
batch_first=True,
padding_value=0,
)
for inp in inp_list:
inp["attention_mask"] = torch.ones_like(inp["input_ids"])
labels = [x.pop("labels") for x in inp_list]
# need to pass the whole inp bc we also track the lengths (with specal keys)
padded_seq = tokenizer.pad(inp_list, **padding_kwargs)
# hacky explicit padding since the labels are not padded by default
labels = tokenizer.pad({"input_ids": labels}, **padding_kwargs)["input_ids"]
labels = torch.where(padded_seq["attention_mask"] == 0, -100, labels)
out = {**padded_seq, "labels": labels}
if has_ctx_ids:
out["ctx_ids"] = ctx_ids
out["ctx_attn_mask"] = ctx_attn_mask
return out
def generation_collator(inp_list, tokenizer):
padding_kwargs = dict(padding=True, padding_side="left", return_tensors="pt")
input_ids = [torch.tensor(x.pop("input_ids")) for x in inp_list]
labels = [x.pop("labels") for x in inp_list]
for i, label in enumerate(labels):
# we don't include the labels in the output during generation
# remove the response tokens
idx = np.argmax(label != -100)
idx = max(1, idx)
input_ids[i] = input_ids[i][:idx]
attn_mask = [torch.ones_like(x) for x in input_ids]
out = tokenizer.pad(
{"input_ids": input_ids, "attention_mask": attn_mask}, **padding_kwargs
)
if "ctx_ids" in inp_list[0]:
# pad to the longest ctx_len in the batch
# which can have a different length from the input_ids, attn_mask, labels
ctx_ids = [example.pop("ctx_ids") for example in inp_list]
n_chunks = [len(x) for x in ctx_ids]
ctx_ids = concat_list(ctx_ids)
ctx_ids = [torch.tensor(x) for x in ctx_ids]
ctx_attn_mask = [torch.ones_like(x) for x in ctx_ids]
ctx_ids = torch.nn.utils.rnn.pad_sequence(
ctx_ids,
batch_first=True,
padding_value=0,
)
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
ctx_attn_mask,
batch_first=True,
padding_value=0,
)
out["ctx_ids"] = ctx_ids
out["ctx_attn_mask"] = ctx_attn_mask
out["n_ctx_chunks"] = torch.tensor(n_chunks, dtype=torch.int32)
return out

View file

@ -0,0 +1,250 @@
IGNORE_INDEX = -100
TRANSFORMED_DATA_DIR = "data/processed_datasets"
RAW_DATA_DIR = "data/raw_datasets/"
SELF_GEN_DATA_DIR = f"{RAW_DATA_DIR}/self_gen/"
# for chunking
CTX_AFFIXES = {
"google/gemma-2-2b-it": {
"prefix": [2, 106, 1645, 110], # <bos><start_of_turn>user\n\n\n
"suffix": [107, 108, 106, 2516, 108], # <end_of_turn>\n<start_of_turn>model\n
},
"mistralai/Mistral-7B-Instruct-v0.2": {
"prefix": [1, 733, 16289, 28793, 28705, 13, 13], # `<s> [INST] \n\n`
"suffix": [733, 28748, 16289, 28793], # ` [/INST] `
},
"Qwen/Qwen3-4B-Instruct-2507": {
# `<|im_start|>system\n<|im_end|>\n<|im_start|>user\n`
"prefix": [151644, 8948, 198, 151645, 198, 151644, 872, 198],
# `<|im_end|>\n<|im_start|>assistant\n`
"suffix": [151645, 198, 151644, 77091, 198],
},
}
LONGBENCH_TASKS = [
"longbench/qasper",
"longbench/multifieldqa_en",
"longbench/2wikimqa",
]
LONGBENCH_E_TASKS = [
"longbench/qasper_e",
"longbench/multifieldqa_en_e",
"longbench/2wikimqa_e",
]
DS_KWARGS = {
"pwc": dict(
train=dict(path="sggetao/PwC", split="train"),
validation=dict(path="sggetao/PwC", split="test[:900]"),
test=dict(path="sggetao/PwC", split="test"),
),
"pwc_compact": dict(
train=dict(
path="parquet",
data_files="data/raw_datasets/pwc_compact/train/ds.parquet",
split="train",
),
),
"pwc_compact_tiny": dict(
train=dict(
path="parquet",
data_files="data/raw_datasets/pwc_compact/train/ds.parquet",
# correspond to skipping to first 900 samples in the original dataset
split="train[60:200]",
),
),
"pwc_tiny": dict(
train=dict(path="sggetao/PwC", split="train[900:2000]"),
validation=dict(path="sggetao/PwC", split="train[:900]"),
),
"squad": dict(
train=dict(path="data/raw_datasets/squad", split="train"),
validation=dict(path="data/raw_datasets/squad", split="validation[:1000]"),
test=dict(path="data/raw_datasets/squad", split="validation"),
),
"squad_compact": dict(
train=dict(
path="parquet",
data_files="data/raw_datasets/squad_compact/train/ds.parquet",
# correspond to skipping to first 900 samples in the original dataset
split="train[180:]",
),
),
"squad_negative": dict(
test=dict(path="data/raw_datasets/squad", split="validation"),
),
"squad_assistant_ctx": dict(
test=dict(path="data/raw_datasets/squad", split="validation"),
),
"squad_negative_no_passage": dict(
test=dict(path="data/raw_datasets/squad", split="validation"),
),
"squad_assistant_ctx_no_passage": dict(
test=dict(path="data/raw_datasets/squad", split="validation"),
),
"drop": dict(
train=dict(
path="ucinlp/drop",
split="train",
),
validation=dict(
path="ucinlp/drop",
split="validation[:900]",
),
test=dict(
path="ucinlp/drop",
split="validation",
),
),
"drop_compact": dict(
train=dict(
path="parquet",
data_files="data/raw_datasets/drop_compact/train/ds.parquet",
# correspond to skipping to first 900 samples in the original dataset
split="train",
)
),
"ropes": dict(
train=dict(
path="allenai/ropes",
split="train",
),
validation=dict(
path="allenai/ropes",
split="validation[:900]",
),
test=dict(path="allenai/ropes", split="validation"),
),
"ropes_compact": dict(
train=dict(
path="parquet",
data_files="data/raw_datasets/ropes_compact/train/ds.parquet",
split="train",
)
),
}
# add ctx_numbers
tok_bins = [(64, 128), (128, 256), (256, 512)] + [
(512 + 256 * i, 512 + 256 * (i + 1)) for i in range(14)
]
tok_bins += [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
]
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
for toy_ds_name in ["ctx_numbers", "ctx_kv", "ctx_magic_number"]:
for tok_bin in tok_bins:
DS_KWARGS[f"{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}"] = dict(
train=dict(
path="json",
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/train.jsonl",
split="train",
),
validation=dict(
path="json",
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/val.jsonl",
split="train",
),
test=dict(
path="json",
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/test.jsonl",
split="train",
),
)
# LongBench kwargs
for ds_name in LONGBENCH_TASKS + LONGBENCH_E_TASKS:
DS_KWARGS[ds_name] = dict(
test=dict(
path="THUDM/LongBench",
name=ds_name.split("/")[-1],
split="test",
)
)
CLOSED_QA_DATASETS = {
"longbench/qasper",
"longbench/multifieldqa_en",
"longbench/2wikimqa",
"squad",
"squad_negative",
"squad_assistant_ctx",
"squad_negative_no_passage",
"squad_assistant_ctx_no_passage",
"ropes",
"drop",
}
MULTI_ANSWER_DATASETS = {
"longbench/qasper",
"longbench/multifieldqa_en",
"longbench/hotpotqa",
"longbench/2wikimqa",
"squad",
"squad_negative",
"squad_assistant_ctx",
"squad_negative_no_passage",
"squad_assistant_ctx_no_passage",
"drop",
}
for ds_name in list(CLOSED_QA_DATASETS):
if ds_name.startswith("longbench/"):
CLOSED_QA_DATASETS.add(f"{ds_name}_e")
for ds_name in list(MULTI_ANSWER_DATASETS):
if ds_name.startswith("longbench/"):
MULTI_ANSWER_DATASETS.add(f"{ds_name}_e")
# for training closed qa datasets, e.g., hotpot_qa, squad, etc.
CLOSED_QA_INTX_TEMPLATES = [
"Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}",
"Answer without any explanation.\n\nQuestion: {input}",
"Based on the provided text, what is the answer to the following question? Provide only the answer.\n\nQuestion: {input}",
"Extract the answer to the question from the text. Be concise. Do not explain.\n\nQuestion: {input}",
"What is the answer to this question, based on the context? Respond with the answer only.\n\nQuestion: {input}",
"Provide a direct answer to the question using the given passages. Do not give any explanation.\n\nQuestion: {input}",
"Answer the question using only information from the provided text. No extra words.\n\nQuestion: {input}",
"From the passages, answer the question. Just the answer, please.\n\nQuestion: {input}",
"Give the answer to the question. Do not include any other text.\n\nQuestion: {input}",
"The answer to the question is in the text. Find it and state it clearly. No need for explanation.\n\nQuestion: {input}",
"Concisely answer the question based on the text provided. Don't include any other words. Just the answer.\n\nQuestion: {input}",
"Read the passages and answer the question with the minimal necessary words.\n\nQuestion: {input}",
"What is the direct response to the question, according to the text? Avoid explanation.\n\nQuestion: {input}",
"Please provide only the answer to the question, derived from the text.\n\nQuestion: {input}",
"Using the provided context, answer the question. Output the answer and nothing else.\n\nQuestion: {input}",
"Identify the answer in the text and present it without elaboration.\n\nQuestion: {input}",
"Answer the following question based on the text. Your answer should be brief and to the point. No explanation.\n\nQuestion: {input}",
"Based on the information given, what is the answer to the question? Only state the answer.\n\nQuestion: {input}",
"Find the answer to the question in the provided passages and write it down. No explanations.\n\nQuestion: {input}",
"The question is: {input}. Provide the answer based on the text, and nothing more.",
"Question: {input}\nAnswer directly based on the text provided. No extra words.",
"Question: {input}\nPlease provide the answer based on the text. No explanation is needed.",
]
EVAL_INTX_TEMPLATES = {
# binary (yes/no, a/b) qa given ctx
"ropes": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
# short-ctx reasoning
"drop": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
# short-ctx extractive qa
"squad": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
"squad_negative": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
"squad_assistant_ctx": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
"squad_negative_no_passage": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
"squad_assistant_ctx_no_passage": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
# longbench
"longbench/qasper": 'Answer the question as concisely as you can, using a single phrase or sentence if possible.\nIf the question cannot be answered based on the information in the article, write "unanswerable".\nIf the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}',
"longbench/multifieldqa_en": "Answer the following question. Only output the answer and do not output any other words.\n\nQuestion: {input}",
"longbench/2wikimqa": "Answer the following question. Only output the answer and do not output any other words.\n\nQuestion: {input}",
}
for ds_name in LONGBENCH_E_TASKS:
EVAL_INTX_TEMPLATES[ds_name] = EVAL_INTX_TEMPLATES[ds_name[:-2]]

View file

@ -0,0 +1,290 @@
# based on
# https://github.com/MeetKai/functionary/blob/aa3dbdd65f7e388f2386622606bdfeec95c2b863/functionary/train/packing/packed_dataset.py
import json
import logging
import os
import pprint
import numpy as np
from ctx_to_lora.utils import check_is_iterable, concat_list
logger = logging.getLogger()
def pack_data_points_by_length(
lens: list[list[int]],
ctx_lens: list[list[int]],
max_packed_inp_len: int,
max_packed_ctx_len: int,
max_size: int = -1,
) -> tuple[list[int], list[int]]:
if not lens:
return []
len_arr = np.array([sum(l) for l in lens], dtype=np.long)
ctx_len_arr = np.array([sum(l) for l in ctx_lens], dtype=np.long)
n = len(len_arr)
assert len(ctx_len_arr) == n, "Length of ctx_len_arr must match length of lens"
if n == 1:
return (
[0]
if len_arr[0] <= max_packed_inp_len and ctx_len_arr[0] <= max_packed_ctx_len
else []
)
# Create cumulative sum arrays for efficient range queries
cumsum_inp_len = np.cumsum(len_arr)
cumsum_ctx_len = np.cumsum(ctx_len_arr)
idx_pairs = []
i = 0
while i < n:
# Find the maximum j such that sum(lens[i:j+1]) <= max_packed_inp_len
start_sum_inp = cumsum_inp_len[i - 1] if i > 0 else 0
valid_ends_inp = (cumsum_inp_len[i:] - start_sum_inp) <= max_packed_inp_len
start_sum_ctx = cumsum_ctx_len[i - 1] if i > 0 else 0
valid_ends_ctx = (cumsum_ctx_len[i:] - start_sum_ctx) <= max_packed_ctx_len
valid_ends = valid_ends_inp & valid_ends_ctx
if not np.any(valid_ends):
# Single item exceeds max_packed_inp_len, skip it
logging.debug(
f"Skipping item {i} with input length {len_arr[i]} and context length {ctx_len_arr[i]}"
)
i += 1
continue
# Find the last valid index
max_valid_idx = i + np.where(valid_ends)[0][-1]
# Apply max_size constraint
if max_size > 0:
max_valid_idx = min(max_valid_idx, i + max_size - 1)
idx_pairs.append((i, max_valid_idx + 1))
i = max_valid_idx + 1
return idx_pairs
def pack_data_points_FA(
batch: dict[str, any],
) -> dict[str, np.ndarray]:
if not batch:
raise ValueError("Batch is empty")
# Pre-allocate lists with known sizes
total_ctx_len = sum(len(y) for x in batch["ctx_ids"] for y in x)
total_inp_len = sum(len(y) for x in batch["input_ids"] for y in x)
ctx_ids = np.empty(total_ctx_len, dtype=np.long)
ctx_position_ids = np.empty(total_ctx_len, dtype=np.long)
input_ids = np.empty(total_inp_len, dtype=np.long)
position_ids = np.empty(total_inp_len, dtype=np.long)
labels = np.empty(total_inp_len, dtype=np.long)
has_logprobs = "logprobs_vals" in batch
if has_logprobs:
sequences = zip(
batch["input_ids"],
batch["labels"],
batch["logprobs_vals"],
batch["logprobs_indices"],
)
n_labels = sum(len(y) for x in batch["logprobs_vals"] for y in x)
k = len(batch["logprobs_vals"][0][0][0]) # assuming all have same k
logprobs_vals = np.empty((n_labels, k), dtype=np.float32)
logprobs_indices = np.empty((n_labels, k), dtype=np.int32)
logits_offset = 0
else:
sequences = zip(batch["input_ids"], batch["labels"])
offset = 0
for sample in sequences:
input_ids_b, labels_b = sample[:2]
inp_start = offset
# compute position_ids for each sub-list in input_ids_b
local_start = inp_start
for ids_b in input_ids_b:
local_end = local_start + len(ids_b)
position_ids[local_start:local_end] = np.arange(len(ids_b), dtype=np.int32)
local_start = local_end
input_ids_b = concat_list(input_ids_b)
labels_b = concat_list(labels_b)
inp_len = len(input_ids_b)
inp_end = offset + inp_len
input_ids[inp_start:inp_end] = input_ids_b
labels[inp_start:inp_end] = labels_b
offset += inp_len
if has_logprobs:
logprobs_vals_b, logprobs_indices_b = sample[2:]
logprobs_vals_b = concat_list(logprobs_vals_b)
logprobs_indices_b = concat_list(logprobs_indices_b)
logits_len = len(logprobs_vals_b)
logprobs_vals[logits_offset : logits_offset + logits_len] = logprobs_vals_b
logprobs_indices[logits_offset : logits_offset + logits_len] = (
logprobs_indices_b
)
logits_offset += logits_len
ctx_offset = 0
for ctx_ids_b in batch["ctx_ids"]:
local_start = ctx_offset
for ctx_ids_b_item in ctx_ids_b:
local_end = local_start + len(ctx_ids_b_item)
ctx_position_ids[local_start:local_end] = np.arange(
len(ctx_ids_b_item), dtype=np.int32
)
local_start = local_end
ctx_ids_b = concat_list(ctx_ids_b)
ctx_len = len(ctx_ids_b)
ctx_start, ctx_end = ctx_offset, ctx_offset + ctx_len
ctx_ids[ctx_start:ctx_end] = ctx_ids_b
ctx_offset += ctx_len
out = {
"ctx_ids": ctx_ids,
"ctx_position_ids": ctx_position_ids,
"input_ids": input_ids,
"position_ids": position_ids,
"labels": labels,
}
if has_logprobs:
out["logprobs_vals"] = logprobs_vals
out["logprobs_indices"] = logprobs_indices
return out
def pack_batch(
batch: dict[str, any],
max_packed_inp_len: int,
max_packed_ctx_len: int,
max_packed_size: int = -1,
metadata_path: str = "",
) -> dict[str, any]:
need_flatten = check_is_iterable(batch["input_ids"][0][0])
assert need_flatten, (
f"Packing requires the input_ids to be nested "
f"(allowing multiple QAs per sample), but got {batch['input_ids'][0]}"
)
n_queries = [len(x) for x in batch["input_ids"]]
n_ctx_chunks = [len(x) for x in batch["ctx_ids"]]
inp_lens = [[len(y) for y in x] for x in batch["input_ids"]]
inp_count = len(inp_lens)
if "ctx_ids" not in batch:
raise ValueError("Batch must contain 'ctx_ids' and 'labels' keys")
# we do not pad so we can just take the length of the tokens
ctx_lens = [[len(y) for y in x] for x in batch["ctx_ids"]]
# Group indices
idx_pairs = pack_data_points_by_length(
inp_lens,
ctx_lens,
max_packed_inp_len,
max_packed_ctx_len,
max_packed_size,
)
# Pack groups
packed_batch = {
"ctx_ids": [],
"ctx_position_ids": [],
"input_ids": [],
"position_ids": [],
"labels": [],
"n_queries": [],
"n_ctx_chunks": [],
}
has_logprobs = "logprobs_vals" in batch
if has_logprobs:
packed_batch["logprobs_vals"] = []
packed_batch["logprobs_indices"] = []
packing_efficiency_ratios = []
ctx_packing_efficiency_ratios = []
for idx_pair in idx_pairs:
start_idx, end_idx = idx_pair[0], idx_pair[1]
group_items = {
"ctx_ids": batch["ctx_ids"][start_idx:end_idx],
"input_ids": batch["input_ids"][start_idx:end_idx],
"labels": batch["labels"][start_idx:end_idx],
}
if has_logprobs:
group_items["logprobs_vals"] = batch["logprobs_vals"][start_idx:end_idx]
group_items["logprobs_indices"] = batch["logprobs_indices"][
start_idx:end_idx
]
packed_item = pack_data_points_FA(group_items)
packed_batch["ctx_ids"].append(packed_item["ctx_ids"])
packed_batch["ctx_position_ids"].append(packed_item["ctx_position_ids"])
packed_batch["input_ids"].append(packed_item["input_ids"])
packed_batch["position_ids"].append(packed_item["position_ids"])
packed_batch["labels"].append(packed_item["labels"])
packed_batch["n_queries"].append(n_queries[start_idx:end_idx])
packed_batch["n_ctx_chunks"].append(n_ctx_chunks[start_idx:end_idx])
if has_logprobs:
packed_batch["logprobs_vals"].append(packed_item["logprobs_vals"])
packed_batch["logprobs_indices"].append(packed_item["logprobs_indices"])
if max_packed_inp_len > 0:
inp_efficiency = len(packed_item["input_ids"]) / max_packed_inp_len
packing_efficiency_ratios.append(inp_efficiency)
if max_packed_ctx_len > 0:
ctx_efficiency = len(packed_item["ctx_ids"]) / max_packed_ctx_len
ctx_packing_efficiency_ratios.append(ctx_efficiency)
# Calculate length statistics
packed_inp_lens_arr = np.array([len(x) for x in packed_batch["input_ids"]])
packed_ctx_lens_arr = np.array([len(x) for x in packed_batch["ctx_ids"]])
# Log performance statistics
avg_inp_packing_efficiency = (
np.mean(packing_efficiency_ratios) if packing_efficiency_ratios else 0
)
avg_ctx_packing_efficiency = (
np.mean(ctx_packing_efficiency_ratios) if ctx_packing_efficiency_ratios else 0
)
# Create packing statistics dictionary
packing_stats = {
"original_samples": inp_count,
"packed_samples": len(idx_pairs),
"avg_inp_packing_efficiency": float(avg_inp_packing_efficiency),
"avg_ctx_packing_efficiency": float(avg_ctx_packing_efficiency),
"input_ids_length_stats": {
"avg": float(np.mean(packed_inp_lens_arr)),
"std": float(np.std(packed_inp_lens_arr)),
"min": int(np.min(packed_inp_lens_arr)),
"max": int(np.max(packed_inp_lens_arr)),
},
"context_ids_length_stats": {
"avg": float(np.mean(packed_ctx_lens_arr)),
"std": float(np.std(packed_ctx_lens_arr)),
"min": int(np.min(packed_ctx_lens_arr)),
"max": int(np.max(packed_ctx_lens_arr)),
},
}
# Save to metadata_path if provided
if metadata_path:
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
with open(metadata_path, "w") as f:
json.dump(packing_stats, f, indent=4)
logging.debug(f"Packing stats:\n{pprint.pformat(packing_stats, indent=2)}")
return packed_batch

View file

@ -0,0 +1,206 @@
import logging
import random
from collections.abc import Callable
from typing import Any
from ctx_to_lora.data.definitions import CLOSED_QA_INTX_TEMPLATES, EVAL_INTX_TEMPLATES
from ctx_to_lora.utils import concat_list
logger = logging.getLogger()
def closed_qa_prompting(prompt: str):
template = random.choice(CLOSED_QA_INTX_TEMPLATES)
return template.format(input=prompt)
def chat_to_str(messages: list[dict[str, str]]):
return "Below is the chat history from the current user.\n\n" + "\n\n".join(
[
"Message from: {role}\n{content}".format(
**{**m, "role": m["role"].capitalize()}
)
for m in messages
]
)
def get_preprocessing_fn(
ds_name: str,
is_eval: bool,
) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""
Get preprocessing function for a specific dataset.
Args:
ds_name: Name of the dataset
Returns:
A preprocessing function that takes and returns a dictionary
"""
f = lambda x: x
if ds_name.startswith("self_gen") or ds_name.endswith("_compact"):
# already processed data, do nothing
return f
if "fw_qa_v2" in ds_name:
def f(sample):
# get questions/answers from all levels in the ds
q_cols = [col for col in sample.keys() if col.startswith("prompts_level")]
r_cols = [col for col in sample.keys() if col.startswith("responses_level")]
questions = concat_list([sample[col] for col in q_cols])
responses = concat_list([sample[col] for col in r_cols])
min_len = min(len(questions), len(responses))
if min_len == 0:
return {
"context": None,
"prompts": None,
"responses": None,
}
return {
"context": sample["context"],
"prompts": questions[:min_len],
"responses": responses[:min_len],
}
elif ds_name.startswith("longbench"):
def f(sample):
return {
"context": sample["context"],
"prompt": sample["input"],
"response": sample["answers"][0],
}
elif ds_name == "pwc" or ds_name == "pwc_tiny":
# original pwc
def f(sample):
return {
"context": sample["input"],
"prompt": sample["prompt"],
"response": sample["answer"],
}
elif ds_name == "squad":
# original squad
def f(sample):
q = sample["question"]
prompt = closed_qa_prompting(q) if not is_eval else q
return {
"context": sample["context"],
"prompt": prompt,
"response": sample["answers"]["text"][0],
}
elif ds_name == "squad_assistant_ctx":
def f(sample):
return {
"context": "You are a useful AI assistant.",
"prompt": sample["context"] + "\n\n" + sample["question"],
"response": sample["answers"]["text"][0],
}
elif ds_name == "squad_negative":
with open("data/gutenburg_sample.txt") as f:
gutenburg_sample = f.read()
def f(sample):
return {
"context": gutenburg_sample,
"prompt": sample["context"] + "\n\n" + sample["question"],
"response": sample["answers"]["text"][0],
}
elif ds_name == "squad_negative_no_passage":
with open("data/gutenburg_sample.txt") as f:
gutenburg_sample = f.read()
def f(sample):
return {
"context": gutenburg_sample,
"prompt": sample["question"],
"response": sample["answers"]["text"][0],
}
elif ds_name == "squad_assistant_ctx_no_passage":
def f(sample):
return {
"context": "You are a useful AI assistant.",
"prompt": sample["question"],
"response": sample["answers"]["text"][0],
}
elif ds_name == "drop":
def f(sample):
q = sample["question"]
prompt = closed_qa_prompting(q) if not is_eval else q
return {
"context": sample["passage"],
"prompt": prompt,
"response": sample["answers_spans"]["spans"][0],
}
elif ds_name == "ropes":
ctx_template = "{background}\n{situation}"
def f(sample):
response = sample["answers"]["text"][0]
bg_txt = sample["background"]
situation_txt = sample["situation"]
ctx = ctx_template.format(background=bg_txt, situation=situation_txt)
q = sample["question"]
q = closed_qa_prompting(q) if not is_eval else q
return {"context": ctx, "prompt": q, "response": response}
if is_eval and (ds_name in EVAL_INTX_TEMPLATES):
prompt_template = EVAL_INTX_TEMPLATES[ds_name]
def eval_intx_decorator(f):
def g(sample):
sample = f(sample)
assert "prompt" in sample, (
f"Expected 'prompt' in sample, got {sample.keys()}"
)
sample["prompt"] = prompt_template.format(input=sample["prompt"])
return sample
return g
f = eval_intx_decorator(f)
def maybe_convert_to_list(f):
def g(sample):
sample = f(sample)
if "prompt" in sample:
sample["prompts"] = [sample.pop("prompt")]
if "response" in sample:
sample["responses"] = [sample.pop("response")]
return sample
return g
f = maybe_convert_to_list(f)
if "self_gen" not in ds_name:
def strip_response(f):
def g(sample):
sample = f(sample)
if "responses" in sample and bool(sample["responses"]):
sample["responses"] = [
r.strip() if isinstance(r, str) else r
for r in sample["responses"]
]
return sample
return g
f = strip_response(f)
return f

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,85 @@
STOP_STRINGS = {
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
}
Q_GEN_SYSTEM_TEMPLATE = (
"You are a creative and helpful assistant.\n"
"You are tasked with generating questions for reading comprehension tests.\n"
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
"**DO NOT** hallucinate or make up information."
)
Q_GEN_PROMPT_TEMPLATE = (
"### Instructions ###\n"
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
"information provided in the context, not general questions that suit any context.\n\n"
"### Context ###\n"
"{context}\n\n\n"
"### Rules ###\n"
"Rules to follow when generating the questions:\n"
"1. The questions must be specific to the given context and fully answerable from information present in the given context.\n"
"2. Ask questions that are fact-seeking based on the information provided.\n"
"3. Make sure the questions are clear and unambiguous.\n"
"4. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the questions.\n"
"5. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
"6. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
"7. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
"8. Ignore typos, spacing, and grammatical errors in the context.\n\n"
"Rules to follow when generating the answers:\n"
"1. The answers must use the (implied) information provided in the context.\n"
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the answers.\n"
"3. Do not just copy words from the context. Answer the question in your own words.\n"
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
"Respond with {n_qa_pairs} question-answer pairs.\n"
"Always use proper grammar and punctuation.\n"
"Try to use different question forms and styles.\n"
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
"The question-answer pairs should be in the following format:\n"
"Question 1: {{question_1}}\n"
"Answer 1: {{answer_1}}\n"
"Question 2: {{question_2}}\n"
"Answer 2: {{answer_2}}\n"
"..."
)
Q_GEN_PROMPT_TEMPLATE_REPEAT = (
"### Instructions ###\n"
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
"information provided in the context, not general questions that suit any context.\n\n"
"### Context ###\n"
"{context}\n\n\n"
"### Example Question-Answer Pairs ###\n"
"{qa_pairs}\n\n\n"
"### Rules ###\n"
"Rules to follow when generating the questions:\n"
"1. The questions must be specific to the given context and fully answerable from information present in *or* implied from the given context.\n"
"2. The questions must *not* be redundant with the example questions-answer pairs provided.\n"
"3. You should prioritize fact-seeking questions. Consider reversal questions, e.g., asking 'What causes X to happen?' is valid when 'Y causes X' is presented in the context.\n"
"4. If all the facts in the context are already covered by the provided examples, you must generate *more complicated* questions that require reasoning beyond simple information retrieval.\nThis includes asking about information that can be inferred, requiring synthesizing information from multiple parts of the text, or understanding relationships between concepts, events, or individuals mentioned in the context. For example, if the context says 'The Eiffel Tower was completed in 1889 after 2 years of construction', you can ask 'When did the construction of the Eiffel Tower begin?'. Here's another example: if the context says 'Alice is Bob's mother. Bob is Charlie's Dad', you can ask 'Who is Charlie's grandmother?'.\n"
"5. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the questions.\n"
"6. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
"7. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
"8. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
"9. Ignore typos, spacing, and grammatical errors in the context.\n\n"
"Rules to follow when generating the answers:\n"
"1. The answers must use the (implied) information provided in the context.\n"
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
"the answers.\n"
"3. Do not just copy words from the context. Answer the question in your own words.\n"
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
"Respond with {n_qa_pairs} question-answer pairs.\n"
"Always use proper grammar and punctuation.\n"
"Try to use different question forms and styles.\n"
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
"The question-answer pairs should be in the following format:\n"
"Question 1: {{question_1}}\n"
"Answer 1: {{answer_1}}\n"
"Question 2: {{question_2}}\n"
"Answer 2: {{answer_2}}\n"
"..."
)

View file

@ -0,0 +1,16 @@
SELF_GEN_SYSTEM_MSG = "You are an honest and helpful assistant."
SELF_QA_INTX = (
"# System Instruction\n"
"- The information provided is up-to-date information and/or the user instruction.\n"
"- When the provided information is not relevant to the question, ***ignore*** it and answer the question based on your knowledge.\n"
"- If the provided information is related to the question, incorporate it in your response.\n"
"- If the provided information is an instruction, follow the instruction carefully.\n"
"\n---\n\n"
"# User Input\n"
)
PRE_CTX = "# Provided Information\n"
QA_PROMPT_TEMPLATE = PRE_CTX + "{context}\n\n---\n\n" + SELF_QA_INTX + "{question}"
PROMPT_TEMPLATE = "{context}\n\n{question}"

File diff suppressed because it is too large Load diff

163
src/ctx_to_lora/metrics.py Normal file
View file

@ -0,0 +1,163 @@
from collections import defaultdict
from collections.abc import Callable
import numpy as np
import torch
from rouge_score import rouge_scorer
from transformers import EvalPrediction
LENGTH_BINS = [
# finegrain bins
(0, 2**7 - 1),
(2**7, 2**8 - 1),
(2**8, 2**9 - 1),
# coarse bins
(0, 2**9 - 1),
(2**9, 2**10 - 1),
(2**10, 2**11 - 1),
(2**11, 2**12 - 1),
(2**12, 2**13 - 1),
(0, 2**13 - 1),
(2**13, 2**14 - 1),
(2**14, 2**15 - 1),
(2**15, float("inf")),
]
def get_length_bin(length: int):
"""Get the length bin for a given length."""
for i, (start, end) in enumerate(LENGTH_BINS):
if start <= length < end:
return (start, end)
def compute_rouge(pred_texts, label_texts):
out = defaultdict(list)
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
for pred_text, label_text in zip(pred_texts, label_texts):
scores = scorer.score(pred_text, label_text)
for k, v in scores.items():
out[f"{k}.f1"].append(v.fmeasure)
out_mean = dict()
for k in out:
out_mean[k] = np.mean(out[k])
return out_mean, out
@torch.inference_mode()
def compute_per_token_acc(shift_logits, shift_labels, valid_masks):
indices = torch.where(valid_masks)
acc = (shift_logits.argmax(-1) == shift_labels)[indices].float()
return {
"per_token_accs": acc.flatten().tolist(),
"n_per_token_accs": valid_masks.sum().item(),
}
@torch.inference_mode()
def compute_prefix_matching(shift_logits, shift_labels, valid_masks):
lengths = valid_masks.sum(dim=1)
is_wrong = (shift_logits.argmax(-1) != shift_labels) * valid_masks
is_correct = (shift_logits.argmax(-1) == shift_labels) * valid_masks
# NOTE: not reliable for multi-turn conversations
# ie, all tokens in the following user's turn will be correct
# still monotonically correlate with perf though
wrong_pos = torch.argmax(is_wrong, dim=1) - torch.argmax(valid_masks, dim=1)
perf = wrong_pos / lengths
# if all tokens are correct, set to 1
perf = torch.where(is_correct.sum(dim=1) == lengths, 1, perf)
return {
"prefix_matchings": perf.tolist(),
"n_prefix_matchings": valid_masks.shape[0],
}
@torch.inference_mode()
def compute_perplexity(shift_logits, shift_labels, valid_masks):
return {"perplexities_ph": [1], "n_perplexities_ph": 1}
class Evaluator:
def __init__(self, metric_fns: list[Callable]):
self.metric_fns = metric_fns
self.reset()
def reset(self):
self.accum_metrics = defaultdict(lambda: list((0,)))
self.count = defaultdict(lambda: list((0,)))
def update(self, shift_logits, shift_labels, valid_masks, lengths=None):
for metric_fn in self.metric_fns:
# overall metric
metric = metric_fn(shift_logits, shift_labels, valid_masks)
for k, v in metric.items():
key = k if not k.startswith("n_") else k[2:]
if k.startswith("n_"):
# prefix "n_" indicates the count of the metric
self.count[key].append(v)
else:
self.accum_metrics[key] += v
for start, end in LENGTH_BINS:
key_w_len = f"{key}_len_{start}-{end}"
if key_w_len not in self.accum_metrics:
# add key here so that it shows up in the output
self.accum_metrics[key_w_len] = [0]
self.count[key_w_len] = [0]
# split samples into length groups, calculate metric for each group
if lengths is not None:
for start, end in LENGTH_BINS:
logits, labels, masks = [], [], []
for logit, label, m, len in zip(
shift_logits, shift_labels, valid_masks, lengths
):
if isinstance(len, torch.Tensor):
len = len.item()
if start <= len < end:
logits.append(logit)
labels.append(label)
masks.append(m)
if not logits:
continue
metric = metric_fn(
torch.stack(logits), torch.stack(labels), torch.stack(masks)
)
for k, v in metric.items():
if k.startswith("n_"):
key = f"{k[2:]}_len_{start}-{end}"
self.count[key].append(v)
else:
key = f"{k}_len_{start}-{end}"
self.accum_metrics[key] += v
def compute(self):
# Get result across entire eval set
result = {
k: np.sum(v) / np.sum(self.count[k]) if len(v) > 1 else "None"
for k, v in self.accum_metrics.items()
}
# Reset batch statistics
self.reset()
return result
@torch.no_grad()
def compute_metrics(
eval_pred: EvalPrediction,
compute_result: bool,
evaluator: Evaluator,
) -> dict | None:
inputs = eval_pred.inputs
len_key = "ctx_ids_len" if "ctx_ids_len" in inputs else "input_ids_len"
lengths = inputs[len_key]
logits, labels = eval_pred.predictions, eval_pred.label_ids
shift_logits = logits[..., :-1, :]
shift_labels = labels[..., 1:]
valid_masks = torch.where(shift_labels != -100, 1, 0)
evaluator.update(shift_logits, shift_labels, valid_masks, lengths)
if compute_result:
return evaluator.compute()

View file

@ -0,0 +1,183 @@
import logging
import os
import torch
from peft import PeftModel
from peft import get_peft_config as _get_peft_config
from peft.utils import PeftType
from transformers import (
AutoModel,
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
Gemma3ForConditionalGeneration,
)
logger = logging.getLogger()
GEMMA_VISION_MODELS = [
"google/gemma-3-4b-it",
"google/gemma-3-12b-it",
"google/gemma-3-27b-it",
]
def check_is_vision_model(model_name):
return model_name in GEMMA_VISION_MODELS
def get_model_and_tokenizer(
model_name_or_path,
train,
requires_grad,
use_flash_attn=True,
peft_config=None,
model_kwargs=None,
tokenizer_kwargs=None,
use_q_lora=False,
device="cuda",
dtype=torch.bfloat16,
):
model = get_model(
model_name_or_path,
train,
requires_grad,
use_flash_attn,
peft_config,
model_kwargs,
use_q_lora,
device,
dtype,
)
tokenizer = get_tokenizer(model_name_or_path, tokenizer_kwargs, peft_config, train)
model.config.pad_token_id = tokenizer.pad_token_id
if getattr(model, "generation_config", None):
model.generation_config.pad_token_id = tokenizer.pad_token_id
return model, tokenizer
def get_tokenizer(
model_name_or_path, tokenizer_kwargs=None, peft_config=None, train=False
):
padding_side = "left" if not train else "right"
truncation_side = "left"
if tokenizer_kwargs is None:
tokenizer_kwargs = {}
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
add_bos_tokens=False,
add_eos_tokens=False,
padding_side=padding_side,
truncation_side=truncation_side,
trust_remote_code=True,
**tokenizer_kwargs,
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
template_path = f"chat_templates/{model_name_or_path}.jinja"
if not os.path.exists(template_path):
logger.warning(
f"Chat template not found at {template_path}. Using default template."
)
return tokenizer
logger.info(f"Using chat template from {template_path}")
chat_template = open(template_path).read()
chat_template = chat_template.replace(" ", "").replace("\n", "")
tokenizer.chat_template = chat_template
return tokenizer
def get_model(
model_name_or_path,
train,
requires_grad,
use_flash_attn=True,
peft_config=None,
model_kwargs=None,
use_q_lora=False,
device="cuda",
dtype=torch.bfloat16,
):
model_init_kwargs = dict(
pretrained_model_name_or_path=model_name_or_path,
device_map=device,
torch_dtype=dtype,
trust_remote_code=True,
attn_implementation="eager",
use_cache=None,
)
is_vision_model = check_is_vision_model(model_name_or_path)
if model_kwargs is not None:
model_init_kwargs.update(model_kwargs)
is_bidir_model = (
"bert" in model_name_or_path.lower() or "gte" in model_name_or_path.lower()
)
if use_flash_attn:
if "gte" not in model_name_or_path:
model_init_kwargs["attn_implementation"] = "flash_attention_2"
elif "gte" in model_name_or_path:
model_init_kwargs["attn_implementation"] = "sdpa"
if is_vision_model:
# always use sdpa for vision models
# model_init_kwargs["attn_implementation"] = "sdpa"
model_init_kwargs.pop("use_cache")
elif is_bidir_model:
model_init_kwargs["torch_dtype"] = torch.float32
model_init_kwargs.pop("use_cache")
if use_q_lora:
# https://huggingface.co/blog/4bit-transformers-bitsandbytes
# https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing
# see bitsandbytes for the quantization implementation https://github.com/bitsandbytes-foundation/bitsandbytes
# see unsloth https://huggingface.co/docs/trl/v0.7.11/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth
# does work currently bc it modifies the forward pass call of Linear
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
model_init_kwargs["quantization_config"] = bnb_config
logger.debug(f"Model init kwargs: {model_init_kwargs}")
if not is_vision_model:
if is_bidir_model:
model = AutoModel.from_pretrained(**model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(**model_init_kwargs)
else:
model = Gemma3ForConditionalGeneration.from_pretrained(**model_init_kwargs)
model = model.language_model
if peft_config is not None:
model = PeftModel(model, peft_config)
model.train(train)
for name, param in model.named_parameters():
param.requires_grad = requires_grad
return model
def get_lora_config(model_dir, **kwargs):
if "target_modules" not in kwargs or kwargs["target_modules"] is None:
logger.info("No target modules specified for LoRA.")
return None
r = kwargs.pop("lora_r", 8)
peft_conf_kwargs = dict(
r=r,
peft_type=PeftType.LORA,
base_model_name_or_path=model_dir,
task_type="CAUSAL_LM",
lora_dropout=kwargs.get("lora_dropout", 0.0),
lora_alpha=r ** (3 / 2) * 2,
)
peft_conf_kwargs.update(kwargs)
peft_config = _get_peft_config(peft_conf_kwargs)
return peft_config

View file

View file

@ -0,0 +1,211 @@
import logging
from dataclasses import dataclass
from enum import Enum
from einops import rearrange, repeat, unpack
from jaxtyping import Float, Integer
from torch import Tensor, nn
from transformers import (
PretrainedConfig,
PreTrainedModel,
)
from ctx_to_lora.configs import (
AggregatorArguments,
)
from ctx_to_lora.modeling.idefics2 import Idefics2Perceiver, Idefics2PerceiverConfig
from ctx_to_lora.pooling import POOL_FN
from ctx_to_lora.utils import (
get_num_layers,
)
logger = logging.getLogger()
class AGGREGATOR_TYPE(str, Enum):
POOLER = "pooler"
PERCEIVER = "perceiver"
@dataclass
class AggregatorConfig:
aggregator_type: AGGREGATOR_TYPE
num_layers: int
num_modules: int
num_extra_modules: int
output_size: int
feature_size: int
# pooler
pooling_type: POOL_FN
# perceiver
num_latent_factor: int
lora_r: int
per_rank_gen: bool
n_latent_queries: int
num_blocks: int
num_self_attn_per_block: int
shared_weights: bool
layer_to_layer_ctx_encoder: bool
def get_aggregator_config(
model: PreTrainedModel,
ctx_encoder_model_config: PretrainedConfig,
layer_to_layer_ctx_encoder: bool,
output_size: int,
num_modules: int,
num_extra_modules: int,
lora_r: int,
per_rank_gen: bool,
aggregator_args: AggregatorArguments,
):
return AggregatorConfig(
feature_size=ctx_encoder_model_config.hidden_size,
output_size=output_size,
num_layers=get_num_layers(model),
num_modules=num_modules,
num_extra_modules=num_extra_modules,
lora_r=lora_r,
per_rank_gen=per_rank_gen,
layer_to_layer_ctx_encoder=layer_to_layer_ctx_encoder,
**vars(aggregator_args),
)
class Perceiver(nn.Module):
"""perceiver w/ bottleneck size = n_modules * n_layers"""
def __init__(
self,
feature_size,
output_size,
num_layers,
num_modules,
num_extra_modules,
per_rank_gen,
lora_r,
num_latent_factor, # unused
layer_to_layer_ctx_encoder,
n_latent_queries,
*args,
**kwargs,
):
super().__init__()
assert num_extra_modules == 0
self.num_layers = num_layers
self.num_modules = num_modules
self.num_extra_modules = num_extra_modules
self.per_rank_gen = per_rank_gen
self.r = lora_r if self.per_rank_gen else 1
n_output_queries = num_layers * (num_modules * self.r + num_extra_modules)
self.layer_to_layer = layer_to_layer_ctx_encoder
if self.layer_to_layer:
n_output_queries = num_modules * self.r + num_extra_modules
self.config = Idefics2PerceiverConfig(
input_size=feature_size,
num_blocks=kwargs["num_blocks"],
num_self_attn_per_block=kwargs["num_self_attn_per_block"],
shared_weights=kwargs["shared_weights"],
n_latents=n_latent_queries,
intermediate_size_factor=4,
hidden_size=output_size,
attn_implementation="flash_attention_2",
)
self.decoder_config = Idefics2PerceiverConfig(
input_size=output_size,
num_blocks=1,
num_self_attn_per_block=0,
shared_weights=False,
n_latents=n_output_queries,
intermediate_size_factor=4,
hidden_size=output_size,
attn_implementation="flash_attention_2",
)
self.perceiver = Idefics2Perceiver(self.config, self.decoder_config)
self.iterative_mode = False
def enable_iterative_mode(self, x: bool):
self.iterative_mode = x
def forward(
self,
ctx_features: Float[Tensor, "bs seq_len feature_dim"]
| Float[Tensor, "bs x seq_len feature_dim"],
ctx_attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
ctx_position_ids: Integer[Tensor, "bs seq_len"] | None = None,
):
if self.layer_to_layer and not self.iterative_mode:
if ctx_attn_mask is not None:
ctx_attn_mask = repeat(
ctx_attn_mask,
"bs seq_len -> (num_layers bs) seq_len",
num_layers=self.num_layers,
)
ctx_features = rearrange(
ctx_features,
"bs num_layers seq_len feature_dim -> (num_layers bs) seq_len feature_dim",
)
if ctx_position_ids is not None:
ctx_position_ids = repeat(
ctx_position_ids,
"1 seq_len -> 1 (num_layers seq_len)",
num_layers=self.num_layers,
)
ctx_features = rearrange(
ctx_features,
"1 num_layers seq_len feature_dim -> 1 (num_layers seq_len) feature_dim",
)
x = self.perceiver(ctx_features, ctx_attn_mask, ctx_position_ids)
if self.layer_to_layer and self.iterative_mode:
lora_x = rearrange(
x,
"bs (n_modules r) d -> bs n_modules r d",
n_modules=self.num_modules,
r=self.r,
)
return lora_x, None
if self.layer_to_layer:
per_layer_size = self.num_modules * self.r + self.num_extra_modules
x = rearrange(
x,
"(num_layers bs) (per_layer_sz) d -> bs (num_layers per_layer_sz) d",
num_layers=self.num_layers,
per_layer_sz=per_layer_size,
)
lora_x, extra_x = unpack(
x,
[
[self.num_layers * self.num_modules * self.r],
[self.num_layers * self.num_extra_modules],
],
"bs * feature_dim",
)
lora_x = rearrange(
lora_x,
"bs (n_layers n_modules r) d -> bs n_layers n_modules r d",
n_modules=self.num_modules,
n_layers=self.num_layers,
r=self.r,
)
if not self.per_rank_gen:
lora_x = lora_x.squeeze(3)
extra_x = rearrange(
extra_x,
"bs (n_layers n_extra_modules) d -> bs n_layers n_extra_modules d",
n_extra_modules=self.num_extra_modules,
n_layers=self.num_layers,
)
return lora_x, extra_x
AGGREGATOR_CLS = {
AGGREGATOR_TYPE.PERCEIVER: Perceiver,
}

View file

@ -0,0 +1,632 @@
import gc
import json
import logging
import os
import re
from math import ceil
from typing import Any
import torch
from jaxtyping import Integer
from peft import PeftModel
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
from torch import Tensor, nn
from transformers import PreTrainedModel
from transformers.modeling_outputs import ModelOutput
from ctx_to_lora.data.definitions import CTX_AFFIXES
from ctx_to_lora.data.q_generation_template import (
Q_GEN_PROMPT_TEMPLATE,
Q_GEN_PROMPT_TEMPLATE_REPEAT,
Q_GEN_SYSTEM_TEMPLATE,
STOP_STRINGS,
)
from ctx_to_lora.data.self_gen_template import SELF_QA_INTX
from ctx_to_lora.utils import log_num_train_params
logger = logging.getLogger()
def get_q_gen_prompt(context, n_qa_pairs):
prompt = Q_GEN_PROMPT_TEMPLATE.format(context=context, n_qa_pairs=n_qa_pairs)
return prompt
def get_q_gen_prompt_repeat(context, qa_pairs, n_qa_pairs):
example_qa_pairs = ""
for i, (q, a) in enumerate(qa_pairs, 1):
example_qa_pairs += f"Question {i}: {q}\nAnswer {i}: {a}\n"
prompt = Q_GEN_PROMPT_TEMPLATE_REPEAT.format(
context=context,
qa_pairs=example_qa_pairs,
n_qa_pairs=n_qa_pairs,
)
return prompt
def check_should_skip(txt: str, vllm_model: str) -> bool:
"""Check if the response should be skipped based on stop strings."""
for stop in STOP_STRINGS[vllm_model]:
if stop in txt[-len(stop) :]:
return (txt.split(stop)[0], False) # Found a valid stop string
return (txt, True) # No valid stop string found, skip this response
def postprocess_qa_pairs(res_txt: str):
"""
Postprocesses the QA pairs from the response text.
Args:
res_txt: The response text.
n_qa_pairs: The number of QA pairs.
Returns:
A tuple of two lists, the first containing the questions and the second containing the answers.
"""
# capture everything after each "Question {number}:" until "Answer"
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
questions = re.findall(q_pattern, res_txt, flags=re.S)
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
answers = re.findall(a_pattern, res_txt, flags=re.S)
if len(questions) != len(answers):
print(f"Warning---number of questions and answers do not match")
print(f"Number of questions: {len(questions)}")
print(f"Number of answers: {len(answers)}")
out_q = []
out_a = []
n_skips = 0
if (len(questions) > 0) and (len(answers) > 0):
n_gen_pairs = min(len(questions), len(answers))
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
for i in range(n_gen_pairs):
response = answers[i].strip()
question = questions[i].strip()
if not response or not question:
print(f"Skipping empty question or answer at index {i}")
continue
if (not has_left_over) and (i == n_gen_pairs - 1):
response, skip = check_should_skip(response, "google/gemma-3-12b-it")
if skip:
print(f"Skipping due to missing stop string")
n_skips += 1
continue
out_q.append(question.strip())
out_a.append(response.strip())
print(f"Skipped {n_skips} responses due to missing stop strings")
return out_q, out_a
def build_messages(ctx_text: str, level: int, example_qa_pairs: list = None):
messages = [
{"role": "system", "content": Q_GEN_SYSTEM_TEMPLATE},
{
"role": "user",
"content": get_q_gen_prompt(ctx_text, 5)
if level == 0
else get_q_gen_prompt_repeat(ctx_text, example_qa_pairs, 5),
},
]
return messages
def get_shifted_label_pos(labels):
pos = torch.where(labels != -100)
# (batch_idx, token_idx)
return (pos[0], pos[1] - 1)
def logits_at_positions(outputs: ModelOutput, pos) -> Tensor:
logits = outputs.logits
return logits[pos[0], pos[1]]
def ctx_inp_split(
ctx_inp_ids, ctx_inp_sep_seq, pad_token_id, prefix_tokens=None, padding_side="right"
):
# Split each row in ctx_inp_ids at the first occurrence of ctx_inp_sep_seq
# Return the part after the separator for each row
batch_size = ctx_inp_ids.size(0)
sep_len = ctx_inp_sep_seq.size(0)
out_inp = []
out_ctx = []
for i in range(batch_size):
row = ctx_inp_ids[i]
# Find where the separator starts
for j in range(row.size(0) - sep_len + 1):
if torch.equal(row[j : j + sep_len], ctx_inp_sep_seq):
out_ctx.append(row[:j])
if prefix_tokens is not None:
out_inp.append(
torch.cat([prefix_tokens, row[j + sep_len :]], axis=-1)
)
else:
out_inp.append(row[j + sep_len :])
break
else:
# If separator not found
raise ValueError(f"Separator sequence not found in row {i}")
out_inp = torch.nn.utils.rnn.pad_sequence(
out_inp, batch_first=True, padding_value=pad_token_id, padding_side=padding_side
)
out_ctx = torch.nn.utils.rnn.pad_sequence(
out_ctx, batch_first=True, padding_value=pad_token_id, padding_side=padding_side
)
return out_ctx, out_inp
def get_peft_layers(model, peft_config):
out = []
for module_name, module in model.named_modules():
if not check_target_module_exists(peft_config, module_name):
continue
if not isinstance(module, BaseTunerLayer):
continue
# support just Linear layer for now
# all modules should be a leave module that is Linear layer
assert isinstance(module.base_layer, nn.Linear), (
"all modules should be a leave module that is Linear layer"
)
# this should always pass
name = module_name.split(".")[-1]
assert name in peft_config.target_modules
out.append(module)
return out
class CtxDistillModel(nn.Module):
def __init__(
self,
base_model: PeftModel,
prefix_tokens: Integer[Tensor, "n"],
ctx_inp_sep_seq: Integer[Tensor, "m"],
pad_token_id: int,
update_iterations: int,
reset: bool = True,
tokenizer=None,
q_model: PreTrainedModel | None = None,
q_tokenizer=None,
reprompt_ctx: bool = False,
lora_save_dir: str | None = None,
save_after_distill: bool = True,
q_gen_rounds: int = 4,
batch_size: int = 16,
):
super().__init__()
self.register_module("base_model", base_model)
self.register_module("q_model", q_model)
self.register_buffer("prefix_tokens", prefix_tokens)
self.register_buffer("ctx_inp_sep_seq", ctx_inp_sep_seq)
self.tokenizer = tokenizer
self.q_tokenizer = q_tokenizer
self.pad_token_id = pad_token_id
self.update_iterations = update_iterations
self.reprompt_ctx = reprompt_ctx
self.reset = reset
self.device = base_model.device
self.to(self.device)
self.q_gen_rounds = q_gen_rounds
# New save options
self.lora_save_dir = lora_save_dir
self.save_after_distill = save_after_distill
self.peft_config = base_model.peft_config["default"]
self.adapter_name = "default"
self.base_model.set_adapter("default")
for layer in get_peft_layers(self.base_model, self.peft_config):
for name, p in layer.named_parameters():
if "lora_A" in name or "lora_B" in name:
p.requires_grad = True
log_num_train_params(self.base_model)
self._init_optim()
# Mini-batch size for distillation updates
self.batch_size = batch_size
@property
def generation_config(self):
return self.base_model.generation_config
def _init_optim(self):
self.optimizer = torch.optim.AdamW(
[
p
for l in get_peft_layers(self.base_model, self.peft_config)
for p in l.parameters()
if p.requires_grad
],
lr=1e-4,
)
def reset_lora(self):
print("Resetting LoRA")
for layer in get_peft_layers(self.base_model, self.peft_config):
layer.reset_lora_parameters(self.adapter_name, init_lora_weights=True)
self._init_optim()
def save_lora(self):
"""
Save current LoRA adapter in PEFT format plus a lightweight JSON summary
for easy human inspection/manipulation.
"""
if self.lora_save_dir is None:
return
os.makedirs(self.lora_save_dir, exist_ok=True)
# Standard PEFT save (produces adapter_config.json + adapter_model.bin / safetensors)
self.base_model.save_pretrained(self.lora_save_dir)
# Human-readable summary of LoRA parameter shapes
summary = {
name: list(p.shape)
for name, p in self.base_model.named_parameters()
if ("lora_A" in name or "lora_B" in name) and p.requires_grad
}
with open(os.path.join(self.lora_save_dir, "lora_summary.json"), "w") as f:
json.dump(summary, f, indent=2)
print(f"LoRA adapter saved to {self.lora_save_dir}")
@torch.enable_grad()
def _distill_context(
self,
ctx_inp_res_ids: Integer[Tensor, "bs ctx_inp_length"],
ctx_inp_res_attention_mask: Integer[Tensor, "bs ctx_inp_length"],
teacher_labels: Integer[Tensor, "bs ctx_inp_length"],
inp_res_ids: Integer[Tensor, "bs inp_length"],
inp_res_attention_mask: Integer[Tensor, "bs inp_length"],
student_labels: Integer[Tensor, "bs inp_length"],
):
# Implements KD-style loss by computing teacher (with context) and student (no context)
# log-probs locally, using mini-batches for updates.
was_training = self.training
self.train()
num_samples = ctx_inp_res_ids.size(0)
mb = self.batch_size
num_batches = ceil(num_samples / mb)
total_steps = max(self.update_iterations * num_batches, 1)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=total_steps, eta_min=0.0
)
print(
f"Starting context distillation for {self.update_iterations} epochs, "
f"mini-batch size {mb} ({num_batches} batches/epoch), "
f"cosine LR schedule over {total_steps} steps"
)
for epoch in range(self.update_iterations):
# Shuffle order each epoch for SGD
perm = torch.randperm(num_samples, device=self.device)
epoch_loss = 0.0
for b in range(num_batches):
start = b * mb
end = min(start + mb, num_samples)
indices = perm[start:end]
b_ctx_ids = ctx_inp_res_ids[indices]
b_ctx_am = ctx_inp_res_attention_mask[indices]
b_teacher_labels = teacher_labels[indices]
b_inp_ids = inp_res_ids[indices]
b_inp_am = inp_res_attention_mask[indices]
b_student_labels = student_labels[indices]
# Compute teacher distribution (top-k) for this mini-batch
with torch.no_grad(), self.base_model.disable_adapter():
t_pos = get_shifted_label_pos(b_teacher_labels)
teacher_outputs = self.base_model(
b_ctx_ids, attention_mask=b_ctx_am
)
teacher_logits = logits_at_positions(teacher_outputs, t_pos)
K = 16
topk_vals, topk_idx = teacher_logits.topk(K, dim=-1)
teacher_denom = torch.logsumexp(
teacher_logits.float(), dim=-1, keepdim=True
)
teacher_p = (topk_vals - teacher_denom).exp().detach() # [N, K]
# Student forward and update for this mini-batch
self.optimizer.zero_grad()
s_pos = get_shifted_label_pos(b_student_labels)
student_outputs = self.base_model(b_inp_ids, attention_mask=b_inp_am)
student_logits = logits_at_positions(student_outputs, s_pos)
student_denom = torch.logsumexp(
student_logits.float(), dim=-1, keepdim=True
)
selected_student_logits = student_logits.gather(-1, topk_idx)
student_logq = selected_student_logits - student_denom # [N, K]
token_losses = -(teacher_p * student_logq).sum(dim=-1) # [N]
loss = token_losses.mean()
loss.backward()
self.optimizer.step()
scheduler.step()
epoch_loss += loss.detach().item()
cur_lr = self.optimizer.param_groups[0]["lr"]
print(
f"Epoch {epoch + 1}/{self.update_iterations}, "
f"batch {b + 1}/{num_batches}: loss={epoch_loss / num_batches:.4f}, lr={cur_lr:.6e}"
)
if not was_training:
self.eval()
def generate_questions(self, *args, **kwargs):
questions = self.q_model.generate(*args, **kwargs)
return questions
def teacher_generate(self, *args, **kwargs):
# rename for separate timing
return self.base_model.generate(*args, **kwargs)
def student_generate(self, *args, **kwargs):
# rename for separate timing
return self.base_model.generate(*args, **kwargs)
def get_lora_state(self, clone: bool = True):
"""
Return a dict of current LoRA parameter tensors.
clone=True returns detached cloned tensors (safe to store).
"""
return {
name: (p.detach().clone() if clone else p)
for name, p in self.base_model.named_parameters()
if ("lora_A" in name or "lora_B" in name)
}
def generate(
self,
*model_inputs_args: Any,
distill_only: bool = False,
**model_inputs_kwargs: dict[str, Any],
):
if self.reset:
self.reset_lora()
# teacher tokens
orig_ctx_inp_ids = model_inputs_kwargs.pop("input_ids")
ctx_inp_ids = orig_ctx_inp_ids.clone()
_, orig_inp_ids = ctx_inp_split(
ctx_inp_ids,
self.ctx_inp_sep_seq,
self.pad_token_id,
self.prefix_tokens,
padding_side="left",
)
ctx_inp_attention_mask = model_inputs_kwargs.pop("attention_mask")
if self.q_model is not None:
self.q_model.to(self.base_model.device)
# Extract context-only portion after separator (remove prefix tokens from first row)
ctx_ids_full, _ = ctx_inp_split(
ctx_inp_ids, self.ctx_inp_sep_seq, self.pad_token_id
) # [bs, var_len]
ctx_ids = ctx_ids_full[0, len(self.prefix_tokens) :]
ctx_txt = self.tokenizer.decode(ctx_ids, skip_special_tokens=True)
questions = []
answers = []
# Build multiple instruction variants
for lvl in range(self.q_gen_rounds):
messages = build_messages(
ctx_txt, lvl, zip(questions, answers) if lvl > 0 else None
)
q_inputs = self.q_tokenizer.apply_chat_template(
messages,
tokenize=True,
add_special_tokens=False,
padding=False,
truncation=False,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
q_inputs = {k: v.to(self.q_model.device) for k, v in q_inputs.items()}
with torch.no_grad():
question_outputs = self.generate_questions(
input_ids=q_inputs["input_ids"],
attention_mask=q_inputs["attention_mask"],
max_new_tokens=1024,
do_sample=False,
temperature=0.0,
eos_token_id=106, # <end_of_turn> for gemma-3-12b-it
)
# Slice off the prompt portion
gen_only = question_outputs[:, q_inputs["input_ids"].shape[-1] :]
res = self.q_tokenizer.batch_decode(gen_only, skip_special_tokens=False)
gen_q_list, gen_a_list = postprocess_qa_pairs(res[0])
questions += gen_q_list
answers += gen_a_list
if len(questions) == 0:
# when q_model refuses to provide questions
# only happens with sample 116 in longbench/multifieldqa_en_e
# in this case cd just doesn't work, we fall back to zero-shot answer
print(f"Warning---no questions generated, skipping distillation")
attention_mask = torch.where(
orig_inp_ids != self.pad_token_id, 1, 0
).long()
return self.student_generate(
orig_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
)
ctx_inp_messages = [
[{"role": "user", "content": f"{ctx_txt}\n\n{SELF_QA_INTX}\n\n{q}"}]
for q in questions
]
encoded_ctx_inp = self.tokenizer.apply_chat_template(
ctx_inp_messages,
tokenize=True,
add_special_tokens=False,
padding=True,
truncation=False,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
encoded_ctx_inp = {k: v.to(self.device) for k, v in encoded_ctx_inp.items()}
ctx_inp_ids = encoded_ctx_inp["input_ids"]
ctx_inp_attention_mask = encoded_ctx_inp["attention_mask"]
self.q_model.to("cpu")
gc.collect()
torch.cuda.empty_cache()
# sample responses first
ctx_inp_res_ids = self.teacher_generate(
ctx_inp_ids,
attention_mask=ctx_inp_attention_mask,
**model_inputs_kwargs,
)
ctx_inp_res_attention_mask = torch.where(
ctx_inp_res_ids != self.pad_token_id, 1, 0
).long()
ctx_inp_res_txt = self.tokenizer.batch_decode(ctx_inp_res_ids)
bs = ctx_inp_ids.shape[0]
res_len = ctx_inp_res_ids.shape[-1] - ctx_inp_ids.shape[-1]
res_ids = ctx_inp_res_ids[:, -res_len:] # correct
pads = torch.full_like(ctx_inp_ids, self.pad_token_id)
teacher_labels = torch.cat([pads, res_ids], dim=-1)
teacher_labels = torch.where(
teacher_labels != self.pad_token_id, teacher_labels, -100
)
# student tokens
_, inp_res_ids = ctx_inp_split(
ctx_inp_res_ids,
self.ctx_inp_sep_seq,
self.pad_token_id,
self.prefix_tokens,
padding_side="left",
)
inp_res_attention_mask = torch.where(
inp_res_ids != self.pad_token_id, 1, 0
).long()
student_labels = inp_res_ids.clone()
student_labels[:, :-res_len] = -100
student_labels = torch.where(
student_labels != self.pad_token_id, student_labels, -100
)
self._distill_context(
ctx_inp_res_ids,
ctx_inp_res_attention_mask,
teacher_labels,
inp_res_ids,
inp_res_attention_mask,
student_labels,
)
# Save LoRA after distillation if requested
if distill_only:
return self.get_lora_state()
model_inputs_kwargs.pop("attention_mask", None)
model_inputs_kwargs.pop("input_ids", None)
if self.reprompt_ctx:
attention_mask = torch.where(orig_ctx_inp_ids != self.pad_token_id, 1, 0)
model_outputs = self.student_generate(
orig_ctx_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
)
else:
attention_mask = torch.where(orig_inp_ids != self.pad_token_id, 1, 0).long()
model_outputs = self.student_generate(
orig_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
)
return model_outputs
if __name__ == "__main__":
from ctx_to_lora.data.processing import load_and_process_dataset
from ctx_to_lora.model_loading import get_lora_config, get_model_and_tokenizer
model_name = "google/gemma-2-2b-it"
q_model_name = "google/gemma-3-12b-it"
peft_config = get_lora_config(
model_name, r=8, target_modules=["down_proj"], lora_dropout=0.0
)
peft_config.lora_alpha = 16
model, tokenizer = get_model_and_tokenizer(
model_name,
train=False,
requires_grad=False,
peft_config=peft_config,
)
q_model, q_tokenizer = get_model_and_tokenizer(
q_model_name,
train=False,
requires_grad=False,
peft_config=peft_config,
)
ds = load_and_process_dataset("pwc", split="train", num_proc=8)
ctx = ds[0]["context"]
inp = ds[1]["prompts"][0]
# Build a simple context/input pair separated by a unique token sequence
sep_text = SELF_QA_INTX
# ctx = "# Provided Information\nMy name is Tan."
# ctx = "# Provided Info"
prompt = f"{ctx}\n\n{sep_text}\n\n{inp}"
messages = [{"role": "user", "content": prompt}]
encoded = tokenizer.apply_chat_template(
messages,
return_tensors="pt",
return_dict=True,
add_generation_prompt=True,
)
# encoded = tokenizer(prompt, return_tensors="pt")
encoded = {k: v.to(model.device) for k, v in encoded.items()}
prefix_tokens = CTX_AFFIXES[model_name]["prefix"]
prefix_tokens = torch.tensor(prefix_tokens, dtype=torch.long)
sep_ids = (
tokenizer(sep_text.strip("\n"), add_special_tokens=False, return_tensors="pt")
.input_ids[0]
.to(model.device)
)
cd_model = CtxDistillModel(
base_model=model,
prefix_tokens=prefix_tokens,
ctx_inp_sep_seq=sep_ids,
pad_token_id=tokenizer.pad_token_id,
update_iterations=100,
q_model=q_model,
q_tokenizer=q_tokenizer,
tokenizer=tokenizer,
reprompt_ctx=False,
lora_save_dir="./saved_lora_adapter", # example save path
save_after_distill=True,
)
with torch.no_grad():
for _ in range(1):
base_model_res = model.generate(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
max_new_tokens=256,
do_sample=False,
)
print(
f"Base model response:{tokenizer.batch_decode(base_model_res, skip_special_tokens=False)}"
)
outputs = cd_model.generate(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
max_new_tokens=256,
do_sample=False,
)
print(
f"Student response: {tokenizer.batch_decode(outputs, skip_special_tokens=False)}"
)

View file

@ -0,0 +1,158 @@
import logging
from contextlib import contextmanager
from enum import Enum
import torch
from torch import nn
from transformers import PreTrainedModel
from ctx_to_lora.configs import CtxEncoderArguments
from ctx_to_lora.utils import get_base_model
logger = logging.getLogger()
@contextmanager
def early_exit(base_model: PreTrainedModel, exit_layer: int):
try:
layers = base_model.layers
base_model.layers = layers[:exit_layer]
yield base_model
finally:
base_model.layers = layers
@contextmanager
def maybe_add_batch_dim(kwargs):
try:
batched_input = False
batched_attn_mask = False
if (
"input_ids" in kwargs
and kwargs["input_ids"] is not None
and len(kwargs["input_ids"].shape) == 1
):
kwargs["input_ids"] = kwargs["input_ids"].unsqueeze(0)
batched_input = True
if (
"attention_mask" in kwargs
and kwargs["attention_mask"] is not None
and isinstance(kwargs["attention_mask"], torch.Tensor)
and len(kwargs["attention_mask"].shape) == 1
):
kwargs["attention_mask"] = kwargs["attention_mask"].unsqueeze(0)
batched_attn_mask = True
yield batched_input, batched_attn_mask
finally:
if batched_input:
kwargs["input_ids"] = kwargs["input_ids"].squeeze(0)
if batched_attn_mask:
kwargs["attention_mask"] = kwargs["attention_mask"].squeeze(0)
class EarlyExit(nn.Module):
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
super().__init__()
base_model = get_base_model(base_model)
if "gte" in base_model.config.name_or_path:
base_model.encoder.layer = base_model.encoder.layer[: config.layer_idx]
else:
base_model.layers = base_model.layers[: config.layer_idx]
self.base_model = base_model
@property
def config(self):
return self.base_model.config
@torch.no_grad()
def forward(self, **kwargs):
model_outputs = self.base_model(**kwargs)
return model_outputs.last_hidden_state
class EmbeddingOnly(nn.Module):
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
super().__init__()
self.base_model = base_model
@property
def config(self):
return self.base_model.config
@torch.no_grad()
def forward(self, **kwargs):
kwargs["output_hidden_states"] = True # Force output of hidden states
outputs = self.base_model(**kwargs)
# Return the embeddings only
return outputs.hidden_states[0] # The first hidden state is the embeddings
class PerLayerActivations(nn.Module):
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
super().__init__()
self.keep_lm_head = getattr(config, "keep_lm_head", False)
if not self.keep_lm_head:
base_model = get_base_model(base_model) # remove lm head
else:
base_model.lm_head = nn.Identity()
# -1 to remove last attn block
if config.ctx_encoder_last_layer is not None:
last_layer = config.ctx_encoder_last_layer - 1
else:
last_layer = -1
if self.keep_lm_head:
base_model.model.layers = base_model.model.layers[:last_layer]
else:
base_model.layers = base_model.layers[:last_layer]
self.base_model = base_model
@property
def config(self):
return self.base_model.config
def get_input_embeddings(self):
return self.base_model.get_input_embeddings()
def set_input_embeddings(self, value):
self.base_model.set_input_embeddings(value)
def get_output_embeddings(self):
return self.base_model.get_output_embeddings()
def set_output_embeddings(self, new_embeddings):
self.base_model.set_output_embeddings(new_embeddings)
def set_decoder(self, decoder):
self.base_model.set_decoder(decoder)
def get_decoder(self):
return self.base_model.get_decoder()
@torch.no_grad()
def forward(self, **kwargs):
kwargs["output_hidden_states"] = True # Force output of hidden states
outputs = self.base_model(**kwargs)
# Return all layers' activations except the last one
# from embeddings to the input of the last attn block
# Shape: (batch_size, num_layers, seq_len, hidden_size)
if self.keep_lm_head:
return outputs
else:
return torch.stack(outputs.hidden_states, dim=1)
class CTX_ENCODER_TYPE(str, Enum):
EARLY_EXIT = "early_exit"
EMBED_ONLY = "embed_only"
PER_LAYER_ACTIVATIONS = "per_layer_activations"
CTX_ENCODER_CLS = {
CTX_ENCODER_TYPE.EARLY_EXIT: EarlyExit,
CTX_ENCODER_TYPE.EMBED_ONLY: EmbeddingOnly,
CTX_ENCODER_TYPE.PER_LAYER_ACTIVATIONS: PerLayerActivations,
}

View file

@ -0,0 +1,70 @@
import requests
import torch
def call_generate(
input_txt: str,
context_txt: str,
window_size: int | None = None,
max_new_tokens: int | None = None,
host: str = "http://127.0.0.1:8989",
timeout: int = 120,
) -> torch.Tensor:
"""Send the prompt to the API server and return the generated token tensor."""
payload: dict[str, object] = {
"input_txt": input_txt,
"context_txt": context_txt,
}
if window_size is not None:
payload["window_size"] = int(window_size)
if max_new_tokens is not None:
payload["max_new_tokens"] = int(max_new_tokens)
response = requests.post(f"{host}/generate", json=payload, timeout=timeout)
response.raise_for_status()
data = response.json()
if "output" not in data:
raise ValueError(f"Unexpected response payload: {data}")
return torch.tensor([data["output"]])
def check_server_health(host: str = "http://127.0.0.1:8989", timeout: int = 60) -> None:
"""Check if the API server is healthy and responding."""
try:
response = requests.get(f"{host}/health", timeout=timeout)
response.raise_for_status()
data = response.json()
if data.get("status") != "ok":
raise RuntimeError(f"Server is not healthy: {data}")
except requests.exceptions.ConnectionError as e:
raise RuntimeError(
f"Cannot connect to server at {host}. Is the server running?"
) from e
except requests.exceptions.Timeout as e:
raise RuntimeError(f"Server health check timed out after {timeout}s") from e
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Server health check failed: {e}") from e
print("Server is healthy.")
class GenerativeAdapter(torch.nn.Module):
def __init__(self, model, tokenizer):
super().__init__()
self.base_model = model # placeholder
self.tokenizer = tokenizer
check_server_health()
@property
def generation_config(self):
return self.base_model.generation_config
def generate(self, *args, **kwargs):
ctx_ids = kwargs["ctx_ids"]
input_ids = kwargs["input_ids"]
assert ctx_ids.shape[0] == 1
assert input_ids.shape[0] == 1
context_txt = self.tokenizer.decode(ctx_ids[0])
input_txt = self.tokenizer.decode(input_ids[0])
outputs = call_generate(input_txt, context_txt)
return outputs

View file

@ -0,0 +1,930 @@
import logging
from collections.abc import Iterable
from dataclasses import dataclass
from functools import partial
from math import sqrt
from typing import Any
import torch
from einops import unpack
from einops.layers.torch import EinMix as Mix
from jaxtyping import Float, Integer
from peft import (
LoraConfig,
LoraRuntimeConfig,
PeftConfig,
PeftModel,
)
from peft.tuners._buffer_dict import BufferDict
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
from peft.utils import PeftType, TaskType
from torch import Tensor, nn
from transformers import (
PretrainedConfig,
PreTrainedModel,
)
from transformers.modeling_outputs import ModelOutput
from transformers.models.modernbert.modeling_modernbert import ModernBertModel
from ctx_to_lora.configs import (
AggregatorArguments,
CtxEncoderArguments,
HypernetArguments,
)
from ctx_to_lora.data.processing import tokenize_ctx_text
from ctx_to_lora.model_loading import (
get_model,
get_tokenizer,
)
from ctx_to_lora.modeling.aggregator import (
AGGREGATOR_CLS,
AggregatorConfig,
get_aggregator_config,
)
from ctx_to_lora.modeling.ctx_encoder import CTX_ENCODER_CLS, CTX_ENCODER_TYPE
from ctx_to_lora.modeling.lora_layer import (
apply_lora_to_layers,
lora_forward,
lora_forward_packed,
)
from ctx_to_lora.modeling.lora_merger import combine_lora
from ctx_to_lora.utils import (
get_layers,
get_num_layers,
get_peft_in_out_features,
get_peft_modules,
)
logger = logging.getLogger()
@dataclass
class HypernetConfig:
latent_size: int
use_light_weight_lora: bool
light_weight_latent_size: int
per_rank_gen: bool
use_per_rank_bias: bool
use_bias: bool
per_layer_processing: bool
use_token_mixing: bool
num_pre_head_layers: int
dropout_rate: float
lora_config: LoraConfig
extra_modules: list[str] | None
base_hidden_size: int
layer_indices: Iterable[int]
feature_sizes: tuple[dict[str, int], dict[str, int]]
aggregator_config: AggregatorConfig
def get_hypernet_config(
model: PreTrainedModel,
ctx_encoder_model_config: PretrainedConfig,
hypernet_args: HypernetArguments,
aggregator_args: AggregatorArguments,
ctx_encoder_args: CtxEncoderArguments,
):
num_modules = 0
lora_config = getattr(model, "peft_config", None)
if lora_config is not None:
lora_config = lora_config["default"]
num_modules += len(lora_config.target_modules)
num_extra_modules = len(hypernet_args.extra_modules or [])
indices = torch.arange(get_num_layers(model), device=model.device)
return HypernetConfig(
**vars(hypernet_args),
base_hidden_size=model.config.hidden_size,
lora_config=lora_config,
layer_indices=indices,
feature_sizes=get_peft_in_out_features(model, peft_config=lora_config),
aggregator_config=get_aggregator_config(
model,
ctx_encoder_model_config,
ctx_encoder_args.ctx_encoder_type == CTX_ENCODER_TYPE.PER_LAYER_ACTIVATIONS,
hypernet_args.latent_size,
num_modules,
num_extra_modules,
lora_config.r,
hypernet_args.per_rank_gen,
aggregator_args,
),
)
def get_init_peft_weights(model: PeftModel, peft_config: PeftConfig = None):
if peft_config is None:
peft_config = model.peft_config["default"]
peft_weights = {module_name: dict() for module_name in peft_config.target_modules}
adapter_name = "default"
for module_name, module in model.named_modules():
if not check_target_module_exists(peft_config, module_name):
continue
if not isinstance(module, BaseTunerLayer):
continue
# support just Linear layer for now
# all modules should be a leave module that is Linear layer
assert isinstance(module.base_layer, nn.Linear), (
"all modules should be a leave module that is Linear layer"
)
# this should always pass
name = module_name.split(".")[-1]
assert name in peft_config.target_modules
for submodule_name, submodule in module.named_modules():
if not isinstance(submodule, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
continue
if adapter_name not in submodule:
continue
if submodule_name not in peft_weights[name]:
peft_weights[name][submodule_name] = submodule[adapter_name]
else:
smod1 = peft_weights[name][submodule_name]
smod2 = submodule[adapter_name]
assert type(smod1) == type(smod2)
return peft_weights
class ResMLPBlock(nn.Module):
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: int,
dropout_rate: float = 0,
):
super().__init__()
layers = []
layers = [
nn.LayerNorm(input_size),
nn.Dropout(dropout_rate),
nn.Linear(input_size, hidden_size),
nn.SiLU(),
nn.Dropout(dropout_rate),
nn.Linear(hidden_size, output_size),
nn.LayerNorm(output_size),
]
self.mlp = nn.Sequential(*layers)
def forward(self, x):
return x + self.mlp(x)
class ResMLPBlockPerLayer(nn.Module):
def __init__(
self,
n_layers: int,
input_size: int,
hidden_size: int,
output_size: int,
):
super().__init__()
layers = [
nn.LayerNorm(input_size),
Mix(
"bs n_layers n_modules r d_in -> bs n_layers n_modules r d_hid",
weight_shape="n_layers d_in d_hid",
bias_shape="n_layers d_hid",
n_layers=n_layers,
d_in=input_size,
d_hid=hidden_size,
),
nn.SiLU(),
Mix(
"bs n_layers n_modules r d_hid -> bs n_layers n_modules r d_out",
weight_shape="n_layers d_hid d_out",
bias_shape="n_layers d_out",
n_layers=n_layers,
d_hid=hidden_size,
d_out=output_size,
),
nn.LayerNorm(output_size),
]
self.layers = nn.Sequential(*layers)
def forward(self, x):
return x + self.layers(x)
class HyperLoRA(nn.Module):
def __init__(self, config: HypernetConfig):
super().__init__()
# aggregator output [bs, n_layers, n_modules, feature_dim]
# by mixing the pooled features with layer embs and module embs (for pooling)
# or via a perceiver w/ bottleneck size = n_modules * n_layers
self.config = config
logger.debug(f"HyperLoRA config: {self.config}")
self.iterative_mode = False
self._init_model()
def _init_model(self):
self.agg_config = self.config.aggregator_config
self.aggregator = AGGREGATOR_CLS[self.agg_config.aggregator_type](
**vars(self.agg_config)
)
self.lora_config = self.config.lora_config
self.r = self.lora_config.r
self.target_modules = (
tuple(sorted(self.lora_config.target_modules)) if self.lora_config else None
)
self.num_modules = len(self.target_modules) if self.target_modules else 0
self.extra_modules = (
self.config.extra_modules if self.config.extra_modules else None
)
self.num_extra_modules = len(self.extra_modules) if self.extra_modules else 0
self.layer_indices = self.config.layer_indices
self.n_layers = len(self.layer_indices)
self.d_in, self.d_out = self.config.feature_sizes
self.d_latent = self.config.latent_size
if self.target_modules:
if self.config.per_layer_processing:
layers = [
ResMLPBlockPerLayer(
self.n_layers,
self.d_latent,
self.d_latent * 4,
self.d_latent,
)
for _ in range(self.config.num_pre_head_layers)
]
else:
layers = [
ResMLPBlock(
input_size=self.config.latent_size,
hidden_size=self.config.latent_size * 4,
output_size=self.config.latent_size,
dropout_rate=getattr(self.config, "dropout_rate", 0),
)
for _ in range(self.config.num_pre_head_layers)
]
self.layers = nn.Sequential(*layers)
self.d_lora = max(self.d_in[m] + self.d_out[m] for m in self.target_modules)
self.bias_A = nn.ParameterDict(
{
m: nn.Parameter(
torch.normal(
0,
0.2 / (self.d_in[m] * self.r) ** 0.5,
(self.n_layers, self.r, self.d_in[m]),
)
)
for m in self.target_modules
}
)
self.bias_B = nn.ParameterDict(
{
m: nn.Parameter(torch.zeros((self.n_layers, self.r, self.d_out[m])))
for m in self.target_modules
}
)
self.scaler_A = nn.ParameterDict(
{
m: nn.Parameter(torch.ones((1, self.n_layers, self.r, 1)))
for m in self.target_modules
}
)
self.scaler_B = nn.ParameterDict(
{
m: nn.Parameter(torch.zeros((1, self.n_layers, self.r, 1)))
for m in self.target_modules
}
)
n_modules = len(self.target_modules)
# have to do this otherwise doesnt work with adamw_torch_fused
# has something to do with the bias shape (n_modules r d_lora)
# when n_modules == 1, adamw_torch_fused complains about device/layout
# but when n_modules > 1, it works fine
if n_modules == 1:
self.head = Mix(
"bs n_layers n_modules r d_latent -> bs n_layers n_modules r d_lora",
weight_shape="n_layers d_latent d_lora",
bias_shape=None, # no bias
n_layers=len(self.layer_indices),
d_latent=self.config.latent_size,
r=self.config.lora_config.r,
d_lora=self.d_lora,
)
else:
self.head = Mix(
"bs n_layers n_modules r d_latent -> bs n_layers n_modules r d_lora",
weight_shape="n_layers n_modules d_latent d_lora",
bias_shape=None, # no bias
n_layers=len(self.layer_indices),
n_modules=n_modules,
d_latent=self.config.latent_size,
r=self.config.lora_config.r,
d_lora=self.d_lora,
)
def get_head_bias(self):
bias_dict = dict()
for module in self.target_modules:
bias_A = self.bias_A[module]
bias_B = self.bias_B[module]
bias_dict[module] = dict(A=bias_A, B=bias_B)
return bias_dict
def _to_lora_dict(
self, flat_loras: Float[Tensor, "bs n_layers n_modules r max_io_dim"]
) -> dict[str, dict[str, Float[Tensor, "bs n_layers r _"]]]:
if self.target_modules is None:
return None
# list of [bs, n_layers, r, in_d_outim]
# and in_d_outim might vary across modules
loras = unpack(
flat_loras,
[[] for _ in range(len(self.target_modules))],
"bs n_layers * r max_io_dim",
)
# dict of {module:
# {A: [bs, n_layers, r, d_inim],
# B: [bs, n_layers, r, d_outim]}}
lora_dict = dict()
for module, lora in zip(self.target_modules, loras):
A, B = unpack(
lora[..., : self.d_in[module] + self.d_out[module]],
[[self.d_in[module]], [self.d_out[module]]],
"bs n_layers r *",
)
# apparently doing A * self.scaler_A is slow due to broadcasting
A = torch.einsum("ijkl,ijkl->ijkl", A, self.scaler_A[module])
B = torch.einsum("ijkl,ijkl->ijkl", B, self.scaler_B[module])
lora_dict[module] = dict(A=A, B=B)
return lora_dict
def _to_layernorm_dict(
self, flat_layernorms: Float[Tensor, "bs n_layers n_modules hidden_size"]
) -> dict[str, Float[Tensor, "bs n_layers hidden_size"]]:
if self.extra_modules is None:
return None
layernorms = unpack(
flat_layernorms,
[[] for _ in range(len(self.extra_modules))],
"bs n_layers * hidden_size",
)
return {k: v for k, v in zip(self.extra_modules, layernorms)}
def enable_iterative_mode(self, x: bool):
self.iterative_mode = x
self.aggregator.enable_iterative_mode(x)
def forward(
self,
features: Float[Tensor, "bs seq_len feature_dim"],
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
position_ids: Integer[Tensor, "bs seq_len"] | None = None,
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
):
# [bs, n_layers, n_total_modules, r, feature_dim]
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
if self.aggregator.layer_to_layer and self.iterative_mode:
# iterative inference
# features: [bs num_layers seq_len feature_dim]
bs, n_layers = features.shape[0:2]
lora_emb = torch.empty(
(bs, n_layers, self.num_modules, self.r, self.config.latent_size),
device=features.device,
)
for i in range(n_layers):
lora_emb[:, i], _ = self.aggregator(
features[:, i], attn_mask, position_ids
)
else:
# batched inference
lora_emb, _ = self.aggregator(features, attn_mask, position_ids)
# [bs, n_layers, n_modules, r, max_in_d_outim]
flat_loras = None
if self.target_modules:
lora_emb = self.layers(lora_emb)
norm = torch.norm(lora_emb, dim=-1, keepdim=True)
norm_lora_emb = lora_emb / norm
flat_loras = self.head(norm_lora_emb)
flat_layernorms = None
return flat_loras, flat_layernorms
def generate_weights(
self,
features: Float[Tensor, "bs seq_len feature_dim"],
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
position_ids: Integer[Tensor, "bs seq_len"] | None = None,
):
flat_loras, flat_layernorms = self.forward(features, attn_mask, position_ids)
return self._to_lora_dict(flat_loras), self._to_layernorm_dict(flat_layernorms)
class ModulatedPretrainedModel(nn.Module):
def __init__(
self,
base_model: PeftModel,
hypernet_config: HypernetConfig,
ctx_encoder_args: CtxEncoderArguments,
use_base_input_as_ctx: bool = False,
# need non-packed inputs for generation
use_sequence_packing: bool = True,
user_defined_scaling: float = 1,
inp_compressor=None,
):
assert not use_base_input_as_ctx
super().__init__()
self.device = base_model.device
self.peft_config = base_model.peft_config["default"]
self.hypernet_config = hypernet_config
self.ctx_encoder_args = ctx_encoder_args
self.use_base_input_as_ctx = use_base_input_as_ctx
self.use_sequence_packing = use_sequence_packing
self.user_defined_scaling = user_defined_scaling
self.inp_compressor = inp_compressor
self.model_accepts_loss_kwargs = True
self.generated_loras = None
self.register_module("base_model", base_model)
self._init_model()
self._bias_hyper_init()
@classmethod
def from_state_dict(
cls,
state_dict: dict,
train: bool = True,
base_model_kwargs: dict = None,
use_flash_attn: bool = True,
**kwargs: Any,
):
lora_config = state_dict["hypernet_config"].lora_config
print(f"lora_config: {lora_config}")
model_name_or_path = state_dict["base_model_name_or_path"]
base_model = get_model(
model_name_or_path,
train=train,
requires_grad=False,
peft_config=lora_config,
model_kwargs=base_model_kwargs,
use_flash_attn=use_flash_attn,
)
hypernet_config = state_dict["hypernet_config"]
if getattr(hypernet_config, "num_pre_head_layers", None) is None:
hypernet_config.num_pre_head_layers = 4
if getattr(hypernet_config, "use_per_rank_bias", None) is None:
hypernet_config.use_per_rank_bias = False
if getattr(hypernet_config, "use_bias", None) is None:
hypernet_config.use_bias = True
ctx_encoder_args = state_dict["ctx_encoder_args"]
model = cls(base_model, hypernet_config, ctx_encoder_args, **kwargs)
model.load_state_dict(state_dict)
return model
def patch_lora_forward(self):
layers = get_layers(self.base_model)
lora_forward_fn = (
lora_forward_packed if self.use_sequence_packing else lora_forward
)
for layer_idx in self.hypernet.layer_indices:
for module_info in get_peft_modules(layers[layer_idx], self.peft_config):
name = module_info["name"]
module = module_info["module"]
if getattr(module, "patched_forward", False):
continue
logger.debug(f"Applying LoRA forward to {name}")
module.forward_orig = module.forward
module.patched_forward = True
module.forward = partial(
lora_forward_fn,
self=module,
lora_dropout_p=self.peft_config.lora_dropout,
scaling=self.peft_config.lora_alpha,
)
def _init_model(self):
# disable adapter of the base model
# this only works with LoRA(?)
# we disable to avoid peft lora computation
self.base_model.disable_adapter_layers()
self.hypernet = (
HyperLoRA(self.hypernet_config).to(self.device).to(torch.float32)
)
self.patch_lora_forward()
ctx_model_name = self.ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_model_name is None:
ctx_model_name = self.base_model.config.name_or_path
# use an explicit copy of the base model
# for using with "modules_to_save"
base_model_attn_impl = self.base_model.config._attn_implementation
logger.debug(f"ctx_model_name: {ctx_model_name}")
logger.debug(f"base_model.config._attn_implementation: {base_model_attn_impl}")
encoder_model = get_model(
ctx_model_name,
train=self.base_model.training,
requires_grad=False,
use_flash_attn=base_model_attn_impl == "flash_attention_2",
use_q_lora=self.ctx_encoder_args.quantize_ctx_encoder,
)
self.ctx_encoder = CTX_ENCODER_CLS[self.ctx_encoder_args.ctx_encoder_type](
encoder_model, self.ctx_encoder_args
)
# delegate to base_model
@property
def config(self):
return self.base_model.config
@property
def generation_config(self):
return self.base_model.generation_config
@property
def vocab_size(self):
return self.base_model.vocab_size
def get_input_embeddings(self):
return self.base_model.get_input_embeddings()
@torch.no_grad()
def _bias_hyper_init(self):
if self.hypernet.extra_modules:
self.hypernet.extra_head.weight.data[:] = 0
self.hypernet.extra_head.bias.data[:] = 0
if self.hypernet.target_modules:
peft_weights = get_init_peft_weights(
self.base_model, self.hypernet.lora_config
)
logger.debug(f"peft_weights: {peft_weights}")
r = self.hypernet_config.lora_config.r
nn.init.normal_(
self.hypernet.head.weight,
mean=0,
std=0.5
/ sqrt(self.hypernet.config.latent_size + self.hypernet.d_lora * r),
# the head outputs per rank lora --> divide by r to scale down grad
)
def state_dict(self, *args, **kwargs):
# we assume ctx_encoder and base model is frozen here
if len([p for p in self.ctx_encoder.parameters() if p.requires_grad]):
raise ValueError("ctx_encoder contains trainable parameters")
if len([p for p in self.base_model.parameters() if p.requires_grad]):
raise ValueError("base model contains trainable parameters")
state_dict = self.hypernet.state_dict(*args, **kwargs)
state_dict["base_model_name_or_path"] = self.base_model.name_or_path
state_dict["hypernet_config"] = self.hypernet_config
state_dict["ctx_encoder_args"] = self.ctx_encoder_args
return state_dict
def load_state_dict(self, state_dict: dict, *args, **kwargs):
self.base_model_name_or_path = state_dict.pop("base_model_name_or_path")
self.hypernet_config = state_dict.pop("hypernet_config")
self.ctx_encoder_args = state_dict.pop("ctx_encoder_args")
if self.base_model_name_or_path != self.base_model.name_or_path:
raise ValueError(
f"Base model name or path mismatch. "
f"The base model given is: {self.base_model.name_or_path}, "
f"but the loaded name is: {self.base_model_name_or_path}"
)
self._init_model()
def remove_compile_prefix(sd: dict[str, Tensor]) -> dict[str, Tensor]:
COMPILED_PREFIX = "_orig_mod."
for k in list(sd.keys()):
if k.startswith(COMPILED_PREFIX):
sd[k[len(COMPILED_PREFIX) :]] = sd.pop(k)
return sd
load_result = self.hypernet.load_state_dict(
remove_compile_prefix(state_dict),
strict=True, # , *args, **kwargs
)
logger.info(f"load result: {load_result}")
return load_result
def generate_weights(
self,
ctx_ids: Integer[Tensor, "bs ctx_len"],
ctx_attn_mask: Integer[Tensor, "bs ctx_len"] | None = None,
ctx_position_ids: Integer[Tensor, "bs ctx_len"] | None = None,
**kwargs: Any,
):
with torch.no_grad():
ctx_encoder_kwargs = dict(
input_ids=ctx_ids,
attention_mask=ctx_attn_mask,
position_ids=ctx_position_ids,
)
if isinstance(self.ctx_encoder.base_model, ModernBertModel):
position_ids = ctx_position_ids.flatten()
indices = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
# [bsz + 1]
cu_seqlens = torch.cat(
(
indices[position_ids == 0],
torch.tensor(
position_ids.size(),
device=position_ids.device,
dtype=torch.int32,
),
)
)
ctx_encoder_kwargs = dict(
input_ids=ctx_ids.squeeze(0),
cu_seqlens=cu_seqlens,
max_seqlen=position_ids.max() + 1,
attention_mask=-1,
seq_len=-1,
batch_size=-1,
)
ctx_features = self.ctx_encoder(**ctx_encoder_kwargs, **kwargs)
if isinstance(self.ctx_encoder.base_model, ModernBertModel):
ctx_features = ctx_features.unsqueeze(0)
if self.user_defined_scaling == 1:
return self.hypernet.generate_weights(
ctx_features, ctx_attn_mask, ctx_position_ids
)
lora_dict, _ = self.hypernet.generate_weights(
ctx_features, ctx_attn_mask, ctx_position_ids
)
for module in lora_dict:
lora_dict[module]["A"] = lora_dict[module]["A"] * self.user_defined_scaling
lora_dict[module]["B"] = lora_dict[module]["B"] * self.user_defined_scaling
return lora_dict, None
def enable_iterative_mode(self, x: bool):
self.hypernet.enable_iterative_mode(x)
def forward(
self,
ctx_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
ctx_attn_mask: Integer[Tensor, "n_ctx ctx_len"] | None = None,
ctx_position_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
n_queries: Integer[Tensor, "n_ctx"] | None = None,
return_generated_lora: bool | None = False,
*model_inputs_args: Any,
**model_inputs_kwargs: dict[str, Any],
) -> tuple | ModelOutput:
"""Forward pass of the modulated model."""
generated_loras = None
generated_layernorms = None
if ctx_ids is None and not self.use_base_input_as_ctx:
logger.warning(
(
"*" * 100,
"\n\nNo ctx_features provided, using the base model for forward pass\n\n",
"*" * 100,
)
)
else:
if self.use_base_input_as_ctx:
ctx_ids = (
model_inputs_kwargs["input_ids"]
if "input_ids" in model_inputs_kwargs
else model_inputs_args[0]
)
ctx_attn_mask = (
model_inputs_kwargs["attention_mask"]
if "attention_mask" in model_inputs_kwargs
else None
)
ctx_position_ids = (
model_inputs_kwargs["position_ids"]
if "position_ids" in model_inputs_kwargs
else None
)
generated_loras, generated_layernorms = self.generate_weights(
ctx_ids, ctx_attn_mask, ctx_position_ids
)
if generated_loras is not None:
generated_loras = combine_lora(
generated_loras,
n_ctx_chunks,
lora_bias=self.hypernet.get_head_bias()
if self.hypernet.config.use_bias
else None,
)
# input_ids in model_inputs_kwargs contains only
# prompt + response (for hypernet training)
position_ids = (
model_inputs_kwargs["position_ids"]
if "position_ids" in model_inputs_kwargs
else None
)
if n_queries is None:
if ctx_position_ids is None:
n_queries = torch.ones(
ctx_ids.shape[0], dtype=torch.int32, device=self.device
)
else:
# quite redundant (we do cu_seqlens many places)
# TODO: compute cu_seqlens here and propagate that
n_queries = torch.ones(
(ctx_position_ids == 0).sum(),
dtype=torch.int32,
device=self.device,
)
apply_lora_to_layers(
self.base_model,
self.hypernet.layer_indices,
generated_loras,
n_queries,
position_ids,
)
model_outputs = self.base_model(*model_inputs_args, **model_inputs_kwargs)
if return_generated_lora:
return model_outputs, (generated_loras, generated_layernorms)
else:
return model_outputs
def combine_lora(self, *args, **kwargs):
# for timing
return combine_lora(*args, **kwargs)
def apply_lora_to_layers(self, *args, **kwargs):
# for timing
return apply_lora_to_layers(*args, **kwargs)
# for simple api usage
def internalize(self, ctx_str: str):
ctx_tokenizer = get_tokenizer(self.ctx_encoder.base_model.name_or_path)
ctx_ids = tokenize_ctx_text(dict(context=[ctx_str]), ctx_tokenizer)["ctx_ids"]
return self._internalize_from_ids(torch.tensor(ctx_ids, device=self.device))
def _internalize_from_ids(
self,
ctx_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
ctx_attn_mask: Integer[Tensor, "n_ctx ctx_len"] | None = None,
ctx_position_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
):
self.patch_lora_forward()
if ctx_attn_mask is None and ctx_position_ids is None:
assert ctx_ids.shape[0] == 1
ctx_attn_mask = torch.ones_like(ctx_ids)
generated_loras, generated_layernorms = self.generate_weights(
ctx_ids, ctx_attn_mask, ctx_position_ids
)
self.generated_loras = generated_loras
def reset(self):
self.generated_loras = None
layers = get_layers(self.base_model)
for layer_idx in self.hypernet.layer_indices:
for module_info in get_peft_modules(layers[layer_idx], self.peft_config):
name = module_info["name"]
module = module_info["module"]
logger.debug(f"Resetting forward for {name}")
module.forward = module.forward_orig
module.patched_forward = False
@torch.inference_mode()
def generate(
self,
ctx_ids: Integer[Tensor, "n_chunks ctx_length"] | None = None,
ctx_attn_mask: Integer[Tensor, "n_chunks ctx_length"] | None = None,
ctx_position_ids: Integer[Tensor, "n_chunks ctx_length"] | None = None,
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
n_queries: Integer[Tensor, "n_ctx"] | None = None,
scalers: Float[Tensor, "n_ctx"] | None = None,
bias_scaler: float | None = None,
*model_inputs_args: Any,
**model_inputs_kwargs: dict[str, Any],
):
generated_loras = None
generated_layernorms = None
if (
ctx_ids is None
and not self.generated_loras
and not self.use_base_input_as_ctx
):
print(
"*" * 100
+ "\n\nNo ctx_ids provided, using the base model for generation\n\n"
+ "*" * 100
)
elif ctx_ids is None and self.generated_loras:
generated_loras = self.generated_loras
if n_ctx_chunks is None:
n_ctx_chunks = torch.tensor((1,), device=self.device)
print(
"*" * 100
+ "\n\nUsing internalized LoRAs for generation\n\n"
+ "*" * 100
)
else:
if self.use_base_input_as_ctx:
ctx_ids = (
model_inputs_kwargs["input_ids"]
if "input_ids" in model_inputs_kwargs
else model_inputs_args[0]
)
ctx_attn_mask = (
model_inputs_kwargs["attention_mask"]
if "attention_mask" in model_inputs_kwargs
else None
)
ctx_position_ids = (
model_inputs_kwargs["position_ids"]
if "position_ids" in model_inputs_kwargs
else None
)
generated_loras, generated_layernorms = self.generate_weights(
ctx_ids, ctx_attn_mask, ctx_position_ids
)
if generated_loras is not None:
generated_loras = self.combine_lora(
generated_loras,
n_ctx_chunks,
lora_bias=self.hypernet.get_head_bias()
if self.hypernet.config.use_bias
else None,
scalers=scalers,
bias_scaler=bias_scaler,
)
# apply lora hook to the base model
# TODO: we dont this position_ids for generation?
position_ids = (
model_inputs_kwargs["position_ids"]
if "position_ids" in model_inputs_kwargs
else None
)
if n_queries is None:
if ctx_position_ids is None:
n_queries = torch.ones(
model_inputs_kwargs["input_ids"].shape[0],
dtype=torch.int32,
device=self.device,
)
else:
# quite redundant (we do cu_seqlens many places)
# TODO: compute cu_seqlens here and propagate that
n_queries = torch.ones(
(ctx_position_ids == 0).sum(),
dtype=torch.int32,
device=self.device,
)
apply_lora_to_layers(
self.base_model,
self.hypernet.layer_indices,
generated_loras,
n_queries,
position_ids,
)
model_outputs = self.base_model.generate(
*model_inputs_args, **model_inputs_kwargs
)
return model_outputs
# needed for loading model from checkpoint
# see https://github.com/huggingface/transformers/pull/34632
torch.serialization.add_safe_globals(
[
AggregatorConfig,
LoraConfig,
HypernetConfig,
PeftType,
TaskType,
LoraRuntimeConfig,
set, # for real?
]
)

View file

@ -0,0 +1,765 @@
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Idefics2 model."""
import math
import torch
from torch import nn
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
from transformers.modeling_utils import PreTrainedModel
from transformers.models.idefics2.configuration_idefics2 import Idefics2Config
from transformers.utils import (
add_start_docstrings,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
if is_flash_attn_2_available():
from flash_attn.bert_padding import unpad_input
from transformers.modeling_flash_attention_utils import _flash_attention_forward
logger = logging.get_logger(__name__)
class Idefics2PerceiverConfig(PretrainedConfig):
r"""
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the perceiver block.
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
n_latents (`int`, *optional*, defaults to 64):
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
resampler_depth (`int`, *optional*, defaults to 3):
Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
n_heads (`int`, *optional*, defaults to 16):
Number of heads in each Transformer block (for multi-headed self-attention).
head_dim (`int`, *optional*, defaults to 96):
Dimensionality of each head projection in the Transformer block.
num_key_value_heads (`int`, *optional*, defaults to 4):
Number of key-value heads in the perceiver attention block.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
"""
model_type = "idefics2_perceiver"
def __init__(
self,
input_size: int,
num_blocks: int,
num_self_attn_per_block: int,
shared_weights: bool,
intermediate_size_factor: int,
hidden_act="silu",
hidden_size=4096,
rms_norm_eps=1e-06,
n_latents=64,
n_heads=16,
head_dim=128,
num_key_value_heads=4,
attention_dropout=0.0,
**kwargs,
):
self.num_blocks = num_blocks
self.num_self_attn_per_block = num_self_attn_per_block
self.shared_weights = shared_weights
self.input_size = input_size
self.intermediate_size_factor = intermediate_size_factor
# for perceiver
self.hidden_act = hidden_act
self.hidden_size = hidden_size
self.rms_norm_eps = rms_norm_eps
self.n_latents = n_latents
self.n_heads = n_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.attention_dropout = attention_dropout
if self.num_key_value_heads > self.n_heads:
raise ValueError(
f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
f" n_heads={self.n_heads}"
)
super().__init__(**kwargs)
class Idefics2MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
output_size: int,
hidden_act: str,
):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
IDEFICS2_START_DOCSTRING = r"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`Idefics2Config`] or [`Idefics2VisionConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare Idefics2 Model outputting raw hidden-states without any specific head on top.",
IDEFICS2_START_DOCSTRING,
)
class Idefics2PreTrainedModel(PreTrainedModel):
config_class = Idefics2Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [
"Idefics2VisionAttention",
"Idefics2MLP",
"Idefics2PerceiverLayer",
"Idefics2DecoderLayer",
]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def _init_weights(self, module):
std = (
self.config.initializer_range
if hasattr(self.config, "initializer_range")
else 0.02
)
if hasattr(module, "class_embedding"):
module.class_embedding.data.normal_(mean=0.0, std=std)
if isinstance(module, (nn.Linear, nn.Conv2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(
batch, num_key_value_heads, n_rep, slen, head_dim
)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
class Idefics2RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Idefics2RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
class Idefics2PerceiverAttention(nn.Module):
def __init__(self, config, layer_idx: int | None = None) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__()
self.config = config
self.layer_idx = None
self.hidden_size = config.hidden_size
self.num_heads = config.n_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.attention_dropout = config.attention_dropout
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=False
)
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.v_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
)
self.o_proj = nn.Linear(
self.num_heads * self.head_dim, self.hidden_size, bias=False
)
self.is_causal = False
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: tuple[torch.Tensor] | None = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
"""
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
Args:
latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
past_key_value (`Tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states.
output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
"""
bsz, q_len, _ = latents.size()
kv_seq_len = q_len + context.size()[1]
hidden_states = torch.concat([context, latents], dim=-2)
query_states = self.q_proj(latents)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, kv_seq_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
past_key_value = getattr(self, "past_key_value", past_key_value)
if past_key_value is not None:
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx
)
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(
query_states, key_states.transpose(2, 3)
) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
# TODO cyril: modular
class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
"""
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
# Ignore copy
def forward(
self,
latents: torch.Tensor,
is_cross_attn: bool,
context: torch.Tensor | None = None,
attention_mask: torch.LongTensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
bsz, q_len, _ = latents.size()
query_states = self.q_proj(latents)
if is_cross_attn:
kv_inp = context
else:
kv_inp = latents
key_states = self.k_proj(kv_inp)
value_states = self.v_proj(kv_inp)
# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
query_states = query_states.view(
*latents.shape[:2], self.num_heads, self.head_dim
)
key_states = key_states.view(
*kv_inp.shape[:2], self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
*kv_inp.shape[:2], self.num_key_value_heads, self.head_dim
).transpose(1, 2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
dropout_rate = 0.0 if not self.training else self.attention_dropout
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently casted in float32. Hence, we need
# cast them back in float16 just to be sure everything works as expected.
input_dtype = query_states.dtype
if input_dtype == torch.float32:
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
# Handle the case where the model is quantized
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# Reashape to the expected shape for Flash Attention
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
position_ids=position_ids,
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
**kwargs,
)
attn_output = attn_output.reshape(
bsz, q_len, self.num_heads * self.head_dim
).contiguous()
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
# "eager": Idefics2PerceiverAttention,
"flash_attention_2": Idefics2PerceiverFlashAttention2,
}
class Idefics2PerceiverLayer(nn.Module):
def __init__(self, config, is_cross_attn: bool):
super().__init__()
self.hidden_size = config.hidden_size
self.n_latents = config.n_latents
self.rms_norm_eps = config.rms_norm_eps
self.is_cross_attn = is_cross_attn
self.input_latents_layernorm = Idefics2RMSNorm(
self.hidden_size, eps=self.rms_norm_eps
)
self.input_context_layernorm = (
Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
if self.is_cross_attn
else torch.nn.Identity()
)
self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
config._attn_implementation
](config)
self.post_attention_layernorm = Idefics2RMSNorm(
self.hidden_size, eps=self.rms_norm_eps
)
self.pre_ff_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self.post_ff_layernorm = Idefics2RMSNorm(
self.hidden_size, eps=self.rms_norm_eps
)
self.mlp = Idefics2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.hidden_size * 4,
output_size=config.hidden_size,
hidden_act=config.hidden_act,
)
def forward(
self,
latents: torch.Tensor,
context: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: tuple[torch.Tensor] | None = None,
output_attentions: bool | None = False,
use_cache: bool | None = False,
**kwargs,
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
"""
Args:
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, sequence_length)` where padding elements are indicated by 0.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = latents
latents = self.input_latents_layernorm(latents)
context = self.input_context_layernorm(context)
latents, self_attn_weights, present_key_value = self.self_attn(
latents=latents,
is_cross_attn=self.is_cross_attn,
context=context,
attention_mask=attention_mask,
position_ids=position_ids,
**kwargs,
)
latents = self.post_attention_layernorm(latents)
latents = residual + latents
residual = latents
# latents = self.post_attention_layernorm(latents)
latents = self.pre_ff_layernorm(latents)
latents = self.mlp(latents)
latents = self.post_ff_layernorm(latents)
latents = residual + latents
outputs = (latents,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
IDEFICS2_INPUTS_DOCSTRING = r"""
Args:
context (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`):
The hidden states of the image after vision encoder and modality projection.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
"""
@add_start_docstrings(
"Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed ",
"`n_latents` inputs to decrease embedding sequence length. The Resampler acts as a form of learned pooling and ",
"is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)",
IDEFICS2_START_DOCSTRING,
)
class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
_supports_sdpa = False
config_class = Idefics2PerceiverConfig
def __init__(self, config) -> None:
super().__init__(config)
self.num_blocks = config.num_blocks
self.num_self_attn_per_block = config.num_self_attn_per_block
self.shared_weights = config.shared_weights
self.hidden_size = config.hidden_size
self.hidden_act = config.hidden_act
self.n_latents = config.n_latents
self.rms_norm_eps = config.rms_norm_eps
# Create Latents for Perceiver
self.latents_q = nn.Parameter(torch.randn(self.n_latents, self.hidden_size))
# First block
assert config.num_blocks > 0
first_x_attn = [Idefics2PerceiverLayer(config, is_cross_attn=True)]
first_self_attn_block = [
Idefics2PerceiverLayer(config, is_cross_attn=False)
for _ in range(config.num_self_attn_per_block)
]
self.layers = nn.ModuleList(first_x_attn + first_self_attn_block)
for layer_idx in range(1, config.num_blocks):
# cross-attention at the beginning of each block
if self.shared_weights:
if layer_idx == 1:
second_x_attn = Idefics2PerceiverLayer(config, is_cross_attn=True)
x_attn = second_x_attn
else:
x_attn = Idefics2PerceiverLayer(config, is_cross_attn=True)
self.layers.append(x_attn)
# self-attention
for i in range(config.num_self_attn_per_block):
if self.shared_weights:
self_attn = first_self_attn_block[i]
else:
self_attn = Idefics2PerceiverLayer(config, is_cross_attn=False)
self.layers.append(self_attn)
self.layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
assert self._use_flash_attention_2
def forward(
self,
context: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
) -> torch.Tensor:
# seq embed -> bsz seq embed
if position_ids is None:
bsz = context.shape[0]
else:
# flattened packed sequence
bsz = torch.where(position_ids == 0, 1, 0).sum()
latents = self.latents_q.unsqueeze(0).expand((bsz, *self.latents_q.size()))
attention_mask = (
_prepare_4d_attention_mask(
attention_mask, latents.dtype, tgt_len=self.n_latents
)
if not self._use_flash_attention_2
else attention_mask
)
compressed_context = latents
cu_seq_lens_q = torch.tensor(
[self.n_latents] * (bsz + 1), device=context.device, dtype=torch.int32
) * torch.arange(bsz + 1, device=context.device, dtype=torch.int32)
max_length_q = self.n_latents
# cu_seq_lens_k = None
# max_length_k = None
if attention_mask is not None:
logger.warning_once("Using attention mask for resampler")
context, _, cu_seq_lens_k, max_length_k, _ = unpad_input(
context, attention_mask
)
context = context.unsqueeze(0)
position_ids = True # goes down flash attn path that uses cu_seq_lens
elif position_ids is not None:
logger.warning_once("Using position ids for resampler")
position_ids = position_ids.flatten()
indices = torch.arange(
position_ids.size(0), device=position_ids.device, dtype=torch.int32
)
# [bsz + 1]
cu_seq_lens_k = torch.cat(
(
indices[position_ids == 0],
torch.tensor(
position_ids.size(),
device=position_ids.device,
dtype=torch.int32,
),
)
)
max_length_k = position_ids.max() + 1
else:
raise ValueError("either position_ids or attention_mask is required")
x_attn_kwargs = dict(
position_ids=position_ids,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
)
self_attn_position_ids = torch.arange(
self.n_latents, device=context.device, dtype=torch.int32
).repeat(1, bsz)
self_attn_kwargs = dict(
# attention_mask=self_attn_mask,
position_ids=self_attn_position_ids,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_q,
max_length_q=max_length_q,
max_length_k=max_length_q,
)
for i, layer in enumerate(self.layers):
inp_kwargs = dict(
latents=compressed_context,
context=context,
past_key_value=None,
output_attentions=False,
use_cache=False,
)
if layer.is_cross_attn:
attn_kwargs = {**inp_kwargs, **x_attn_kwargs}
else:
attn_kwargs = {**inp_kwargs, **self_attn_kwargs}
layer_outputs = layer(**attn_kwargs)
compressed_context = layer_outputs[0]
compressed_context = self.layernorm(compressed_context)
return compressed_context
class Idefics2Perceiver(Idefics2PreTrainedModel):
def __init__(
self,
encoder_config: Idefics2PerceiverConfig,
decoder_config: Idefics2PerceiverConfig,
):
super().__init__(encoder_config)
self.modality_projection = Idefics2MLP(
hidden_size=encoder_config.input_size,
intermediate_size=encoder_config.intermediate_size_factor
* encoder_config.input_size,
output_size=encoder_config.hidden_size,
hidden_act=encoder_config.hidden_act,
)
self.encoder = Idefics2PerceiverResampler._from_config(encoder_config)
self.decoder = Idefics2PerceiverResampler._from_config(decoder_config)
def forward(
self,
context: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
):
if position_ids is None:
bsz = context.shape[0]
else:
bsz = torch.where(position_ids == 0, 1, 0).sum()
projected_inputs = self.modality_projection(context)
# [bsz, n_latents, dim]
latents = self.encoder(
context=projected_inputs,
attention_mask=attention_mask,
position_ids=position_ids,
)
latent_position_ids = torch.arange(
self.encoder.n_latents, device=context.device
).unsqueeze(0)
latent_position_ids = torch.tile(latent_position_ids, (1, bsz))
outputs = self.decoder(latents, position_ids=latent_position_ids)
return outputs
__all__ = [
"Idefics2Perceiver",
]

View file

@ -0,0 +1,143 @@
import torch
from llmlingua import PromptCompressor
from torch import nn
from ctx_to_lora.data.definitions import CTX_AFFIXES
class LLMLinguaModel(nn.Module):
def __init__(self, model, tokenizer, compression_rate):
super().__init__()
self.base_model = model
self.compressor = PromptCompressor(
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
use_llmlingua2=True, # Whether to use llmlingua-2
)
model_name = self.base_model.name_or_path
self.register_buffer("prefix", torch.tensor(CTX_AFFIXES[model_name]["prefix"]))
self.register_buffer("suffix", torch.tensor(CTX_AFFIXES[model_name]["suffix"]))
self.len_prefix = len(self.prefix)
self.len_suffix = len(self.suffix)
self.tokenizer = tokenizer
self.compression_rate = compression_rate
@property
def generation_config(self):
return self.base_model.generation_config
def compress(self, prompt_txt: str, rate: float):
return self.compressor.compress_prompt(
prompt_txt, rate=rate, force_tokens=["\n", "?"]
)
def compress_tokens(self, input_ids, query_text):
bs = input_ids.shape[0]
txt = self.tokenizer.batch_decode(
input_ids[:, self.len_prefix : -self.len_suffix]
)
q_start_idx = txt[0].rfind(query_text)
ctx_txt = txt[0][:q_start_idx]
q_txt = txt[0][q_start_idx:]
compressed_txt = self.compress(ctx_txt, rate=self.compression_rate)
compressed_ids = self.tokenizer(
compressed_txt["compressed_prompt"] + "\n\n" + q_txt,
return_attention_mask=False,
add_special_tokens=False,
return_tensors="pt",
)["input_ids"].to(self.base_model.device)
out = torch.cat(
[self.prefix.expand(bs, -1), compressed_ids, self.suffix.expand(bs, -1)],
dim=-1,
)
return out
def generate(self, *args, **kwargs):
# take ctx_ids
# strip prefix and suffix
# ctx_ids is left padded
ctx_ids = kwargs["ctx_ids"][:, self.len_prefix : -self.len_suffix]
# decode ctx_ids to ctx_txt
ctx_txt = self.tokenizer.batch_decode(ctx_ids)
compressed_ctx_txt = self.compress(ctx_txt, rate=self.compression_rate)
compressed_ctx_ids = self.tokenizer(
compressed_ctx_txt["compressed_prompt"] + "\n\n",
return_attention_mask=False,
add_special_tokens=False,
return_tensors="pt",
).to(self.base_model.device)
bs = ctx_ids.shape[0]
ctx_inp_ids = torch.cat(
[
self.prefix.expand(bs, -1),
compressed_ctx_ids["input_ids"],
kwargs["input_ids"][:, self.len_prefix :],
],
dim=-1,
)
attn_mask = torch.ones_like(ctx_inp_ids)
for k in [
"ctx_ids",
"ctx_attn_mask",
"n_ctx_chunks",
"input_ids",
"attention_mask",
]:
kwargs.pop(k, None)
return self.base_model.generate(ctx_inp_ids, attention_mask=attn_mask, **kwargs)
if __name__ == "__main__":
from ctx_to_lora.model_loading import get_model_and_tokenizer
model, tokenizer = get_model_and_tokenizer(
"google/gemma-2-2b-it",
train=False,
requires_grad=False,
)
# Demo: build wrapper, create a toy context + prompt, run compression + generation.
device = "cuda"
llm = LLMLinguaModel(model, tokenizer).to(device)
# Toy context and user prompt
context_text = (
"This is a short illustrative context about large language models and compression. "
"They can reduce prompt length while preserving meaning."
)
user_prompt = (
"Summarize the context in one concise sentence." # what we want model to do
)
# Tokenize raw context (core) without special tokens
core_ctx_ids = tokenizer.apply_chat_template(
[[{"role": "user", "content": context_text}]],
tokenize=True,
add_generation_prompt=True,
return_attention_mask=False,
padding=False,
truncation=False,
return_tensors="pt",
add_special_tokens=False,
).to(device)
# Input prompt tokens (what follows the contextual block)
input_ids = tokenizer.apply_chat_template(
[[{"role": "user", "content": user_prompt}]],
tokenize=True,
add_generation_prompt=True,
return_attention_mask=False,
padding=False,
truncation=False,
return_tensors="pt",
add_special_tokens=False,
).to(device)
print("Original context length (chars):", len(context_text))
# Run generation (may vary depending on model capabilities)
output_ids = llm.generate(ctx_ids=core_ctx_ids, input_ids=input_ids)
# Decode only the tail beyond supplied input for readability
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
print(f"\nFull generated text:\n{generated_text}")

View file

@ -0,0 +1,107 @@
from collections.abc import Iterable
from functools import partial
from operator import attrgetter
import torch
import torch.nn.functional as F
from einops import einsum
from jaxtyping import Float, Integer
from torch import Tensor
from ctx_to_lora.utils import get_layers
def lora_forward(
x: Float[Tensor, "tot_q seq_len d_in"],
n_qs: Integer[Tensor, "n_ctx"],
tot_q: int,
A: Float[Tensor, "n_ctx r d_in"],
B: Float[Tensor, "n_ctx r d_out"],
lora_dropout_p: float,
scaling: float,
self,
*args,
**kwargs,
) -> Float[Tensor, "tot_q seq_len d_out"]:
# A: [n_ctx, r, d_in] -> [tot_q, r, d_in]
A = A.repeat_interleave(n_qs, dim=0, output_size=tot_q)
# B: [n_ctx, d_out, r] -> [tot_q, d_out, r]
B = B.repeat_interleave(n_qs, dim=0, output_size=tot_q)
base_out = torch.nn.Linear.forward(self, x, *args, **kwargs)
x = x.to(A.dtype)
delta_x = F.dropout(x, p=lora_dropout_p, training=self.training)
delta_x = einsum(A, delta_x, "tot_q r d_in, tot_q s_len d_in -> tot_q s_len r")
delta_x = einsum(B, delta_x, "tot_q r d_out, tot_q s_len r -> tot_q s_len d_out")
delta_x = delta_x * scaling
return (base_out + delta_x).to(base_out.dtype)
def lora_forward_packed(
x: Float[Tensor, "1 tot_len d_in"],
n_qs: Integer[Tensor, "n_ctx"],
tot_q: int,
seq_lens: Integer[Tensor, "tot_q"],
tot_len: int,
A: Float[Tensor, "n_ctx r d_in"],
B: Float[Tensor, "n_ctx r d_out"],
lora_dropout_p: float,
scaling: float,
self,
*args,
**kwargs,
) -> Float[Tensor, "1 tot_len d_out"]:
# bs of x should be 1 in this case
base_out = torch.nn.Linear.forward(self, x, *args, **kwargs)
x = x.to(A.dtype)
delta_x = F.dropout(x, p=lora_dropout_p, training=self.training)
repeated_A = A.repeat_interleave(n_qs, dim=0, output_size=tot_q)
repeated_A = repeated_A.repeat_interleave(seq_lens, dim=0, output_size=tot_len)
repeated_B = B.repeat_interleave(n_qs, dim=0, output_size=tot_q)
repeated_B = repeated_B.repeat_interleave(seq_lens, dim=0, output_size=tot_len)
delta_x = einsum(
repeated_A, delta_x, "tot_len r d_in, bs tot_len d_in -> bs tot_len r"
)
delta_x = einsum(
repeated_B, delta_x, "tot_len r d_out, bs tot_len r -> bs tot_len d_out"
)
delta_x = delta_x * scaling
return (base_out + delta_x).to(base_out.dtype)
def apply_lora_to_layers(
model: torch.nn.Module,
layer_indices: Iterable[int],
generated_loras: dict[str, dict[str, Float[Tensor, "n_ctx n_layers r _"]]],
n_qs: Integer[Tensor, "n_ctx"],
position_ids: Integer[Tensor, "bs seq_len"] = None,
) -> None:
layers = get_layers(model)
if position_ids is not None:
position_ids = position_ids.squeeze(0)
seq_lens = position_ids[torch.where(position_ids == 0)[0][1:] - 1]
seq_lens = torch.cat(
[seq_lens, torch.tensor([position_ids[-1]], device=seq_lens.device)]
)
seq_lens += 1
tot_len = seq_lens.sum().item()
tot_q = n_qs.sum().item()
for layer_idx in layer_indices:
layer = layers[layer_idx]
for mname in generated_loras:
if mname in ["q_proj", "k_proj", "v_proj", "o_proj", "qkv_proj"]:
long_mname = f"self_attn.{mname}"
elif mname in ["down_proj", "up_proj", "gate_proj"]:
long_mname = f"mlp.{mname}"
module = attrgetter(long_mname)(layer)
A = generated_loras[mname]["A"][:, layer_idx]
B = generated_loras[mname]["B"][:, layer_idx]
module.forward = partial(module.forward, n_qs=n_qs, tot_q=tot_q, A=A, B=B)
if position_ids is not None:
module.forward = partial(
module.forward, seq_lens=seq_lens, tot_len=tot_len
)

View file

@ -0,0 +1,78 @@
"""
Utilities for merging / aggregating LoRA adapters coming from multiple chunks.
"""
import torch
from einops import rearrange
from jaxtyping import Float, Integer
from torch import Tensor
def compute_rank(n_lora, rank):
return (n_lora + 1) * rank
def combine_lora(
generated_loras: dict[str, dict[str, Tensor]],
n_chunks: Integer[Tensor, "n_ctx"],
lora_bias: dict[str, dict[str, Tensor]] | None = None,
scalers: Float[Tensor, "n_ctx"] | None = None,
bias_scaler: float | None = None,
) -> dict[str, dict[str, Tensor]]:
total_chunks = int(n_chunks.sum())
if bias_scaler is None:
bias_scaler = 1
# Assume all modules share same base rank r
first_module = next(iter(generated_loras))
sampled_lora = generated_loras[first_module]["A"]
base_rank = sampled_lora.shape[-2]
device = sampled_lora.device
dtype = sampled_lora.dtype
max_rank_needed = int(compute_rank(n_chunks.max(), base_rank))
combined_loras: dict[str, dict[str, Tensor]] = {
module: {"A": None, "B": None} for module in generated_loras.keys()
}
rank_dim = 2
num_groups = len(n_chunks)
rank_per_group = (n_chunks * base_rank).tolist()
bias_tensor = None
for module_name, module_loras in generated_loras.items():
for matrix_key in ("A", "B"):
if lora_bias is not None:
bias_tensor = lora_bias[module_name][matrix_key]
loras = module_loras[matrix_key]
if (scalers is not None) and (matrix_key == "A"):
loras = loras * scalers[:, None, None, None]
flat_loras = rearrange(
loras, "tot_chunks n_layers r dim -> 1 n_layers (tot_chunks r) dim"
)
per_group_deltas = flat_loras.split(rank_per_group, dim=rank_dim)
combined_shape = [num_groups, *per_group_deltas[0].shape[1:]]
combined_shape[rank_dim] = max_rank_needed
combined = torch.zeros(*combined_shape, device=device, dtype=dtype)
for g, deltas in enumerate(per_group_deltas):
combined_rank = deltas.shape[rank_dim]
# Build slice pattern, slice up to combined_rank.
# slice_pattern = [g, slice(None), slice(None), slice(None)]
# slice_pattern[rank_dim] = slice(combined_rank)
combined[g, :, :combined_rank, :] = deltas
if bias_tensor is not None:
# bias_slice_pattern = [g, slice(None), slice(None), slice(None)]
# bias_slice_pattern[rank_dim] = slice(
# combined_rank, combined_rank + base_rank
# )
combined[g, :, combined_rank : combined_rank + base_rank, :] = (
bias_tensor * bias_scaler
)
combined_loras[module_name][matrix_key] = combined
return combined_loras

View file

@ -0,0 +1,160 @@
from functools import partial
import torch
from peft import PeftConfig
from torch import nn
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers, lora_forward
from ctx_to_lora.modeling.text_to_lora_impl import (
embed_texts,
get_layers,
get_peft_config,
load_hypermod,
)
from ctx_to_lora.utils import get_peft_modules
class TextToLoRA(nn.Module):
def __init__(self, model_name_or_path, prefix_tokens, device):
assert model_name_or_path == "google/gemma-2-2b-it"
super().__init__()
hypermod_dir = "trained_t2l/gemma_2b_t2l"
peft_config = get_peft_config(
PeftConfig.from_json_file(f"{hypermod_dir}/adapter_config.json")
)
# ours lora forward pass uses alpha directly
peft_config.lora_alpha = peft_config.lora_alpha / peft_config.r
self.prefix_tokens = prefix_tokens
self.device = device
(
_,
self.t2l_model,
self.base_model,
self.tokenizer,
self.emb_model,
self.emb_tokenizer,
self.task_desc_format_fn,
self.pooling_fn,
) = load_hypermod(hypermod_dir, device)
layer_indices = range(len(get_layers(self.base_model)))
self.layer_indices = torch.tensor(
layer_indices, dtype=torch.long, device=device
)
# patch base model forward pass to use lora
layers = get_layers(self.base_model)
lora_forward_fn = lora_forward
for layer_idx in self.layer_indices:
for module_info in get_peft_modules(layers[layer_idx], peft_config):
module = module_info["module"]
module.forward = partial(
lora_forward_fn,
self=module,
lora_dropout_p=peft_config.lora_dropout,
scaling=peft_config.lora_alpha,
)
@property
def generation_config(self):
return self.base_model.generation_config
def generate_weights(self, ctx_txt: str):
# generate loras
ctx_emb = embed_texts(
[ctx_txt],
self.emb_model,
self.emb_tokenizer,
self.task_desc_format_fn,
self.pooling_fn,
self.device,
)
encoder_out = self.t2l_model.task_encoder(ctx_emb)
encoded_task_emb = encoder_out["encoded_task_emb"].detach()
lora_A, lora_B = dict(), dict()
lora_dict = dict()
for target_module in self.t2l_model.target_modules:
factorized_delta_w = self.t2l_model.get_delta_weights(
self.layer_indices,
target_module,
encoded_task_emb.expand(self.layer_indices.shape[0], -1),
factorized=True,
)
# lora_A[target_module]: [n_layers, r, d_in]
# lora_A[target_module]: [n_layers, d_out, r]
lora_A[target_module], lora_B[target_module] = factorized_delta_w
# convert to lora format used by lora_forward
# dict of {module:
# {A: [bs, n_layers, r, d_inim],
# B: [bs, n_layers, r, d_outim]}}
lora_dict[target_module] = dict(
A=lora_A[target_module].unsqueeze(0),
B=lora_B[target_module].transpose(-1, -2).unsqueeze(0),
)
return lora_dict
def generate(self, *args, **kwargs):
ctx_ids_full = kwargs["ctx_ids"]
ctx_txt = self.tokenizer.decode(
ctx_ids_full[0, len(self.prefix_tokens) :], skip_special_tokens=True
)
generated_loras = self.generate_weights(ctx_txt)
apply_lora_to_layers(
self.base_model,
self.layer_indices,
generated_loras,
n_qs=torch.tensor([1], device=self.device),
position_ids=None,
)
kwargs.pop("ctx_ids", None)
kwargs.pop("ctx_attn_mask", None)
kwargs.pop("n_ctx_chunks", None)
return self.base_model.generate(*args, **kwargs)
if __name__ == "__main__":
from transformers import AutoTokenizer
from ctx_to_lora.data.definitions import CTX_AFFIXES
from ctx_to_lora.data.processing import load_and_process_dataset
model_name = "google/gemma-2-2b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name)
ds = load_and_process_dataset("pwc", split="train", num_proc=8)
ctx = ds[0]["context"]
inp = ds[1]["prompts"][0]
ctx_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": ctx}], return_tensors="pt", return_dict=True
)
input_ids = tokenizer.apply_chat_template(
[{"role": "user", "content": ctx + "\n\n" + inp}],
return_tensors="pt",
return_dict=True,
)
ctx_ids = {k: v.to("cuda") for k, v in ctx_ids.items()}
input_ids = {k: v.to("cuda") for k, v in input_ids.items()}
prefix_tokens = CTX_AFFIXES[model_name]["prefix"]
prefix_tokens = torch.tensor(prefix_tokens, dtype=torch.long)
t2l_model = TextToLoRA(
model_name,
prefix_tokens,
device="cuda",
)
with torch.no_grad():
for _ in range(1):
outputs = t2l_model.generate(
**input_ids,
ctx_ids=ctx_ids["input_ids"],
max_new_tokens=256,
do_sample=False,
)
print(
f"Student response: {tokenizer.batch_decode(outputs, skip_special_tokens=False)}"
)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,54 @@
from enum import Enum
import torch
from jaxtyping import Float, Integer
from torch import Tensor
POOL_FN = Enum("POOL_FN", ["MEAN", "MAX", "LAST_TOKEN"])
def inv_bool_mask(m: Integer[Tensor, "bs seq_len"]) -> Integer[Tensor, "bs seq_len 1"]:
return (m - 1).bool().unsqueeze(-1)
def get_pooling_fn(pooling_type: str):
if pooling_type == POOL_FN.MEAN:
return mean_pool
elif pooling_type == POOL_FN.MAX:
return max_pool
elif pooling_type == POOL_FN.LAST_TOKEN:
return last_token_pool
def mean_pool(
features: Float[Tensor, "bs seq_len feature_dim"],
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
) -> Float[Tensor, "bs 1 feature_dim"]:
if attn_mask is not None:
features = features.masked_fill(inv_bool_mask(attn_mask), 0)
return features.sum(dim=1) / attn_mask.sum(dim=1).unsqueeze(1)
def max_pool(
features: Float[Tensor, "bs seq_len feature_dim"],
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
) -> Float[Tensor, "bs 1 feature_dim"]:
if attn_mask is not None:
features = features.masked_fill(inv_bool_mask(attn_mask), -float("inf"))
return torch.max(features, dim=1)
def last_token_pool(
features: Float[Tensor, "bs seq_len feature_dim"],
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
) -> Float[Tensor, "bs feature_dim"]:
left_padding = attn_mask[:, -1].sum() == attn_mask.shape[0]
if left_padding:
return features[:, -1]
else:
sequence_lengths = attn_mask.sum(dim=1) - 1
batch_size = features.shape[0]
return features[
torch.arange(batch_size, device=features.device),
sequence_lengths,
]

View file

View file

@ -0,0 +1,353 @@
"""Lightweight CUDA memory tracking utilities for measuring per-method peak memory usage.
Usage:
x = SomeClass()
add_memory_tracker(x.some_method, "some_method") # wraps the bound method in-place
x.some_method(...)
print_aggregate_memory_stats("some_method")
Design notes:
- Mirrors the API of timer.py but records CUDA memory (peak increase in bytes) per call.
- add_memory_tracker mutates the instance method with a wrapper (idempotent: double wrap avoided).
- Global registry: { name: [int, ...] } storing per-call peak memory increase (bytes).
- print_aggregate_memory_stats prints summary stats (count, total, mean, median, min, max, p95, last).
- If CUDA or torch is unavailable, wrappers degrade gracefully (no measurements recorded).
Metrics collected per call (if CUDA available):
- peak_increase_bytes: (torch.cuda.max_memory_allocated() - start_allocated)
This captures the maximum additional memory pressure during the call.
Caveats:
- Rapid allocate/free patterns entirely inside the call still reflect peak transient usage.
- Asynchronous CUDA ops: we synchronize before and after to improve accuracy. This may
slightly affect performance timings but is necessary for memory correctness.
"""
from __future__ import annotations
from collections.abc import Callable
from statistics import mean, median, stdev
from typing import Any
try: # Optional dependency handling
import torch # type: ignore
except Exception: # pragma: no cover - torch absence path
torch = None # type: ignore
# Global memory registry: name -> list of peak memory increases (bytes)
MEMORY_REGISTRY: dict[str, list[int]] = {}
def _cuda_available() -> bool:
return bool(torch is not None and torch.cuda.is_available())
def add_memory_tracker(func: Callable, name: str) -> None:
"""Attach a CUDA memory tracking wrapper to a bound method.
Parameters
----------
func : Callable
A *bound* instance method (instance.method). Raises ValueError if unbound.
name : str
Key under which memory stats are recorded in MEMORY_REGISTRY.
"""
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
if getattr(func, "__is_memory_wrapper__", False): # already wrapped
return
raise ValueError(
"add_memory_tracker expects a bound method: call with instance.method"
)
instance = func.__self__
method_name = getattr(func, "__name__", None)
if method_name is None:
raise ValueError("Cannot determine method name for provided callable")
existing = getattr(instance, method_name, None)
if getattr(existing, "__is_memory_wrapper__", False): # idempotent
return
orig_bound = func
def tracked(*args: Any, **kwargs: Any): # noqa: D401 - simple wrapper
if not _cuda_available():
return orig_bound(*args, **kwargs)
# Synchronize to get a clean baseline
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start_alloc = torch.cuda.memory_allocated()
try:
return orig_bound(*args, **kwargs)
finally:
torch.cuda.synchronize()
peak_alloc = torch.cuda.max_memory_allocated()
peak_increase = peak_alloc - start_alloc
# Record only if positive (avoid negative due to potential race, though improbable)
if peak_increase < 0:
peak_increase = 0
MEMORY_REGISTRY.setdefault(name, []).append(int(peak_increase))
tracked.__name__ = method_name
tracked.__doc__ = getattr(orig_bound, "__doc__")
tracked.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
tracked.__is_memory_wrapper__ = True # type: ignore[attr-defined]
tracked.__wrapped__ = orig_bound # type: ignore[attr-defined]
tracked.__memory_name__ = name # type: ignore[attr-defined]
setattr(instance, method_name, tracked)
def _format_bytes(num_bytes: float) -> str:
"""Human-readable byte formatting (base-2)."""
if num_bytes < 1024:
return f"{int(num_bytes):5d}B"
units = ["KiB", "MiB", "GiB", "TiB"]
value = float(num_bytes)
for u in units:
value /= 1024.0
if value < 1024.0:
return f"{value:8.3f}{u}"
return f"{value:8.3f}PiB" # Extremely unlikely
def compute_aggregate_memory_stats(
name: str | None = None,
) -> dict[str, dict[str, float]] | None:
"""Compute aggregate CUDA memory statistics for specific trackers.
Parameters
----------
name : Optional[str]
Specific tracker name to compute stats for. If None, all trackers are computed.
Returns
-------
Optional[Dict[str, Dict[str, float]]]
None if no data, else dict mapping tracker names to their stats.
Each stats dict contains: count, total, mean, median, min, max, p95, last, std.
"""
if not MEMORY_REGISTRY:
return None
keys = [name] if name else sorted(MEMORY_REGISTRY.keys())
valid_keys = [k for k in keys if k in MEMORY_REGISTRY and MEMORY_REGISTRY[k]]
if not valid_keys:
return None
result = {}
for k in valid_keys:
data = MEMORY_REGISTRY[k]
data_sorted = sorted(data)
cnt = len(data)
total = sum(data)
avg = mean(data)
med = median(data)
std = stdev(data) if cnt > 1 else 0.0
mn = data_sorted[0]
mx = data_sorted[-1]
p95_index = int(0.95 * (cnt - 1))
p95 = data_sorted[p95_index]
last = data[-1]
result[k] = {
"count": float(cnt),
"total": float(total),
"mean": float(avg),
"median": float(med),
"min": float(mn),
"max": float(mx),
"p95": float(p95),
"last": float(last),
"std": float(std),
}
return result
def save_memory_stats_csv(file_path: str, name: str | None = None) -> None:
"""Save aggregate CUDA memory statistics to a CSV file.
Parameters
----------
file_path : str
Path where the CSV file will be saved.
name : Optional[str]
Specific tracker name to export. If None, all trackers are exported.
"""
import csv
stats = compute_aggregate_memory_stats(name)
if stats is None:
raise ValueError("No memory data available to export")
with open(file_path, "w", newline="") as csvfile:
fieldnames = [
"name",
"count",
"total",
"mean",
"median",
"min",
"max",
"p95",
"last",
"std",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for tracker_name, data in stats.items():
row = {"name": tracker_name}
row.update(data)
writer.writerow(row)
def print_aggregate_memory_stats(name: str | None = None) -> None:
"""Print aggregate CUDA memory stats for one or all tracked names.
Parameters
----------
name : Optional[str]
Specific name to report; if None, report all.
"""
stats = compute_aggregate_memory_stats(name)
if stats is None:
print("[mem] No memory data collected.")
return
keys = [name] if name else sorted(MEMORY_REGISTRY.keys())
missing = [k for k in keys if k not in MEMORY_REGISTRY or not MEMORY_REGISTRY[k]]
if missing:
print(f"[mem] No data for: {', '.join(missing)}")
if not stats:
return
header = (
f"{'name':20} {'count':>6} {'total':>12} {'mean':>12} {'median':>12} "
f"{'min':>12} {'max':>12} {'p95':>12} {'std':>12} {'last':>12}"
)
print(header)
print("-" * len(header))
for k, data in stats.items():
print(
f"{k:20} {int(data['count']):6d} {_format_bytes(data['total']):>12} "
f"{_format_bytes(data['mean']):>12} {_format_bytes(data['median']):>12} "
f"{_format_bytes(data['min']):>12} {_format_bytes(data['max']):>12} "
f"{_format_bytes(data['p95']):>12} {_format_bytes(data['std']):>12} "
f"{_format_bytes(data['last']):>12}"
)
def compute_global_memory_stats() -> dict[str, float] | None:
"""Compute aggregate stats across all recorded memory entries.
Returns
-------
Optional[Dict[str, float]]
None if no data; else dict with count, total, mean, median, min, max, p95, std.
"""
if not MEMORY_REGISTRY:
return None
all_values: list[int] = []
for lst in MEMORY_REGISTRY.values():
all_values.extend(lst)
if not all_values:
return None
data_sorted = sorted(all_values)
cnt = len(all_values)
total = float(sum(all_values))
avg = mean(all_values)
med = median(all_values)
std = stdev(all_values) if cnt > 1 else 0.0
mn = float(data_sorted[0])
mx = float(data_sorted[-1])
p95_index = int(0.95 * (cnt - 1))
p95 = float(data_sorted[p95_index])
return {
"count": float(cnt),
"total": total,
"mean": float(avg),
"median": float(med),
"min": mn,
"max": mx,
"p95": p95,
"std": float(std),
}
def print_global_memory_stats() -> None:
"""Pretty-print global stats across all memory registry entries."""
stats = compute_global_memory_stats()
if stats is None:
print("[mem] No memory data collected.")
return
header = (
f"{'scope':20} {'count':>6} {'total':>12} {'mean':>12} {'median':>12} "
f"{'min':>12} {'max':>12} {'p95':>12} {'std':>12}"
)
print(header)
print("-" * len(header))
print(
f"{'<ALL>':20} {int(stats['count']):6d} {_format_bytes(stats['total']):>12} "
f"{_format_bytes(stats['mean']):>12} {_format_bytes(stats['median']):>12} "
f"{_format_bytes(stats['min']):>12} {_format_bytes(stats['max']):>12} "
f"{_format_bytes(stats['p95']):>12} {_format_bytes(stats['std']):>12}"
)
def reset_memory_trackers() -> None:
"""Clear all recorded memory tracking data."""
MEMORY_REGISTRY.clear()
__all__ = [
"MEMORY_REGISTRY",
"add_memory_tracker",
"compute_aggregate_memory_stats",
"save_memory_stats_csv",
"print_aggregate_memory_stats",
"compute_global_memory_stats",
"print_global_memory_stats",
"reset_memory_trackers",
]
if __name__ == "__main__": # Simple demonstration
class Demo:
def __init__(self, device: str | None = None):
self.device = device or ("cuda" if _cuda_available() else "cpu")
def allocate(self, n: int = 1_000_000) -> int:
if not _cuda_available():
# Fallback: just create a CPU tensor
_ = [0] * n # noqa: F841
return n
import torch # local import to avoid mypy confusion
t = torch.empty(n, dtype=torch.float32, device=self.device)
# Perform an op to ensure allocation
t.uniform_() # noqa: F841
return t.numel()
def noalloc(self): # method with negligible allocation
return 42
demo = Demo()
add_memory_tracker(demo.allocate, "alloc")
add_memory_tracker(demo.noalloc, "noalloc")
add_memory_tracker(demo.allocate, "alloc") # idempotent
for _ in range(5):
demo.allocate(200_000)
demo.noalloc()
print("\nAll memory stats:\n")
print_aggregate_memory_stats()
print("\nSingle (alloc):\n")
print_aggregate_memory_stats("alloc")
print("\nGlobal memory stats:\n")
print_global_memory_stats()

View file

@ -0,0 +1,332 @@
"""Lightweight timing utilities for attaching runtime measurement to object methods.
Usage:
x = SomeClass()
add_timer(x.some_method, "some_method") # wraps the bound method in-place
x.some_method(...)
print_aggregate_timer_stats("some_method")
Design notes:
- add_timer mutates the instance by replacing the bound method with a timing wrapper.
- Multiple calls to add_timer on the same (already wrapped) method are ignored to avoid double timing.
- Global registry: { name: [float, ...] } storing individual call durations.
- print_aggregate_timer_stats prints summary stats (count, total, mean, median, min, max, p95, last).
"""
from __future__ import annotations
from collections.abc import Callable
from statistics import mean, median, stdev
from time import perf_counter
from typing import Any
# Global timer registry: name -> list of durations (seconds)
TIMER_REGISTRY: dict[str, list[float]] = {}
class _TimerWrapperMarker:
"""Mixin marker to identify already wrapped callables."""
__slots__ = ("__wrapped_name__",)
def __init__(self, wrapped_name: str):
self.__wrapped_name__ = wrapped_name
def add_timer(func: Callable, name: str) -> None:
"""Attach a timing wrapper to a bound method.
Parameters
----------
func : Callable
A *bound* method (e.g., instance.method). If an unbound function is provided
it will raise a ValueError (explicitness keeps behavior predictable).
name : str
Key under which durations are recorded in TIMER_REGISTRY.
"""
# Basic validation (permit already wrapped functions for idempotency even though
# they sit as plain functions on the instance dict and thus lack __self__).
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
if getattr(func, "__is_timer_wrapper__", False): # already wrapped, no-op
return
raise ValueError("add_timer expects a bound method: call with instance.method")
instance = func.__self__ # The object instance
method_name = getattr(func, "__name__", None)
if method_name is None:
raise ValueError("Cannot determine method name for provided callable")
# Prevent double-wrapping (idempotent behavior)
existing = getattr(instance, method_name, None)
if getattr(existing, "__is_timer_wrapper__", False): # Already wrapped
return
orig_bound = func # capture original bound method
def timed(*args: Any, **kwargs: Any): # noqa: D401 - simple wrapper
start = perf_counter()
try:
return orig_bound(*args, **kwargs)
finally:
elapsed = perf_counter() - start
TIMER_REGISTRY.setdefault(name, []).append(elapsed)
# Mark wrapper to avoid double wrapping; preserve introspection hints.
timed.__name__ = method_name
timed.__doc__ = getattr(orig_bound, "__doc__")
timed.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
timed.__is_timer_wrapper__ = True # type: ignore[attr-defined]
timed.__wrapped__ = orig_bound # type: ignore[attr-defined]
timed.__timer_name__ = name # type: ignore[attr-defined]
setattr(instance, method_name, timed)
def _format_seconds(sec: float) -> str:
if sec >= 1:
return f"{sec:8.3f}s"
if sec >= 1e-3:
return f"{sec * 1e3:8.3f}ms"
if sec >= 1e-6:
return f"{sec * 1e6:8.3f}µs"
return f"{sec * 1e9:8.3f}ns"
def compute_aggregate_timer_stats(
name: str | None = None,
) -> dict[str, dict[str, float]] | None:
"""Compute aggregate timing statistics for specific timers.
Parameters
----------
name : Optional[str]
Specific timer name to compute stats for. If None, all timers are computed.
Returns
-------
Optional[Dict[str, Dict[str, float]]]
None if no data, else dict mapping timer names to their stats.
Each stats dict contains: count, total, mean, median, min, max, p95, last, std.
"""
if not TIMER_REGISTRY:
return None
keys = [name] if name else sorted(TIMER_REGISTRY.keys())
valid_keys = [k for k in keys if k in TIMER_REGISTRY and TIMER_REGISTRY[k]]
if not valid_keys:
return None
result = {}
for k in valid_keys:
data = TIMER_REGISTRY[k]
data_sorted = sorted(data)
cnt = len(data)
total = sum(data)
avg = mean(data)
med = median(data)
std = stdev(data) if cnt > 1 else 0.0
mn = data_sorted[0]
mx = data_sorted[-1]
p95_index = int(0.95 * (cnt - 1))
p95 = data_sorted[p95_index]
last = data[-1]
result[k] = {
"count": float(cnt),
"total": total,
"mean": avg,
"median": med,
"min": mn,
"max": mx,
"p95": p95,
"last": last,
"std": std,
}
return result
def save_timer_stats_csv(file_path: str, name: str | None = None) -> None:
"""Save aggregate timing statistics to a CSV file.
Parameters
----------
file_path : str
Path where the CSV file will be saved.
name : Optional[str]
Specific timer name to export. If None, all timers are exported.
"""
import csv
stats = compute_aggregate_timer_stats(name)
if stats is None:
raise ValueError("No timing data available to export")
with open(file_path, "w", newline="") as csvfile:
fieldnames = [
"name",
"count",
"total",
"mean",
"median",
"min",
"max",
"p95",
"last",
"std",
]
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
for timer_name, data in stats.items():
row = {"name": timer_name}
row.update(data)
writer.writerow(row)
def print_aggregate_timer_stats(name: str | None = None) -> None:
"""Print aggregate timing statistics.
Parameters
----------
name : Optional[str]
Specific timer name to report. If None, all timers are reported.
"""
stats = compute_aggregate_timer_stats(name)
if stats is None:
print("[timer] No timing data collected.")
return
keys = [name] if name else sorted(TIMER_REGISTRY.keys())
missing = [k for k in keys if k not in TIMER_REGISTRY or not TIMER_REGISTRY[k]]
if missing:
print(f"[timer] No data for: {', '.join(missing)}")
if not stats:
return
header = (
f"{'name':20} {'count':>6} {'total':>10} {'mean':>10} {'median':>10} "
f"{'min':>10} {'max':>10} {'p95':>10} {'std':>10} {'last':>10}"
)
print(header)
print("-" * len(header))
for k, data in stats.items():
print(
f"{k:20} {int(data['count']):6d} {_format_seconds(data['total']):>10} "
f"{_format_seconds(data['mean']):>10} {_format_seconds(data['median']):>10} "
f"{_format_seconds(data['min']):>10} {_format_seconds(data['max']):>10} "
f"{_format_seconds(data['p95']):>10} {_format_seconds(data['std']):>10} "
f"{_format_seconds(data['last']):>10}"
)
def compute_global_timer_stats() -> dict[str, float] | None:
"""Compute aggregate statistics across all recorded timer values.
Returns
-------
Optional[Dict[str, float]]
None if no data recorded, else a dict with keys:
count, total, mean, median, min, max, p95, std.
"""
if not TIMER_REGISTRY:
return None
all_values: list[float] = []
for lst in TIMER_REGISTRY.values():
all_values.extend(lst)
if not all_values:
return None
data_sorted = sorted(all_values)
cnt = len(all_values)
total = sum(all_values)
avg = mean(all_values)
med = median(all_values)
std = stdev(all_values) if cnt > 1 else 0.0
mn = data_sorted[0]
mx = data_sorted[-1]
p95_index = int(0.95 * (cnt - 1))
p95 = data_sorted[p95_index]
return {
"count": float(cnt), # keep uniform numeric type
"total": total,
"mean": avg,
"median": med,
"min": mn,
"max": mx,
"p95": p95,
"std": std,
}
def print_global_timer_stats() -> None:
"""Pretty-print global aggregate stats across all timer entries."""
stats = compute_global_timer_stats()
if stats is None:
print("[timer] No timing data collected.")
return
header = (
f"{'scope':20} {'count':>6} {'total':>10} {'mean':>10} {'median':>10} "
f"{'min':>10} {'max':>10} {'p95':>10} {'std':>10}"
)
print(header)
print("-" * len(header))
print(
f"{'<ALL>':20} {int(stats['count']):6d} {_format_seconds(stats['total']):>10} "
f"{_format_seconds(stats['mean']):>10} {_format_seconds(stats['median']):>10} "
f"{_format_seconds(stats['min']):>10} {_format_seconds(stats['max']):>10} "
f"{_format_seconds(stats['p95']):>10} {_format_seconds(stats['std']):>10}"
)
def reset_timers() -> None:
"""Reset (clear) all recorded timing data."""
TIMER_REGISTRY.clear()
__all__ = [
"TIMER_REGISTRY",
"add_timer",
"compute_aggregate_timer_stats",
"save_timer_stats_csv",
"print_aggregate_timer_stats",
"compute_global_timer_stats",
"print_global_timer_stats",
"reset_timers",
]
if __name__ == "__main__":
# Example usage / simple self-test.
import random
import time
class Demo:
def f1(self, n: int = 20_000) -> int:
# CPU-bound work
s = 0
for i in range(n):
s += i * i
return s
def f2(self) -> None:
# Simulate I/O or waiting
time.sleep(random.uniform(0.001, 0.003))
demo = Demo()
# Attach timers (idempotent: calling again does nothing harmful)
add_timer(demo.f1, "f1")
add_timer(demo.f2, "f2")
add_timer(demo.f1, "f1") # demonstrate double-wrap prevention
for _ in range(5):
demo.f1(10_000)
demo.f2()
print("\nAll timers:\n")
print_aggregate_timer_stats()
print("\nSingle timer (f1):\n")
print_aggregate_timer_stats("f1")
print("\nGlobal aggregated stats:\n")
print_global_timer_stats()

View file

@ -0,0 +1,303 @@
"""Unified tracking interface combining timing and CUDA memory usage.
Primary API
-----------
add_tracker(bound_method, name)
Wraps a bound instance method so that each invocation records:
- wall-clock duration (seconds) in timer.TIMER_REGISTRY[name]
- CUDA peak memory increase (bytes) in cuda_memory_tracker.MEMORY_REGISTRY[name]
(only if CUDA + torch available; otherwise memory list may stay empty / absent)
print_tracker_stats(name=None)
Convenience printer that delegates to time + memory aggregate printers.
Design
------
We implement a single wrapper (instead of nesting the individual timer & memory
wrappers) to avoid multiple layers of indirection and to ensure the measured
CUDA memory footprint reflects only the original method's body (excluding the
separate timing wrapper's slight overhead). The wrapper is idempotent: repeated
calls to add_tracker on the same method are ignored.
This file depends on sibling modules:
- tracker.timer
- tracker.cuda_memory_tracker
Both registries remain the single source of truth; no additional registry is introduced.
"""
from __future__ import annotations
from collections.abc import Callable
from time import perf_counter
from typing import Any
# Support both package (relative) and direct script execution.
try: # Package / normal import path
from .cuda_memory_tracker import ( # type: ignore
MEMORY_REGISTRY,
compute_aggregate_memory_stats,
print_aggregate_memory_stats,
print_global_memory_stats,
reset_memory_trackers,
save_memory_stats_csv,
)
from .timer import ( # type: ignore
TIMER_REGISTRY,
compute_aggregate_timer_stats,
print_aggregate_timer_stats,
print_global_timer_stats,
reset_timers,
save_timer_stats_csv,
)
except Exception: # pragma: no cover - fallback when executed directly
import pathlib
import sys
_this_file = pathlib.Path(__file__).resolve()
# project root is two levels up from tracker/ (i.e., .../src)
_src_root = _this_file.parents[2]
if str(_src_root) not in sys.path:
sys.path.insert(0, str(_src_root))
try:
from ctx_to_lora.tracker.cuda_memory_tracker import ( # type: ignore
MEMORY_REGISTRY,
compute_aggregate_memory_stats,
print_aggregate_memory_stats,
print_global_memory_stats,
reset_memory_trackers,
save_memory_stats_csv,
)
from ctx_to_lora.tracker.timer import ( # type: ignore
TIMER_REGISTRY,
compute_aggregate_timer_stats,
print_aggregate_timer_stats,
print_global_timer_stats,
reset_timers,
save_timer_stats_csv,
)
except Exception as e: # If still failing, raise a clearer error.
raise ImportError(
f"Failed to import tracking dependencies; ensure project root on PYTHONPATH. Original: {e}"
)
try: # Optional torch import (lazy fallback if unavailable)
import torch # type: ignore
except Exception: # pragma: no cover - torch absence path
torch = None # type: ignore
__all__ = [
"add_tracker",
"compute_tracker_stats",
"save_tracker_stats_csv",
"print_tracker_stats",
"print_global_tracker_stats",
"reset_trackers",
]
def _cuda_available() -> bool:
return bool(torch is not None and torch.cuda.is_available())
def add_tracker(func: Callable, name: str) -> None:
"""Attach a combined time + CUDA memory tracking wrapper to a bound method.
Parameters
----------
func : Callable
A bound instance method (instance.method). Raises ValueError if unbound.
name : str
Registry key under which metrics are stored.
"""
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
# Permit idempotent re-calls if already wrapped.
if getattr(func, "__is_tracker_wrapper__", False):
return
raise ValueError(
"add_tracker expects a bound method: call with instance.method"
)
instance = func.__self__ # underlying object
method_name = getattr(func, "__name__", None)
if method_name is None:
raise ValueError("Cannot determine method name for provided callable")
existing = getattr(instance, method_name, None)
if getattr(
existing, "__is_tracker_wrapper__", False
): # Already wrapped via unified tracker
return
# If already individually wrapped by timer or memory tracker, we still wrap only once more;
# future calls to add_tracker will become no-ops.
orig_bound = existing if existing is not None else func
def tracked(*args: Any, **kwargs: Any): # noqa: D401 - combined wrapper
use_cuda = _cuda_available()
if use_cuda:
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start_alloc = torch.cuda.memory_allocated()
start_time = perf_counter()
try:
return orig_bound(*args, **kwargs)
finally:
elapsed = perf_counter() - start_time
TIMER_REGISTRY.setdefault(name, []).append(elapsed)
if use_cuda:
torch.cuda.synchronize()
peak_alloc = torch.cuda.max_memory_allocated()
peak_increase = peak_alloc - start_alloc
if peak_increase < 0: # Safety guard (should not happen)
peak_increase = 0
MEMORY_REGISTRY.setdefault(name, []).append(int(peak_increase))
# Introspection / idempotency markers
tracked.__name__ = method_name
tracked.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
tracked.__doc__ = getattr(orig_bound, "__doc__")
tracked.__wrapped__ = orig_bound # type: ignore[attr-defined]
tracked.__is_tracker_wrapper__ = True # type: ignore[attr-defined]
tracked.__is_timer_wrapper__ = True # type: ignore[attr-defined]
tracked.__is_memory_wrapper__ = True # type: ignore[attr-defined]
tracked.__tracker_name__ = name # type: ignore[attr-defined]
setattr(instance, method_name, tracked)
def compute_tracker_stats(
name: str | None = None,
) -> dict[str, dict[str, Any]] | None:
"""Compute both timing and memory stats for a given name (or all if None).
Parameters
----------
name : Optional[str]
Specific tracker name; if None, computes all.
Returns
-------
Optional[Dict[str, Dict[str, Any]]]
None if no data, else dict with 'timing' and 'memory' keys containing
their respective aggregate statistics.
"""
timer_stats = compute_aggregate_timer_stats(name)
memory_stats = compute_aggregate_memory_stats(name)
if timer_stats is None and memory_stats is None:
return None
return {
"timing": timer_stats or {},
"memory": memory_stats or {},
}
def save_tracker_stats_csv(file_path: str, name: str | None = None) -> None:
"""Save both timing and memory stats to separate CSV files.
Parameters
----------
file_path : str
Base path for CSV files. Will create file_path_timing.csv and file_path_memory.csv
name : Optional[str]
Specific tracker name to export. If None, all trackers are exported.
"""
import os
os.makedirs(os.path.dirname(file_path), exist_ok=True)
base_path = os.path.splitext(file_path)[0]
timer_path = f"{base_path}_timing.csv"
memory_path = f"{base_path}_memory.csv"
# Save timing stats if available
timer_stats = compute_aggregate_timer_stats(name)
if timer_stats is not None:
save_timer_stats_csv(timer_path, name)
# Save memory stats if available
memory_stats = compute_aggregate_memory_stats(name)
if memory_stats is not None:
save_memory_stats_csv(memory_path, name)
# If no data at all, raise an error
if timer_stats is None and memory_stats is None:
print("No tracking data available to export")
def print_tracker_stats(name: str | None = None) -> None:
"""Print both timing and memory stats for a given name (or all if None).
Parameters
----------
name : Optional[str]
Specific tracker name; if None, prints all.
"""
print("[tracker] Timing stats:")
print_aggregate_timer_stats(name)
print("\n[tracker] CUDA memory stats:")
print_aggregate_memory_stats(name)
def print_global_tracker_stats() -> None:
"""Print global aggregate timing and memory stats."""
print("[tracker] Global timing stats:")
print_global_timer_stats()
print("\n[tracker] Global CUDA memory stats:")
print_global_memory_stats()
def reset_trackers() -> None:
"""Reset all timer and memory tracking data."""
reset_timers()
reset_memory_trackers()
if __name__ == "__main__": # Demonstration
import random
import time
class Demo:
def compute(self, n: int = 25_000) -> int:
# CPU-bound work
s = 0
for i in range(n):
s += i * i
return s
def gpu_alloc(self, n: int = 500_000):
if not _cuda_available():
# Simulate light wait to differentiate timing
time.sleep(random.uniform(0.01, 0.05))
return None
t = torch.empty(n, dtype=torch.float32, device="cuda")
t.uniform_() # ensure usage
return t.sum().item()
demo = Demo()
add_tracker(demo.compute, "compute")
add_tracker(demo.gpu_alloc, "gpu_alloc")
# Idempotent re-call
add_tracker(demo.compute, "compute")
for _ in range(5):
demo.compute(15_000)
demo.gpu_alloc(300_000)
print_tracker_stats()
print("\n--- Global Combined Stats ---\n")
print_global_tracker_stats()
# Demonstrate CSV export
print("\n[tracker] Saving stats to CSV files...\n")
csv_path = "/tmp/tracker_demo_stats.csv"
save_tracker_stats_csv(csv_path)
print(f"Exported timing stats to: {csv_path.replace('.csv', '_timing.csv')}")
print(f"Exported memory stats to: {csv_path.replace('.csv', '_memory.csv')}")
print("\n[tracker] Resetting registries...\n")
reset_trackers()
print_tracker_stats()

458
src/ctx_to_lora/trainer.py Normal file
View file

@ -0,0 +1,458 @@
import logging
import torch
from torch import nn
from transformers import Trainer
from transformers.trainer_pt_utils import get_parameter_names
from transformers.trainer_utils import IntervalStrategy
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
logger = logging.getLogger()
def per_ctx_loss_ce(inputs, labels, loss):
# loss still has masked out elem (0 at labels=-100)
n_queries_per_ctx = inputs["n_queries"].tolist()
position_ids = inputs["position_ids"].squeeze(0)
# account only label positions
label_mask = labels.squeeze(0) != -100
label_pos_ids = label_mask * position_ids
label_pos_ids_diff = label_pos_ids.diff(
append=torch.tensor([0], device=position_ids.device)
)
# assumes the input starts with non-assistant tokens
start_label_pos = torch.where((label_pos_ids_diff > 0) * ~label_mask)[0]
end_label_pos = torch.where((label_pos_ids_diff < 0) * label_mask)[0]
label_seq_lens = end_label_pos - start_label_pos
# these stack and split can be optimized but let's keep it simple
# mean across tokens of each q
qa_losses = torch.stack(
[
loss[start : start + llen].mean()
for start, llen in zip(start_label_pos, label_seq_lens)
]
)
# mean across queries of each ctx
per_ctx_losses = [ql.mean() for ql in torch.split(qa_losses, n_queries_per_ctx)]
# per-ctx loss
loss = torch.stack(per_ctx_losses)
return loss
def per_ctx_loss_kl(inputs, labels, loss):
# loss is compact (label indices selected)
n_queries_per_ctx = inputs["n_queries"].tolist()
position_ids = inputs["position_ids"].squeeze(0)
# account only label positions
label_mask = labels.squeeze(0) != -100
label_pos_ids = label_mask * position_ids
label_pos_ids_diff = label_pos_ids.diff(
append=torch.tensor([0], device=position_ids.device)
)
# assumes the input starts with non-assistant tokens
start_label_pos = torch.where((label_pos_ids_diff > 0) * ~label_mask)[0]
end_label_pos = torch.where((label_pos_ids_diff < 0) * label_mask)[0]
label_seq_lens = end_label_pos - start_label_pos
# find equiv start indices in the already sliced loss vector
cu_label_seq_lens = torch.cumsum(label_seq_lens, dim=0)
start_indices = torch.cat(
(
torch.tensor([0], device=cu_label_seq_lens.device),
cu_label_seq_lens[:-1],
)
)
# these stack and split can be optimized but let's keep it simple
# mean across tokens of each q
qa_losses = torch.stack(
[loss[start:end].mean() for start, end in zip(start_indices, cu_label_seq_lens)]
)
# mean across queries of each ctx
per_ctx_losses = [ql.mean() for ql in torch.split(qa_losses, n_queries_per_ctx)]
# per-ctx loss
loss = torch.stack(per_ctx_losses)
return loss
class ModulatedModelTrainer(Trainer):
# modified from the base Trainer to support per-context average loss
def get_batch_samples(self, epoch_iterator, num_batches, device):
# only used with `use_per_ctx_average_loss=True`
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples.append(next(epoch_iterator))
except StopIteration:
break
count_num_items_in_batch = (
len(batch_samples) > 0
and "labels" in batch_samples[0]
and "n_ctx_chunks" in batch_samples[0]
)
if count_num_items_in_batch:
num_items_in_batch = dict()
num_items_in_batch["ctx"] = torch.tensor(
sum([batch["n_ctx_chunks"].numel() for batch in batch_samples])
).to(device)
# should we avg over num chunks?
# num_items_in_batch["ctx"] = sum(
# [(batch["ctx_position_ids"] == 0).sum() for batch in batch_samples]
# )
num_items_in_batch["labels"] = sum(
[(batch["labels"].ne(-100)).sum() for batch in batch_samples]
).to(device)
if num_items_in_batch is not None:
if self.args.average_tokens_across_devices:
for k in num_items_in_batch:
num_items_in_batch[k] = self.accelerator.gather(
num_items_in_batch[k]
).sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(device)
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
num_items_in_batch = num_items_in_batch.unsqueeze(0)
return batch_samples, num_items_in_batch
class DistillationTrainer(ModulatedModelTrainer):
def __init__(self, *args, **kwargs):
self.gen_lora_l1_reg_coef = kwargs.pop("gen_lora_l1_reg_coef", 0.0)
self.use_per_ctx_average_loss = kwargs.pop("use_per_ctx_average_loss", False)
super().__init__(*args, **kwargs)
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# NOTE: the loss output from this fn will be ***added***
# meaning that we should always scale the loss wrt `num_items_in_batch`
# (average over the number of items in the accumulated batch)
is_train = num_items_in_batch is not None
labels = inputs.pop("labels", None)
label_pos = torch.where(labels != -100)
outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
if "logprobs_vals" not in inputs:
return (torch.tensor(0.0), outputs) if return_outputs else torch.tensor(0.0)
target_logp = inputs.pop("logprobs_vals").squeeze(0)
indices = inputs.pop("logprobs_indices").squeeze(0)
assert label_pos[0].shape[0] == target_logp.shape[0], (
"Label positions and target log probabilities should have the same # tokens."
f"Got : {label_pos[0].shape[0]=} and {target_logp.shape[0]=}"
)
##### KL loss
outputs_logits = outputs.logits[label_pos[0], label_pos[1] - 1] # shift back 1
logq_full_denom = torch.logsumexp(outputs_logits, dim=-1, keepdim=True) # (N,1)
selected_logits = outputs_logits.gather(1, indices) # (N,K)
# log softmax at selected indices
logq_selected = selected_logits - logq_full_denom
p = target_logp.exp()
loss = -(p * logq_selected).sum(dim=-1)
# teacher_logp = torch.full_like(outputs_logits, -torch.inf)
# teacher_logp.scatter_(1, indices, target_logp)
# # reduction = "batchmean" if num_items_in_batch is None else "sum"
# p = teacher_logp.exp()
# logq = nn.functional.log_softmax(outputs_logits, dim=-1)
# loss = -torch.sum(p * logq, dim=-1)
if self.use_per_ctx_average_loss:
loss = per_ctx_loss_kl(inputs, labels, loss)
if is_train:
if self.use_per_ctx_average_loss:
loss = loss.sum() / num_items_in_batch["ctx"]
else:
loss = loss.sum() / num_items_in_batch["labels"]
else:
# eval
loss = loss.mean()
# if reduction == "batchmean":
# loss = loss.mean()
# elif reduction == "sum":
# # loss does not scale with grad acc
# # num_items_in_batch does
# # this works for both token-avg and ctx-avg
# # loss = loss.sum() / num_items_in_batch
# `num_items_in_batch` is # tokens if `args.use_ctx_average_loss=False``
# loss = loss.sum() / num_items_in_batch
#####
##### unpack gen lora dict and compute regularization loss
l1_norm = 0
n_modules = len(gen_loras)
for module, lora in gen_loras.items():
l1_norm += lora["A"].abs().sum(0).mean() + lora["B"].abs().sum(0).mean()
l1_norm /= n_modules
if is_train:
# during eval `num_items_in_batch` will be None
l1_norm /= num_items_in_batch["ctx"]
total_loss = loss + self.gen_lora_l1_reg_coef * l1_norm
#####
scaler = self.args.gradient_accumulation_steps if is_train else 1
if self.args.average_tokens_across_devices and is_train:
total_loss *= self.accelerator.num_processes
scaler *= self.accelerator.num_processes
# rough estimate of the losses (we only log the values from one step)
if (self.state.global_step == 1 and self.args.logging_first_step) or (
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.state.logging_steps == 0
):
# compensate `num_items_in_batch` division
self.log(
{
"kl_loss": loss.item() * scaler,
"gen_lora_l1_norm": l1_norm.item() * scaler,
}
)
return (total_loss, outputs) if return_outputs else total_loss
def causal_lm_ce_loss(
logits,
labels,
vocab_size: int,
num_items_in_batch: torch.Tensor | None = None,
ignore_index: int = -100,
shift_labels: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
# Upcast to float if we need to compute the loss to avoid potential precision issues
logits = logits.float()
if shift_labels is None:
# Shift so that tokens < n predict n
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
# loss = fixed_cross_entropy(
# logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
# )
loss = nn.functional.cross_entropy(logits, shift_labels, reduction="none")
return loss
class CrossEntropyTrainer(ModulatedModelTrainer):
def __init__(self, *args, **kwargs):
self.gen_lora_l1_reg_coef = kwargs.pop("gen_lora_l1_reg_coef", 0.0)
self.use_per_ctx_average_loss = kwargs.pop("use_per_ctx_average_loss", False)
super().__init__(*args, **kwargs)
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
"""
How the loss is computed by Trainer.
By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
is_train = num_items_in_batch is not None
labels = inputs.pop("labels", None)
outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
# [1, tot_seq_len]
logits = outputs.logits
# [tot_seq_len]
loss = causal_lm_ce_loss(logits, labels, self.model.vocab_size)
if self.use_per_ctx_average_loss:
loss = per_ctx_loss_ce(inputs, labels, loss)
if is_train:
if self.use_per_ctx_average_loss:
loss = loss.sum() / num_items_in_batch["ctx"]
else:
loss = loss.sum() / num_items_in_batch["labels"]
else:
# eval
loss = loss.mean()
#####
# if is_train:
# if self.use_per_ctx_average_loss:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch["ctx"]
# else:
# loss_kwargs["num_items_in_batch"] = num_items_in_batch["labels"]
# inputs = {**inputs, **loss_kwargs}
# outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
# # Save past state if it exists
# if self.args.past_index >= 0:
# self._past = outputs[self.args.past_index]
# if labels is not None:
# unwrapped_model = self.accelerator.unwrap_model(model)
# if _is_peft_model(unwrapped_model):
# model_name = unwrapped_model.base_model.model._get_name()
# else:
# model_name = unwrapped_model._get_name()
# # User-defined compute_loss function
# if self.compute_loss_func is not None:
# loss = self.compute_loss_func(
# outputs, labels, num_items_in_batch=num_items_in_batch["labels"]
# )
# elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
# loss = self.label_smoother(outputs, labels, shift_labels=True)
# else:
# loss = self.label_smoother(outputs, labels)
# else:
# if isinstance(outputs, dict) and "loss" not in outputs:
# raise ValueError(
# "The model did not return a loss from the inputs, "
# "only the following keys: "
# f"{','.join(outputs.keys())}. "
# "For reference, the inputs it received are "
# f"{','.join(inputs.keys())}."
# )
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
#####
##### unpack gen lora dict and compute regularization loss
l1_norm = 0
n_modules = len(gen_loras)
for module, lora in gen_loras.items():
l1_norm += lora["A"].abs().sum(0).mean() + lora["B"].abs().sum(0).mean()
l1_norm /= n_modules
if is_train:
# during eval `num_items_in_batch` will be None
l1_norm /= num_items_in_batch["ctx"]
total_loss = loss + self.gen_lora_l1_reg_coef * l1_norm
#####
scaler = self.args.gradient_accumulation_steps if is_train else 1
if self.args.average_tokens_across_devices and is_train:
total_loss *= self.accelerator.num_processes
scaler *= self.accelerator.num_processes
# rough estimate of the losses (we only log the values from one step)
if (self.state.global_step == 1 and self.args.logging_first_step) or (
self.args.logging_strategy == IntervalStrategy.STEPS
and self.state.global_step % self.state.logging_steps == 0
):
# compensate `num_items_in_batch` division
self.log(
{
"ce_loss": loss.item() * scaler,
"gen_lora_l1_norm": l1_norm.item() * scaler,
}
)
return (total_loss, outputs) if return_outputs else total_loss
def get_decay_parameter_names(model) -> list[str]:
"""
Get all parameter names that weight decay will be applied to.
This function filters out parameters in two ways:
1. By layer type (nn.Embedding)
2. By parameter name patterns (containing 'bias', 'layernorm', 'rmsnorm'
or 'latents_q' [perceiver's latent queries]).
"""
decay_parameters = get_parameter_names(
model,
[nn.Embedding, nn.LayerNorm],
["scaler", "bias", "layernorm", "rmsnorm", "latents_q"],
)
return decay_parameters
def train_model(
model,
training_args,
train_dataset=None,
val_dataset=None,
train_collator=None,
compute_metrics=None,
):
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
logger.info(f"Resuming from the checkpoint: {checkpoint}")
trainer_kwargs = dict(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
data_collator=train_collator,
compute_metrics=compute_metrics,
)
is_modulated_model = isinstance(model, ModulatedPretrainedModel)
trainer_cls = Trainer
if is_modulated_model:
logger.info("Training with modulated model.")
trainer_cls = CrossEntropyTrainer
trainer_kwargs["gen_lora_l1_reg_coef"] = training_args.gen_lora_l1_reg_coef
trainer_kwargs["use_per_ctx_average_loss"] = (
training_args.use_per_ctx_average_loss
)
del training_args.gen_lora_l1_reg_coef
del training_args.use_per_ctx_average_loss
if training_args.use_kl_loss:
logger.info("Training with distillation loss. Using DistillationTrainer.")
trainer_cls = DistillationTrainer
del training_args.use_kl_loss
if training_args.auto_find_batch_size:
# set the batch size to some high number
# which will be lowered by the Trainer
training_args.per_device_train_batch_size = 128
trainer = trainer_cls(**trainer_kwargs)
# if getattr(trainer, "use_per_ctx_average_loss", False):
# trainer.get_batch_samples = trainer.get_batch_samples_ctx
# MONKEY PATCH: remove embedding layers from weight decay
trainer.get_decay_parameter_names = get_decay_parameter_names
# Trainer loads the best model after training
# is done when load_best_model_at_end=True
train_result = trainer.train(resume_from_checkpoint=checkpoint)
trainer.log_metrics("train", train_result.metrics)
trainer.save_model()
# TODO: add benchmark eval?
# clear_gpu()

275
src/ctx_to_lora/utils.py Normal file
View file

@ -0,0 +1,275 @@
import ast
import gc
import hashlib
import logging
import os
import random
import string
import time
from collections.abc import Iterable
from contextlib import contextmanager
from enum import Enum
import torch
import yaml
from peft import PeftConfig, PeftModel
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
from peft.utils import get_peft_model_state_dict
TRAINING_TASK = Enum("TRAINING_TASK", ["CAUSAL_LM", "COMPLETION"])
logger = logging.getLogger()
# taken from https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
@contextmanager
def evaluating(*models):
"""Temporarily switch to evaluation mode."""
is_training = [model.training if model is not None else False for model in models]
try:
for model in models:
if model is not None:
model.eval()
yield models
finally:
for model, training in zip(models, is_training):
if model is not None:
model.train(training)
def get_layers(model):
if hasattr(model, "model"):
return get_layers(model.model)
return model.layers
def get_num_layers(model):
return len(get_layers(model))
def get_base_model(model):
if hasattr(model, "model"):
return get_base_model(model.model)
return model
def get_num_params(model):
total_params = 0
trainable_params = 0
for p in model.parameters():
total_params += p.numel()
if p.requires_grad:
trainable_params += p.numel()
return total_params, trainable_params
def log_num_train_params(model):
logger.debug("Trainable model parameters:")
for name, p in model.named_parameters():
if p.requires_grad:
logger.debug(f"{name}, dtype:{p.dtype}")
num_total_params, num_trainable_params = get_num_params(model)
logger.info(
f"trainable params: {num_trainable_params:,d} "
f"|| all params: {num_total_params:,d} "
f"|| trainable%: {100 * num_trainable_params / num_total_params:.4f}"
)
def get_run_name(seed_str: str | None = None):
if not seed_str:
uuid = "".join(
[random.choice(string.ascii_letters + string.digits) for _ in range(8)]
)
run_name = time.strftime("%Y%m%d-%H%M%S") + f"_{uuid}"
else:
# Generate a UUID from the seed string
hash_object = hashlib.sha256(seed_str.encode())
uuid = hash_object.hexdigest()[:8] # Take the first 8 characters of the hash
run_name = seed_str + f"_{uuid}"
return run_name
def try_convert(s):
try:
return ast.literal_eval(s)
except:
return s
def extract_cli_args(argv: list[str]):
out = dict()
for elem in argv:
if elem.endswith(".yaml"):
out["config"] = elem
elif elem.startswith("--"):
k, v = elem.split("=")
k = k.split("--")[1]
v = try_convert(v)
# if k.startswith('env_'):
# k = k.split('_')[1]
out[k] = v
return out
def setup_logging(output_dir, debug=False):
global logger
os.makedirs(output_dir, exist_ok=True)
log_formatter = logging.Formatter(
fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
stream_level = logging.DEBUG if debug else logging.INFO
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(log_formatter)
stream_handler.setLevel(stream_level)
logger.addHandler(stream_handler)
log_path = f"{output_dir}/debug.log"
debug_handler = logging.FileHandler(log_path, delay=True)
debug_handler.setLevel(logging.DEBUG)
debug_handler.setFormatter(log_formatter)
logger.addHandler(debug_handler)
logger.setLevel(logging.DEBUG)
logger.info(f"Logging to: {log_path}")
def validate_args(args_list):
# there shouldn't be overlap between args
keys = set()
for args in args_list:
logger.debug(args)
args_keys = set(vars(args).keys())
assert len(keys & args_keys) == 0, "Overlap between args"
keys |= args_keys
def save_yaml(data, path):
# Filter out non-primitive fields
data = {
k: v
for k, v in data.items()
if isinstance(v, (int, float, str, bool, list, dict, type(None)))
}
with open(path, "w") as file:
yaml.dump(data, file)
def get_peft_modules(model: PeftModel, peft_config: PeftConfig) -> list[dict[str, str]]:
return [
{"name": name, "module": module}
for name, module in model.named_modules()
if name.split(".")[-1] in peft_config.target_modules
and isinstance(module, BaseTunerLayer)
and check_target_module_exists(peft_config, name)
]
def get_peft_in_out_features(
model: PeftModel,
peft_config: PeftConfig | None = None,
) -> tuple[dict[str, int], dict[str, int]]:
if peft_config is None:
return None, None
in_features = dict()
out_features = dict()
for module_info in get_peft_modules(model, peft_config):
module_name = module_info["name"]
module = module_info["module"]
# support just Linear layer for now
# all modules should be a leave module that is Linear layer
assert isinstance(module.base_layer, torch.nn.Linear), (
"all modules should be a leave module that is Linear layer"
)
# this should always pass
name = module_name.split(".")[-1]
assert name in peft_config.target_modules
if name not in in_features:
in_features[name] = module.in_features
out_features[name] = module.out_features
else:
# assumes each module has the same input and output features
assert in_features[name] == module.in_features
assert out_features[name] == module.out_features
return in_features, out_features
def generated_lora_to_state_dict(
lora_dict: dict,
module_names: dict,
target_modules: list[str],
layer_indices: Iterable[int],
) -> dict:
lora_state_dict = dict()
for target_module in target_modules:
for layer_idx in layer_indices:
for module_name in module_names[target_module][layer_idx]:
if "lora_A" in module_name:
lora_state_dict[module_name] = (
lora_dict[target_module]["A"][layer_idx].cpu().contiguous()
)
elif "lora_B" in module_name:
lora_state_dict[module_name] = (
lora_dict[target_module]["B"][layer_idx].cpu().contiguous()
)
else:
raise ValueError(f"Unexpected module name: {module_name}")
return lora_state_dict
def get_lora_module_names(
model: PeftModel,
target_modules: list[str],
layer_indices: Iterable[int],
) -> dict[str, list[str]]:
module_names = {
target_module: [[] for _ in range(len(layer_indices))]
for target_module in target_modules
}
for k in get_peft_model_state_dict(model):
if "lora" not in k:
continue
layer_idx = int(k.split("layers.")[-1].split(".")[0])
if layer_idx in layer_indices:
for target_module in target_modules:
if target_module in k:
module_names[target_module][layer_idx].append(k)
break
return module_names
def compile_linear(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
module.compile()
def clear_gpu():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()
def concat_list(l):
out = []
for x in l:
out += x
return out
def check_is_iterable(x):
try:
iter(x)
except TypeError:
return False
return True

419
train.py Executable file
View file

@ -0,0 +1,419 @@
import contextlib
import logging
import os
from copy import deepcopy
from functools import partial
import numpy as np
import torch
import wandb
from datasets import disable_caching
from peft import PeftModel
from transformers import (
AutoConfig,
set_seed,
)
from ctx_to_lora.configs import (
AggregatorArguments,
ArgumentParser,
CtxEncoderArguments,
CtxTrainingArguments,
DataArguments,
ExperimentSetup,
HypernetArguments,
LoRAArguments,
ModelArguments,
TrainingArguments,
)
from ctx_to_lora.data.collator import ( # train_packed_collator,; DefaultDataCollator,
flatten_if_not_packed,
)
from ctx_to_lora.data.processing import get_tokenized_dataset, pack
from ctx_to_lora.metrics import (
Evaluator,
compute_metrics,
compute_per_token_acc,
compute_perplexity,
compute_prefix_matching,
)
from ctx_to_lora.model_loading import (
check_is_vision_model,
get_lora_config,
get_model_and_tokenizer,
get_tokenizer,
)
from ctx_to_lora.modeling.hypernet import (
ModulatedPretrainedModel,
get_hypernet_config,
)
from ctx_to_lora.trainer import train_model
from ctx_to_lora.utils import (
compile_linear,
extract_cli_args,
get_run_name,
log_num_train_params,
save_yaml,
setup_logging,
validate_args,
)
logger = logging.getLogger()
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
def main():
############ Argument parsing
parser = ArgumentParser(
(
DataArguments,
CtxTrainingArguments,
ModelArguments,
LoRAArguments,
TrainingArguments,
HypernetArguments,
AggregatorArguments,
CtxEncoderArguments,
)
)
(
data_args,
ctx_args,
model_args,
lora_args,
training_args,
hypernet_args,
aggregator_args,
ctx_encoder_args,
) = parser.parse()
# there shouldn't be overlap between args
validate_args(
[
data_args,
ctx_args,
model_args,
lora_args,
training_args,
hypernet_args,
aggregator_args,
ctx_encoder_args,
]
)
assert ctx_args.use_sequence_packing, (
f"Please set use_sequence_packing=True in {ctx_args}. It's faster!"
)
set_seed(training_args.seed)
checkpoint_dir = training_args.resume_from_checkpoint
# should be the same across processes
# still possible to have a name crash though
# logging_dir is just "runs/DATE_TIME_HOSTNAME"
slurm_job_id = f"_{os.getenv('SLURM_JOB_ID')}" if os.getenv("SLURM_JOB_ID") else ""
run_name = (
get_run_name(seed_str=training_args.logging_dir.strip("runs/") + slurm_job_id)
if not checkpoint_dir
else checkpoint_dir.strip("/").split("/")[-2]
)
output_dir = f"train_outputs/runs/{run_name}"
setup_logging(output_dir, debug=os.getenv("DEBUG", False))
logger.debug(f"CMD: {' '.join(os.sys.argv)}")
cli_args = extract_cli_args(os.sys.argv)
save_yaml(cli_args, f"{output_dir}/cli_args.yaml")
if "config" in cli_args:
config_name = os.path.basename(cli_args["config"]).split(".yaml")[0]
os.environ["WANDB_TAGS"] = config_name
run_name = os.path.basename(output_dir)
training_args.run_name = run_name
training_args.output_dir = output_dir
training_args.logging_dir = output_dir
if (
training_args.lr_scheduler_type == "cosine_with_min_lr"
and training_args.lr_scheduler_kwargs is None
):
training_args.lr_scheduler_kwargs = {"min_lr": 1e-7}
args = {
**vars(deepcopy(data_args)),
**vars(deepcopy(ctx_args)),
**vars(deepcopy(model_args)),
**vars(deepcopy(lora_args)),
**vars(deepcopy(training_args)),
**vars(deepcopy(hypernet_args)),
**vars(deepcopy(aggregator_args)),
**vars(deepcopy(ctx_encoder_args)),
}
args["deepspeed_plugin"] = None
logger.debug(f"args: {args}")
save_yaml(args, f"{output_dir}/args.yaml")
############ Model setup
if not ctx_args.from_pretrained_checkpoint:
model_name = model_args.model_name_or_path
base_model, tokenizer = get_model_and_tokenizer(
**vars(model_args),
train=True,
requires_grad=False,
peft_config=get_lora_config(model_name, **vars(lora_args)),
)
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_name is not None:
ctx_encoder_model_config = AutoConfig.from_pretrained(
ctx_name, trust_remote_code=True
)
if ("Llama" in ctx_name and "Vision" in ctx_name) or check_is_vision_model(
ctx_name
):
ctx_encoder_model_config = ctx_encoder_model_config.text_config
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
else:
ctx_name = base_model.base_model.config.name_or_path
ctx_encoder_model_config = base_model.config
ctx_tokenizer = tokenizer
if ctx_args.exp_setup == ExperimentSetup.HYPERLORA:
logger.info("Using HyperLoRA")
if not ctx_args.from_pretrained_checkpoint:
hypernet_config = get_hypernet_config(
base_model,
ctx_encoder_model_config,
hypernet_args,
aggregator_args,
ctx_encoder_args,
)
if ctx_encoder_args.layer_idx is None:
ctx_encoder_args.layer_idx = (
ctx_encoder_model_config.num_hidden_layers // 4
)
logger.info(
f"Using the first {ctx_encoder_args.layer_idx} layers"
" as the context encoder"
)
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_encoder_args.ctx_encoder_last_layer is None and (
ctx_name is not None and ctx_name != base_model.name_or_path
):
logger.info(
f"Setting ctx_encoder_last_layer to {base_model.name_or_path} max layers"
f":{base_model.config.num_hidden_layers}"
)
ctx_encoder_args.ctx_encoder_last_layer = (
base_model.config.num_hidden_layers
)
model = ModulatedPretrainedModel(
base_model, hypernet_config, ctx_encoder_args
)
else:
if checkpoint_dir:
ctx_args.from_pretrained_checkpoint = (
f"{checkpoint_dir}/pytorch_model.bin"
)
logger.info(
f"Loading from checkpoint: {ctx_args.from_pretrained_checkpoint}"
)
model = ModulatedPretrainedModel.from_state_dict(
torch.load(ctx_args.from_pretrained_checkpoint, weights_only=False),
train=True,
use_flash_attn=model_args.use_flash_attn,
)
tokenizer = get_tokenizer(model.base_model.config.name_or_path, train=True)
ctx_name = model.ctx_encoder_args.ctx_encoder_model_name_or_path
if ctx_name is None:
ctx_name = model.base_model.config.name_or_path
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
training_args.gen_lora_l1_reg_coef = ctx_args.gen_lora_l1_reg_coef
training_args.use_kl_loss = ctx_args.use_kl_loss
training_args.use_per_ctx_average_loss = ctx_args.use_per_ctx_average_loss
if len([p for p in model.ctx_encoder.parameters() if p.requires_grad]):
raise ValueError("ctx_encoder contains trainable parameters")
if len([p for p in model.base_model.parameters() if p.requires_grad]):
raise ValueError("base model contains trainable parameters")
model.hypernet.compile(fullgraph=True, mode="max-autotune")
else:
# activate LoRA
base_model_config = AutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=True
)
base_model_config.save_pretrained(output_dir)
logger.info("Using LoRA")
model.set_adapter("default")
model = torch.compile(model)
model.train()
logger.debug(model)
log_num_train_params(model)
############ Dataset setup
logger.info("Loading dataset...")
add_ctx_to_chat = not isinstance(model, ModulatedPretrainedModel)
ctx_model_max_len = model.ctx_encoder.config.max_position_embeddings
if ctx_args.max_ctx_len > 0:
ctx_model_max_len = ctx_args.max_ctx_len
if ctx_args.max_ctx_chunk_len <= 0:
# set default chunk size to max length of the ctx encoder
ctx_args.max_ctx_chunk_len = ctx_model_max_len
if ctx_args.num_chunk_probs is not None:
ctx_args.num_chunk_probs = {
int(k): float(v) for k, v in ctx_args.num_chunk_probs.items()
}
_get_tokenized_dataset = partial(
get_tokenized_dataset,
max_qas_len=ctx_args.max_qas_len,
max_qas_per_sample=ctx_args.max_qas_per_sample,
base_model_max_len=model.base_model.config.max_position_embeddings,
tokenizer=tokenizer,
ctx_model_max_len=ctx_model_max_len,
ctx_tokenizer=ctx_tokenizer,
add_ctx_to_chat=add_ctx_to_chat,
add_negative_prompt=ctx_args.add_negative_prompt,
max_ctx_chunk_len=ctx_args.max_ctx_chunk_len,
min_ctx_chunk_len=ctx_args.min_ctx_chunk_len,
num_chunk_probs=ctx_args.num_chunk_probs,
max_ctx_chunk_num=ctx_args.max_ctx_chunk_num,
use_kl_loss=ctx_args.use_kl_loss,
)
splits = ["train"]
if training_args.eval_strategy != "no":
splits.append("validation")
tokenized_ds = {split: {} for split in splits}
for split, ds_names in zip(
splits,
[data_args.train_ds_names, data_args.val_ds_names],
):
if not ds_names:
continue
ctx_mgr = (
training_args.main_process_first()
if split == "train"
else contextlib.nullcontext()
)
with ctx_mgr:
# process and tokenize on the main process
# then other replicas can just load the cached dataset
# we dont save cache for validation ds
for ds_name in ds_names:
ds = _get_tokenized_dataset(ds_name, split)
base_name = os.path.basename(ds_name)
if ds_name.startswith("self_gen/"):
ds_name = "self_gen/" + base_name
else:
ds_name = base_name
tokenized_ds[split][ds_name] = ds
train_ds = tokenized_ds["train"]
if data_args.max_train_samples_per_ds is not None:
for ds_name, ds in train_ds.items():
if data_args.max_train_samples_per_ds >= len(ds):
continue
train_ds[ds_name] = ds.take(data_args.max_train_samples_per_ds)
logging.info(f"train_ds: {train_ds}")
val_ds = dict()
if "validation" in tokenized_ds:
n_val_samples = data_args.max_val_samples_per_ds
for ds_name, ds in tokenized_ds["validation"].items():
if ds is None:
# take some samples from train_ds
ds = train_ds[ds_name].take(n_val_samples)
train_ds[ds_name] = train_ds[ds_name].skip(n_val_samples)
val_ds[ds_name] = ds
val_indices = np.random.permutation(len(ds))[:n_val_samples]
val_ds[ds_name] = val_ds[ds_name].select(val_indices)
with training_args.main_process_first():
train_ds = pack(
train_ds,
ctx_args.max_packed_inp_len,
ctx_args.max_packed_ctx_len,
max_packed_size=-1,
seed=training_args.seed,
num_proc=30,
)
logger.info("Setting per_device_train_batch_size to 1")
training_args.per_device_train_batch_size = 1
logger.info(f"train_ds: {train_ds}")
logger.info(f"val_ds: {val_ds}")
collator = flatten_if_not_packed
if isinstance(model, ModulatedPretrainedModel):
if isinstance(model.base_model, PeftModel):
base_model = model.base_model.base_model
else:
base_model = model.base_model
if ctx_name is not None:
logger.info("Compiling ctx_encoder_model")
ctx_base_model = model.ctx_encoder.base_model
compile_linear(ctx_base_model)
elif isinstance(model, PeftModel):
base_model = model.base_model
logger.info("Compiling base_model")
base_model.compile(fullgraph=True, mode="max-autotune")
if LOCAL_RANK == 0:
wandb.init(
project=os.getenv("WANDB_PROJECT"),
name=run_name,
group=run_name,
config=args,
tags=os.getenv("WANDB_TAGS").split(","),
notes=ctx_args.notes,
resume="allow",
)
else:
wandb.init(mode="disabled")
train_model(
model,
training_args,
train_ds,
val_ds,
collator,
compute_metrics=partial(
compute_metrics,
evaluator=Evaluator(
[compute_per_token_acc, compute_prefix_matching, compute_perplexity]
),
),
)
logger.info(f"Training run finished and saved to {output_dir}")
if __name__ == "__main__":
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["WANDB_DIR"] = ".wandb/"
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT") or "ctx_to_lora"
os.environ["WANDB_WATCH"] = ""
os.environ["WANDB_CONSOLE"] = "off"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["OMP_NUM_THREADS"] = "23"
torch._dynamo.config.capture_scalar_outputs = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if os.getenv("DEBUG", False):
disable_caching()
main()

6441
uv.lock generated Normal file

File diff suppressed because it is too large Load diff

122
watcher.py Normal file
View file

@ -0,0 +1,122 @@
import itertools
import os
import time
from argparse import Namespace
from glob import glob
import wandb
import yaml
from ctx_to_lora.eval_utils import run_eval
from ctx_to_lora.utils import clear_gpu
CP_PATTERN = "train_outputs/runs/*/checkpoint*/pytorch_model.bin"
def flatten(l):
return itertools.chain.from_iterable(l)
# handmade file watcher using glob
# not using watchdog because there are too many saved files
# but we want to just watch CP_PATTERN files
class Watcher:
def __init__(self, patterns):
self.patterns = patterns
self.files = self.get_files()
self.last_files = self.files
def get_files(self):
return set(flatten(glob(pattern) for pattern in self.patterns))
def watch(self):
self.files = self.get_files()
new_files = self.files - self.last_files
return sorted(list(new_files))
def update(self, file):
if file in self.last_files:
return
self.last_files.add(file)
print(f"Added {file} to evaluated files.")
def save_state(self):
with open("watcher_state.yaml", "w") as f:
yaml.dump({"last_files": self.last_files}, f)
def load_state(self):
if not os.path.exists("watcher_state.yaml"):
return
with open("watcher_state.yaml") as f:
state = yaml.safe_load(f)
self.last_files = state["last_files"]
if __name__ == "__main__":
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["WANDB_PROJECT"] = "ctx_to_lora"
watcher = Watcher([CP_PATTERN])
watcher.load_state()
print("Watching for new files...")
while True:
time.sleep(10)
new_files = watcher.watch()
for file in new_files:
# workaround to prevent loading incomplete checkpoints
time.sleep(20)
if not os.path.exists(file):
# cp is delete before we can read it
continue
run_dir = file.split("/checkpoint")[0]
run_name = run_dir.split("/")[-1]
print(f"Evaluating {file}")
args = Namespace(**yaml.unsafe_load(open(f"{run_dir}/args.yaml")))
curstep = int(file.split("checkpoint-")[1].split("/")[0])
wandb_kwargs = {
"project": os.getenv("WANDB_PROJECT"),
"group": run_name,
"name": f"{run_name}-eval",
"id": f"{run_name}-eval",
"resume": "allow",
}
wandb.init(**wandb_kwargs)
# TODO: have to change this for bigger models
eval_batch_size = 8
eval_batch_size_gen = 8
metrics = {}
# try:
# # metrics = run_eval(
# # checkpoint_path=file,
# # eval_batch_size=eval_batch_size,
# # split="validation",
# # generative=False,
# # )
# except FileNotFoundError as e:
# print(f"Error evaluating {file}: {e}. The checkpoint might be deleted.")
# continue
try:
gen_metrics = run_eval(
checkpoint_path=file,
split="validation",
eval_batch_size=eval_batch_size_gen,
max_ctx_chunk_len=args.max_ctx_chunk_len,
generative=True,
)
except FileNotFoundError as e:
print(f"The checkpoint might be deleted. Error evaluating {file}: {e}.")
gen_metrics = {}
file = ""
metrics.update(gen_metrics)
for k in metrics:
wandb.log(metrics[k], step=curstep)
wandb.finish()
print(f"Logged metrics: {metrics}")
print("=" * 80)
clear_gpu()
watcher.update(file)
watcher.save_state()

31
webui/SELF_GEN_VIEWER.md Normal file
View file

@ -0,0 +1,31 @@
# Self-Gen Data Viewer
Thanks Claude.
Running the viewer
```bash
uv run self_gen_viewer.py
```
Then open your browser and go to: **http://localhost:5001**
## Usage
1. **Select a Model Folder**: Choose from the dropdown list (e.g., `google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0`)
2. **Select a Parquet File**: Once a folder is selected, available parquet files will appear
3. **Set Number of Samples**: Adjust the sample count (default: 100, max: 1000)
4. **Click "Load Data"**: View the visualized data with context and Q&A pairs
## Data Structure
The viewer expects data in the following structure:
```
data/raw_datasets/self_gen/
├── google/
│ └── gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/
│ └── fw_qa_v2/
│ └── *.parquet
└── mistralai/
└── Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/
└── *.parquet
```

170
webui/self_gen_viewer.py Normal file
View file

@ -0,0 +1,170 @@
import traceback
from pathlib import Path
from datasets import load_dataset
from flask import Flask, jsonify, render_template, request
from transformers import AutoTokenizer
app = Flask(__name__)
# Base path for self_gen data
BASE_DATA_PATH = Path(__file__).parent.parent / "data" / "raw_datasets" / "self_gen"
# Cache for tokenizers
tokenizer_cache = {}
def get_tokenizer(model_path):
"""Get or create tokenizer with caching"""
if model_path not in tokenizer_cache:
try:
tokenizer_cache[model_path] = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
print(f"Error loading tokenizer for {model_path}: {e}")
return None
return tokenizer_cache[model_path]
def discover_folders():
"""Discover all model folders in self_gen directory"""
folders = []
if not BASE_DATA_PATH.exists():
return folders
for vendor_dir in BASE_DATA_PATH.iterdir():
if vendor_dir.is_dir():
for model_dir in vendor_dir.iterdir():
if model_dir.is_dir():
rel_path = model_dir.relative_to(BASE_DATA_PATH)
folders.append(str(rel_path))
return sorted(folders)
def discover_parquet_files(folder_path):
"""Discover all parquet files in a folder"""
full_path = BASE_DATA_PATH / folder_path
parquet_files = []
if full_path.exists():
for parquet_file in full_path.glob("**/*.parquet"):
rel_path = parquet_file.relative_to(full_path)
parquet_files.append(str(rel_path))
return sorted(parquet_files)
def extract_model_name_from_folder(folder_path):
"""Extract base model name from folder path"""
# e.g., "google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0" -> "google/gemma-2-2b-it"
parts = folder_path.split("/")
if len(parts) >= 2:
vendor = parts[0]
model_part = parts[1].split("_temp_")[0]
return f"{vendor}/{model_part}"
return None
@app.route("/")
def index():
"""Main page"""
folders = discover_folders()
return render_template("self_gen_viewer.html", folders=folders)
@app.route("/api/folders")
def api_folders():
"""API endpoint to get available folders"""
folders = discover_folders()
return jsonify({"folders": folders})
@app.route("/api/parquet_files")
def api_parquet_files():
"""API endpoint to get parquet files in a folder"""
folder = request.args.get("folder", "")
if not folder:
return jsonify({"error": "No folder specified"}), 400
files = discover_parquet_files(folder)
return jsonify({"files": files})
@app.route("/api/load_data")
def api_load_data():
"""API endpoint to load and display data from a parquet file"""
folder = request.args.get("folder", "")
parquet_file = request.args.get("file", "")
num_samples = int(request.args.get("num_samples", 100))
if not folder or not parquet_file:
return jsonify({"error": "Missing parameters"}), 400
try:
# Construct full path
full_path = BASE_DATA_PATH / folder / parquet_file
if not full_path.exists():
return jsonify({"error": f"File not found: {full_path}"}), 404
# Extract model name for tokenizer
model_name = extract_model_name_from_folder(folder)
if not model_name:
return jsonify({"error": "Could not extract model name from folder"}), 400
# Load tokenizer
tokenizer = get_tokenizer(model_name)
if tokenizer is None:
return jsonify({"error": f"Could not load tokenizer for {model_name}"}), 500
# Load dataset
ds = load_dataset(
"parquet", data_files=str(full_path), split=f"train[:{num_samples}]"
)
# Process samples
samples = []
for i, sample in enumerate(ds):
processed_sample = {
"index": i,
"ctx": tokenizer.decode(sample["ctx_ids"], skip_special_tokens=False)
if "ctx_ids" in sample
else "N/A",
"questions": [],
}
# Decode input_ids if present
if "input_ids" in sample:
if isinstance(sample["input_ids"][0], list):
# Multiple Q&A pairs
processed_sample["questions"] = [
tokenizer.decode(qa, skip_special_tokens=False)
for qa in sample["input_ids"]
]
else:
# Single item
processed_sample["questions"] = [
tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
]
samples.append(processed_sample)
return jsonify(
{
"success": True,
"num_samples": len(samples),
"model_name": model_name,
"file_path": str(parquet_file),
"samples": samples,
}
)
except Exception as e:
traceback.print_exc()
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
if __name__ == "__main__":
print(f"Data path: {BASE_DATA_PATH}")
print(f"Available folders: {discover_folders()}")
app.run(debug=True, host="0.0.0.0", port=5001)

View file

@ -0,0 +1,422 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Self-Gen Data Viewer</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
min-height: 100vh;
padding: 20px;
}
.container {
max-width: 1400px;
margin: 0 auto;
background: white;
border-radius: 15px;
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
padding: 30px;
}
h1 {
color: #667eea;
margin-bottom: 25px;
text-align: center;
font-size: 2.5em;
}
.controls {
display: flex;
flex-direction: column;
gap: 15px;
margin-bottom: 25px;
}
.control-group {
display: flex;
flex-direction: column;
}
.samples-row {
display: grid;
grid-template-columns: 1fr auto;
gap: 15px;
align-items: end;
}
label {
font-weight: 600;
margin-bottom: 5px;
color: #333;
font-size: 0.9em;
}
select,
input {
padding: 10px;
border: 2px solid #ddd;
border-radius: 8px;
font-size: 1em;
transition: border-color 0.3s;
}
select:focus,
input:focus {
outline: none;
border-color: #667eea;
}
button {
padding: 10px 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
border-radius: 8px;
cursor: pointer;
font-size: 1em;
font-weight: 600;
transition: transform 0.2s, box-shadow 0.2s;
margin-top: auto;
}
button:hover {
transform: translateY(-2px);
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
}
button:active {
transform: translateY(0);
}
button:disabled {
background: #ccc;
cursor: not-allowed;
transform: none;
}
.info-box {
background: #f8f9fa;
padding: 15px;
border-radius: 8px;
margin-bottom: 20px;
border-left: 4px solid #667eea;
}
.info-box p {
margin: 5px 0;
color: #555;
}
.loading {
text-align: center;
padding: 40px;
color: #667eea;
font-size: 1.2em;
}
.error {
background: #fee;
color: #c33;
padding: 15px;
border-radius: 8px;
margin: 20px 0;
border-left: 4px solid #c33;
}
.sample {
background: #f8f9fa;
border: 1px solid #e0e0e0;
border-radius: 10px;
padding: 20px;
margin-bottom: 20px;
transition: box-shadow 0.3s;
}
.sample:hover {
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
}
.sample-header {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 10px 15px;
border-radius: 6px;
margin-bottom: 15px;
font-weight: 600;
}
.section {
margin-bottom: 20px;
}
.section-title {
font-weight: 600;
color: #667eea;
margin-bottom: 10px;
font-size: 1.1em;
border-bottom: 2px solid #667eea;
padding-bottom: 5px;
}
.content {
background: white;
padding: 15px;
border-radius: 6px;
border: 1px solid #e0e0e0;
white-space: pre-wrap;
word-wrap: break-word;
font-family: 'Courier New', monospace;
font-size: 0.9em;
line-height: 1.6;
max-height: 400px;
overflow-y: auto;
}
.question-item {
background: #fff;
padding: 12px;
border-radius: 6px;
border: 1px solid #ddd;
margin-bottom: 10px;
}
.question-number {
background: #667eea;
color: white;
padding: 3px 8px;
border-radius: 4px;
font-size: 0.85em;
font-weight: 600;
display: inline-block;
margin-bottom: 8px;
}
#results {
margin-top: 30px;
}
.load-more {
text-align: center;
margin-top: 20px;
}
.stats {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 15px;
margin-bottom: 20px;
}
.stat-card {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px;
border-radius: 8px;
text-align: center;
}
.stat-value {
font-size: 2em;
font-weight: 700;
}
.stat-label {
font-size: 0.9em;
opacity: 0.9;
}
::-webkit-scrollbar {
width: 8px;
height: 8px;
}
::-webkit-scrollbar-track {
background: #f1f1f1;
border-radius: 4px;
}
::-webkit-scrollbar-thumb {
background: #667eea;
border-radius: 4px;
}
::-webkit-scrollbar-thumb:hover {
background: #764ba2;
}
</style>
</head>
<body>
<div class="container">
<h1>🔍 Self-Gen Data Viewer</h1>
<div class="controls">
<div class="control-group">
<label for="folder-select">Model Folder:</label>
<select id="folder-select">
<option value="">-- Select a folder --</option>
{% for folder in folders %}
<option value="{{ folder }}">{{ folder }}</option>
{% endfor %}
</select>
</div>
<div class="control-group">
<label for="file-select">Parquet File:</label>
<select id="file-select" disabled>
<option value="">-- Select a file --</option>
</select>
</div>
<div class="samples-row">
<div class="control-group">
<label for="num-samples">Number of Samples:</label>
<input type="number" id="num-samples" value="100" min="1" max="1000" step="10">
</div>
<button id="load-btn" disabled>Load Data</button>
</div>
</div>
<div id="info" style="display: none;"></div>
<div id="results"></div>
</div>
<script>
const folderSelect = document.getElementById('folder-select');
const fileSelect = document.getElementById('file-select');
const numSamplesInput = document.getElementById('num-samples');
const loadBtn = document.getElementById('load-btn');
const infoDiv = document.getElementById('info');
const resultsDiv = document.getElementById('results');
// When folder is selected, load parquet files
folderSelect.addEventListener('change', async () => {
const folder = folderSelect.value;
fileSelect.innerHTML = '<option value="">-- Select a file --</option>';
fileSelect.disabled = true;
loadBtn.disabled = true;
if (!folder) return;
try {
const response = await fetch(`/api/parquet_files?folder=${encodeURIComponent(folder)}`);
const data = await response.json();
if (data.files && data.files.length > 0) {
data.files.forEach(file => {
const option = document.createElement('option');
option.value = file;
option.textContent = file;
fileSelect.appendChild(option);
});
fileSelect.disabled = false;
} else {
alert('No parquet files found in this folder');
}
} catch (error) {
alert('Error loading files: ' + error.message);
}
});
// Enable load button when file is selected
fileSelect.addEventListener('change', () => {
loadBtn.disabled = !fileSelect.value;
});
// Load data button
loadBtn.addEventListener('click', loadData);
async function loadData() {
const folder = folderSelect.value;
const file = fileSelect.value;
const numSamples = numSamplesInput.value;
if (!folder || !file) return;
resultsDiv.innerHTML = '<div class="loading">⏳ Loading data...</div>';
infoDiv.style.display = 'none';
try {
const response = await fetch(
`/api/load_data?folder=${encodeURIComponent(folder)}&file=${encodeURIComponent(file)}&num_samples=${numSamples}`
);
const data = await response.json();
if (data.error) {
resultsDiv.innerHTML = `<div class="error"><strong>Error:</strong> ${data.error}</div>`;
return;
}
displayData(data);
} catch (error) {
resultsDiv.innerHTML = `<div class="error"><strong>Error:</strong> ${error.message}</div>`;
}
}
function displayData(data) {
// Show info box
infoDiv.style.display = 'block';
infoDiv.innerHTML = `
<div class="info-box">
<div class="stats">
<div class="stat-card">
<div class="stat-value">${data.num_samples}</div>
<div class="stat-label">Samples</div>
</div>
<div class="stat-card">
<div class="stat-value">${data.samples.length > 0 ? data.samples[0].questions.length : 0}</div>
<div class="stat-label">Questions per Sample</div>
</div>
</div>
<p><strong>Model:</strong> ${data.model_name}</p>
<p><strong>File:</strong> ${data.file_path}</p>
</div>
`;
// Show samples
let html = '';
data.samples.forEach(sample => {
html += `
<div class="sample">
<div class="sample-header">Sample #${sample.index}</div>
<div class="section">
<div class="section-title">📝 Context</div>
<div class="content">${escapeHtml(sample.ctx)}</div>
</div>
<div class="section">
<div class="section-title">❓ Questions & Answers (${sample.questions.length})</div>
${sample.questions.map((q, i) => `
<div class="question-item">
<span class="question-number">Q&A ${i + 1}</span>
<div class="content">${escapeHtml(q)}</div>
</div>
`).join('')}
</div>
</div>
`;
});
resultsDiv.innerHTML = html;
}
function escapeHtml(text) {
const div = document.createElement('div');
div.textContent = text;
return div.innerHTML;
}
</script>
</body>
</html>