mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-05-24 14:15:15 +02:00
Doc-to-LoRA release
This commit is contained in:
commit
1abe8ae16d
92 changed files with 22131 additions and 0 deletions
34
.gitignore
vendored
Normal file
34
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
.venv/
|
||||
.cursorrules
|
||||
tmp/
|
||||
openai_batches/
|
||||
models/
|
||||
results/
|
||||
datasets/
|
||||
llm-comparator/
|
||||
trained_d2l/
|
||||
# Ignore data files but not scripts
|
||||
/data/processed_datasets/
|
||||
/data/raw_datasets/
|
||||
/data/distil/
|
||||
__pycache__/
|
||||
*egg-info/
|
||||
.vscode/
|
||||
.ipynb_checkpoints/
|
||||
wandb/
|
||||
*.bak
|
||||
*outputs/
|
||||
plots/
|
||||
*.out
|
||||
*.err
|
||||
*.pt
|
||||
*.pth
|
||||
*.bin
|
||||
*.safetensors
|
||||
*tfevents*
|
||||
watcher_state.yaml
|
||||
eval_results/
|
||||
.github/
|
||||
*.code-workspace
|
||||
.wandb/
|
||||
.ruff_cache/
|
||||
24
.pre-commit-config.yaml
Normal file
24
.pre-commit-config.yaml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
repos:
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: v3.20.0
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
args:
|
||||
- --py310-plus
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
# Ruff version.
|
||||
rev: v0.11.10
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
types_or: [python, pyi]
|
||||
args: [--fix]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [python, pyi]
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
args:
|
||||
- --profile=black
|
||||
109
README.md
Normal file
109
README.md
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
<div align="center">
|
||||
<h1>Doc-to-LoRA (D2L): Learning to Instantly Internalize Contexts</h1>
|
||||
:sparkles:<a href="https://pub.sakana.ai/doc-to-lora/">Interactive Web</a> |
|
||||
:newspaper:<a href="https://x.com/SakanaAILabs">X</a> |
|
||||
:scroll:<a href="https://arxiv.org/abs/2602.15902">Paper</a> |
|
||||
:hugs:<a href="https://huggingface.co/SakanaAI">Hugging Face</a> |
|
||||
:octocat:<a href="https://github.com/SakanaAI/doc-to-lora">GitHub</a>
|
||||
<br>A reference implementation of Doc-to-LoRA (D2L).<br>
|
||||
</div>
|
||||
<div align="center">
|
||||
<img height="300px" src="assets/overview_animation.gif" />
|
||||
</div>
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ Installation
|
||||
```
|
||||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
./install.sh
|
||||
```
|
||||
|
||||
## 🤗 Pre-Trained Models
|
||||
```
|
||||
uv run huggingface-cli login
|
||||
uv run huggingface-cli download SakanaAI/doc-to-lora --local-dir trained_d2l --include "*/"
|
||||
```
|
||||
|
||||
## 🚀 Python API Usage
|
||||
```python
|
||||
# caveat: this interface only supports non-batched inputs
|
||||
# for batched inference please see `src/ctx_to_lora/modeling/hypernet.py`
|
||||
import torch
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
|
||||
# model loading
|
||||
checkpoint_path = "trained_d2l/gemma_demo/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path, weights_only=False)
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict, train=False, use_sequence_packing=False
|
||||
)
|
||||
model.reset()
|
||||
tokenizer = get_tokenizer(model.base_model.name_or_path)
|
||||
|
||||
# prepare data
|
||||
doc = open("data/sakana_wiki.txt", "r").read()
|
||||
chat = [{"role": "user", "content": "Tell me about Sakana AI."}]
|
||||
chat_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# calls after internalization will be influenced by internalized info
|
||||
model.internalize(doc)
|
||||
|
||||
outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
|
||||
|
||||
# remove internalized info
|
||||
# model.reset()
|
||||
|
||||
# without internalized info, the model will halucinate
|
||||
# outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
# print(tokenizer.decode(outputs[0]))
|
||||
```
|
||||
|
||||
### 🎮 Interactive Demo
|
||||
```bash
|
||||
uv run demo/app.py
|
||||
```
|
||||
<div align="center">
|
||||
<h3>Video Demo</h3>
|
||||
<video src="https://github.com/user-attachments/assets/16781365-5ec2-4c1c-b4f4-aeeebe3c2be5" controls autoplay muted playsinline preload="metadata" width="900"></video>
|
||||
</div>
|
||||
|
||||
### 🧪 Experimental Scripts
|
||||
To run any of the following scripts, use `uv run $PATH_TO_SCRIPT` from the root of this project.
|
||||
|
||||
|
||||
| Experiment | Data prep | Training | Evaluation | Notes |
|
||||
| ------------------------------------ | ------------------------------------- | ----------------------------- | ---------------------------- | ----------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| [Main experiment](scripts/main_exp/) | `scripts/main_exp/0-download_data.sh` | `scripts/main_exp/1-train.sh` | `scripts/main_exp/eval/*.sh` | Downloading data is fastest; regenerate only if you need fresh synthetic data. Evaluation scripts reproduce the main paper metrics. |
|
||||
| [NIAH](scripts/niah/) | `scripts/niah/0-gen_data.sh` | `scripts/niah/1-train.sh` | `scripts/niah/2-eval.sh` | Run the scripts in order; data generation only needs to happen once |
|
||||
|
||||
|
||||
### 🔬 Self-Generated Data Viewer
|
||||
After downloading/generating the data, we can see samples of the data using this script.
|
||||
```bash
|
||||
uv run webui/self_gen_viewer.py
|
||||
```
|
||||
See more info at [webui/SELF_GEN_VIEWER.md](webui/SELF_GEN_VIEWER.md).
|
||||
|
||||
### 📚 Citation
|
||||
```bibtex
|
||||
@techreport{sakana2025doc-to-lora,
|
||||
title = {{Doc-to-LoRA: Learning to Instantly Internalize Contexts}},
|
||||
author = {Rujikorn Charakorn and Edoardo Cetin and Shinnosuke Uesaka and Robert Tjarko Lange},
|
||||
institution = {Sakana AI},
|
||||
year = {2026},
|
||||
month = {Febuary},
|
||||
note = {Technical Report}
|
||||
}
|
||||
```
|
||||
17
accelerate_config.yaml
Normal file
17
accelerate_config.yaml
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: MULTI_GPU
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: true
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 8
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
BIN
assets/cover.png
Normal file
BIN
assets/cover.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 2.3 MiB |
BIN
assets/overview_animation.gif
Normal file
BIN
assets/overview_animation.gif
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.5 MiB |
18
chat_templates/Qwen/Qwen3-4B-Instruct-2507.jinja
Normal file
18
chat_templates/Qwen/Qwen3-4B-Instruct-2507.jinja
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
{%- if messages[0].role == 'system' %}
|
||||
{{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
||||
{%- endif %}
|
||||
{%- for message in messages %}
|
||||
{%- if message.content is string %}
|
||||
{%- set content = message.content %}
|
||||
{%- else %}
|
||||
{%- set content = '' %}
|
||||
{%- endif %}
|
||||
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
|
||||
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
||||
{%- elif message.role == "assistant" %}
|
||||
{{- '<|im_start|>' + message.role + '\n' }}{% generation %}{{ content + '<|im_end|>' }}{% endgeneration %}{{ '\n' }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
24
chat_templates/google/gemma-2-2b-it.jinja
Normal file
24
chat_templates/google/gemma-2-2b-it.jinja
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
{{- bos_token }}
|
||||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
{% for message in loop_messages %}
|
||||
{% if (message['role'] == 'assistant') %}
|
||||
{% set role = 'model' %}
|
||||
{% else %}
|
||||
{% set role = message['role'] %}
|
||||
{% endif %}
|
||||
{%- if message['role'] == 'user' and loop.first and system_message is defined %}
|
||||
{{ '<start_of_turn>' + role + '\n' + system_message + '\n\n' + message['content'] | trim + '<end_of_turn>\n' }}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
{{ '<start_of_turn>' + role + '\n' }}{% generation %}{{ message['content'] + '<end_of_turn>' }}{% endgeneration %}{{ '\n' }}
|
||||
{%- else %}
|
||||
{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}
|
||||
{%- endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt %}
|
||||
{{'<start_of_turn>model\n'}}
|
||||
{% endif %}
|
||||
24
chat_templates/mistralai/Mistral-7B-Instruct-v0.2.jinja
Normal file
24
chat_templates/mistralai/Mistral-7B-Instruct-v0.2.jinja
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
{%- if messages[0]['role'] == 'system' %}
|
||||
{%- set system_message = messages[0]['content'] %}
|
||||
{%- set loop_messages = messages[1:] %}
|
||||
{%- else %}
|
||||
{%- set loop_messages = messages %}
|
||||
{%- endif %}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- for message in loop_messages %}
|
||||
{%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
|
||||
{{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}
|
||||
{%- endif %}
|
||||
{%- if message['role'] == 'user' %}
|
||||
{%- if loop.first and system_message is defined %}
|
||||
{{- ' [INST] ' + system_message + '\n\n' + message['content'] + ' [/INST] ' }}
|
||||
{%- else %}
|
||||
{{- ' [INST] ' + message['content'] + ' [/INST] ' }}
|
||||
{%- endif %}
|
||||
{%- elif message['role'] == 'assistant' %}
|
||||
{% generation %}{{ message['content'] + eos_token }}{% endgeneration %}
|
||||
{%- else %}
|
||||
{{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}
|
||||
{%- endif %}
|
||||
{%- endfor %}
|
||||
31
configs/main_exp/mistral/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
31
configs/main_exp/mistral/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# LoRA
|
||||
lora_r: 8
|
||||
lora_dropout: 0.0
|
||||
target_modules:
|
||||
- down_proj
|
||||
|
||||
use_kl_loss: true
|
||||
|
||||
ctx_encoder_type: per_layer_activations
|
||||
n_latent_queries: 8
|
||||
num_blocks: 9
|
||||
num_self_attn_per_block: 0
|
||||
|
||||
gradient_accumulation_steps: 11
|
||||
max_packed_inp_len: 6144
|
||||
max_packed_ctx_len: 6144
|
||||
|
||||
# data
|
||||
train_ds_names:
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_0.0/pwc_compact
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/squad_compact
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/ropes_compact
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/drop_compact
|
||||
|
||||
val_ds_names:
|
||||
- squad
|
||||
- pwc
|
||||
- drop
|
||||
- ropes
|
||||
- self_gen/mistralai/Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet
|
||||
30
configs/main_exp/qwen/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
30
configs/main_exp/qwen/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# LoRA
|
||||
lora_r: 8
|
||||
lora_dropout: 0.0
|
||||
target_modules:
|
||||
- down_proj
|
||||
|
||||
use_kl_loss: true
|
||||
|
||||
ctx_encoder_type: per_layer_activations
|
||||
n_latent_queries: 8
|
||||
num_blocks: 9
|
||||
num_self_attn_per_block: 0
|
||||
|
||||
gradient_accumulation_steps: 11
|
||||
max_packed_inp_len: 6144
|
||||
max_packed_ctx_len: 6144
|
||||
|
||||
# data
|
||||
train_ds_names:
|
||||
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
|
||||
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_0.0/pwc_compact
|
||||
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/squad_compact
|
||||
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/ropes_compact
|
||||
- self_gen/Qwen/Qwen3-4B-Instruct-2507_temp_0.0_closed_qa_prob_1.0/drop_compact
|
||||
|
||||
val_ds_names:
|
||||
- squad
|
||||
- pwc
|
||||
- drop
|
||||
- ropes
|
||||
31
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
31
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# LoRA
|
||||
lora_r: 8
|
||||
lora_dropout: 0.0
|
||||
target_modules:
|
||||
- down_proj
|
||||
|
||||
use_kl_loss: true
|
||||
|
||||
ctx_encoder_type: per_layer_activations
|
||||
n_latent_queries: 8
|
||||
num_blocks: 9
|
||||
num_self_attn_per_block: 0
|
||||
|
||||
gradient_accumulation_steps: 11
|
||||
max_packed_inp_len: 6144
|
||||
max_packed_ctx_len: 6144
|
||||
|
||||
# data
|
||||
train_ds_names:
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/pwc_compact
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/squad_compact
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/ropes_compact
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/drop_compact
|
||||
|
||||
val_ds_names:
|
||||
- squad
|
||||
- pwc
|
||||
- drop
|
||||
- ropes
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet
|
||||
27
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml
Normal file
27
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# LoRA
|
||||
lora_r: 8
|
||||
lora_dropout: 0.0
|
||||
target_modules:
|
||||
- down_proj
|
||||
|
||||
use_kl_loss: true
|
||||
|
||||
ctx_encoder_type: per_layer_activations
|
||||
n_latent_queries: 8
|
||||
num_blocks: 9
|
||||
num_self_attn_per_block: 0
|
||||
|
||||
gradient_accumulation_steps: 11
|
||||
max_packed_inp_len: 6144
|
||||
max_packed_ctx_len: 6144
|
||||
|
||||
# data
|
||||
train_ds_names:
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/fw_qa_v2/min_0_to_2000/train/*level_1*.parquet
|
||||
|
||||
val_ds_names:
|
||||
- squad
|
||||
- pwc
|
||||
- drop
|
||||
- ropes
|
||||
- self_gen/google/gemma-2-2b-it_temp_0.0_closed_qa_prob_0.0/fw_qa_v2/min_0_to_2000/train/*level_0_val*.parquet
|
||||
14
configs/niah_exp/ctx_magic_number_32_256.yaml
Normal file
14
configs/niah_exp/ctx_magic_number_32_256.yaml
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# LoRA
|
||||
lora_r: 8
|
||||
lora_dropout: 0.0
|
||||
target_modules:
|
||||
- down_proj
|
||||
|
||||
# data
|
||||
train_ds_names:
|
||||
- ctx_magic_number_32_128
|
||||
- ctx_magic_number_128_256
|
||||
|
||||
val_ds_names:
|
||||
- ctx_magic_number_32_128
|
||||
- ctx_magic_number_128_256
|
||||
42
data/build_drop_compact.py
Normal file
42
data/build_drop_compact.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import gc
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds_name = "ucinlp/drop"
|
||||
|
||||
for split in ["train", "validation"]:
|
||||
ctx_qa_dict = dict()
|
||||
ds = load_dataset(ds_name, split=split)
|
||||
print(f"Original size: {len(ds)}")
|
||||
for i, sample in tqdm(enumerate(ds)):
|
||||
ctx = sample["passage"]
|
||||
if ctx not in ctx_qa_dict:
|
||||
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
|
||||
question = sample["question"]
|
||||
answer = sample["answers_spans"]["spans"][0]
|
||||
ctx_qa_dict[ctx]["prompts"].append(question)
|
||||
ctx_qa_dict[ctx]["responses"].append(answer)
|
||||
|
||||
print(f"Unique contexts: {len(ctx_qa_dict)}")
|
||||
# convert ctx_qa_dict to a list of dictionaries
|
||||
samples = [
|
||||
{
|
||||
"context": ctx,
|
||||
"prompts": ctx_qa_dict[ctx]["prompts"],
|
||||
"responses": ctx_qa_dict[ctx]["responses"],
|
||||
}
|
||||
for ctx in ctx_qa_dict
|
||||
]
|
||||
print(f"Sampled data: {samples[0]}")
|
||||
# breakpoint()
|
||||
# save to a new dataset
|
||||
ds = Dataset.from_list(samples)
|
||||
|
||||
save_path = f"./data/raw_datasets/drop_compact/{split}/ds.parquet"
|
||||
print(f"Saving dataset to {save_path}")
|
||||
ds.to_parquet(save_path)
|
||||
print("=" * 80)
|
||||
del ds, samples, ctx_qa_dict
|
||||
gc.collect()
|
||||
43
data/build_pwc_compact.py
Normal file
43
data/build_pwc_compact.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import gc
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds_name = "sggetao/PwC"
|
||||
|
||||
for split in ["train", "test"]:
|
||||
ctx_qa_dict = dict()
|
||||
ds = load_dataset(ds_name, split=split)
|
||||
print(f"Original size: {len(ds)}")
|
||||
for i, sample in tqdm(enumerate(ds)):
|
||||
ctx = sample["input"]
|
||||
if ctx not in ctx_qa_dict:
|
||||
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
|
||||
# question = closed_qa_prompting(sample["prompt"])
|
||||
question = sample["prompt"]
|
||||
answer = sample["answer"]
|
||||
ctx_qa_dict[ctx]["prompts"].append(question)
|
||||
ctx_qa_dict[ctx]["responses"].append(answer)
|
||||
|
||||
print(f"Unique contexts: {len(ctx_qa_dict)}")
|
||||
# convert ctx_qa_dict to a list of dictionaries
|
||||
samples = [
|
||||
{
|
||||
"context": ctx,
|
||||
"prompts": ctx_qa_dict[ctx]["prompts"],
|
||||
"responses": ctx_qa_dict[ctx]["responses"],
|
||||
}
|
||||
for ctx in ctx_qa_dict
|
||||
]
|
||||
print(f"Sampled data: {samples[0]}")
|
||||
# breakpoint()
|
||||
# save to a new dataset
|
||||
ds = Dataset.from_list(samples)
|
||||
|
||||
save_path = f"./data/raw_datasets/pwc_compact/{split}/ds.parquet"
|
||||
print(f"Saving dataset to {save_path}")
|
||||
ds.to_parquet(save_path)
|
||||
print("=" * 80)
|
||||
del ds, samples, ctx_qa_dict
|
||||
gc.collect()
|
||||
45
data/build_ropes_compact.py
Normal file
45
data/build_ropes_compact.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
import gc
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds_name = "allenai/ropes"
|
||||
|
||||
for split in ["train", "validation"]:
|
||||
ctx_qa_dict = dict()
|
||||
ds = load_dataset(ds_name, split=split)
|
||||
print(f"Original size: {len(ds)}")
|
||||
for i, sample in tqdm(enumerate(ds)):
|
||||
ctx_template = "{background}\n{situation}"
|
||||
response = sample["answers"]["text"][0]
|
||||
bg_txt = sample["background"]
|
||||
situation_txt = sample["situation"]
|
||||
ctx = ctx_template.format(background=bg_txt, situation=situation_txt)
|
||||
q = sample["question"]
|
||||
if ctx not in ctx_qa_dict:
|
||||
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
|
||||
ctx_qa_dict[ctx]["prompts"].append(q)
|
||||
ctx_qa_dict[ctx]["responses"].append(response)
|
||||
|
||||
print(f"Unique contexts: {len(ctx_qa_dict)}")
|
||||
# convert ctx_qa_dict to a list of dictionaries
|
||||
samples = [
|
||||
{
|
||||
"context": ctx,
|
||||
"prompts": ctx_qa_dict[ctx]["prompts"],
|
||||
"responses": ctx_qa_dict[ctx]["responses"],
|
||||
}
|
||||
for ctx in ctx_qa_dict
|
||||
]
|
||||
print(f"Sampled data: {samples[0]}")
|
||||
# breakpoint()
|
||||
# save to a new dataset
|
||||
ds = Dataset.from_list(samples)
|
||||
|
||||
save_path = f"./data/raw_datasets/ropes_compact/{split}/ds.parquet"
|
||||
print(f"Saving dataset to {save_path}")
|
||||
ds.to_parquet(save_path)
|
||||
print("=" * 80)
|
||||
del ds, samples, ctx_qa_dict
|
||||
gc.collect()
|
||||
42
data/build_squad_compact.py
Normal file
42
data/build_squad_compact.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
import gc
|
||||
|
||||
from datasets import Dataset, load_dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
if __name__ == "__main__":
|
||||
ds_name = "data/raw_datasets/squad"
|
||||
|
||||
for split in ["train", "validation"]:
|
||||
ctx_qa_dict = dict()
|
||||
ds = load_dataset(ds_name, split=split)
|
||||
print(f"Original size: {len(ds)}")
|
||||
for i, sample in tqdm(enumerate(ds)):
|
||||
ctx = sample["context"]
|
||||
if ctx not in ctx_qa_dict:
|
||||
ctx_qa_dict[ctx] = {"prompts": [], "responses": []}
|
||||
question = sample["question"]
|
||||
answer = sample["answers"]["text"][0]
|
||||
ctx_qa_dict[ctx]["prompts"].append(question)
|
||||
ctx_qa_dict[ctx]["responses"].append(answer)
|
||||
|
||||
print(f"Unique contexts: {len(ctx_qa_dict)}")
|
||||
# convert ctx_qa_dict to a list of dictionaries
|
||||
samples = [
|
||||
{
|
||||
"context": ctx,
|
||||
"prompts": ctx_qa_dict[ctx]["prompts"],
|
||||
"responses": ctx_qa_dict[ctx]["responses"],
|
||||
}
|
||||
for ctx in ctx_qa_dict
|
||||
]
|
||||
print(f"Sampled data: {samples[0]}")
|
||||
# breakpoint()
|
||||
# save to a new dataset
|
||||
ds = Dataset.from_list(samples)
|
||||
|
||||
save_path = f"./data/raw_datasets/squad_compact/{split}/ds.parquet"
|
||||
print(f"Saving dataset to {save_path}")
|
||||
ds.to_parquet(save_path)
|
||||
print("=" * 80)
|
||||
del ds, samples, ctx_qa_dict
|
||||
gc.collect()
|
||||
10
data/download_fineweb_edu.py
Normal file
10
data/download_fineweb_edu.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == "__main__":
|
||||
fw_dir = "./data/raw_datasets/fineweb_edu/"
|
||||
snapshot_download(
|
||||
"HuggingFaceFW/fineweb-edu",
|
||||
repo_type="dataset",
|
||||
local_dir=fw_dir,
|
||||
allow_patterns="sample/100BT/*",
|
||||
)
|
||||
288
data/generate_ctx_magic_number.py
Normal file
288
data/generate_ctx_magic_number.py
Normal file
|
|
@ -0,0 +1,288 @@
|
|||
import argparse
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
# -----------------------------
|
||||
# Config knobs (edit or use CLI)
|
||||
# -----------------------------
|
||||
TOKENS_PER_BLOCK = 40 # rough heuristic tokens per noise block
|
||||
BASE_SAMPLES_PER_BIN = (
|
||||
320_000 # training samples budget scaler only (val/test fixed at 1000 each)
|
||||
)
|
||||
RNG_SEED = 42
|
||||
NOISE_BLOCK = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
|
||||
SPECIAL_TPL = "The special magic number is {magic_number}."
|
||||
SEP = "\n" # between blocks
|
||||
|
||||
|
||||
def save_jsonl(data: list[dict], filepath: str) -> None:
|
||||
parent_dir = os.path.dirname(filepath)
|
||||
if parent_dir:
|
||||
os.makedirs(parent_dir, exist_ok=True)
|
||||
with open(filepath, "w") as f:
|
||||
for entry in data:
|
||||
json.dump(entry, f)
|
||||
f.write("\n")
|
||||
|
||||
|
||||
essential_digits4 = lambda: f"{random.randint(0, 9_999):04d}"
|
||||
|
||||
|
||||
def _choose_position(total_blocks: int, depth_bin: int) -> int:
|
||||
"""Choose an insertion index for the special sentence within [0, total_blocks-1]
|
||||
such that its relative depth falls within the depth bin [i/10, (i+1)/10).
|
||||
"""
|
||||
if total_blocks <= 0:
|
||||
return 0
|
||||
# Use floor for start and ceil for end to cover boundaries evenly
|
||||
start = math.floor(total_blocks * (depth_bin / 10))
|
||||
end = math.ceil(total_blocks * ((depth_bin + 1) / 10)) - 1
|
||||
# clamp
|
||||
start = max(0, min(start, total_blocks - 1))
|
||||
end = max(start, min(end, total_blocks - 1))
|
||||
return random.randint(start, end)
|
||||
|
||||
|
||||
def _build_example(total_blocks: int, depth_bin: int) -> dict:
|
||||
"""Build one example with a special line inserted among noise blocks.
|
||||
|
||||
total_blocks: total number of blocks in the final context (including the special one)
|
||||
depth_bin: integer in [0, 9]
|
||||
"""
|
||||
total_blocks = max(1, total_blocks)
|
||||
|
||||
# Prepare blocks
|
||||
magic = essential_digits4()
|
||||
special_line = SPECIAL_TPL.format(magic_number=magic)
|
||||
|
||||
# We'll have (total_blocks - 1) noise blocks and 1 special line
|
||||
noise_count = max(0, total_blocks - 1)
|
||||
blocks: list[str] = [NOISE_BLOCK for _ in range(noise_count)]
|
||||
|
||||
insert_at = _choose_position(total_blocks, depth_bin)
|
||||
# Insert special line at the desired position within the final sequence
|
||||
# If noise_count == 0, we just return special
|
||||
if noise_count == 0:
|
||||
final_blocks = [special_line]
|
||||
else:
|
||||
# Compose by interleaving noise and inserting special at index
|
||||
# Build a list of length `total_blocks` and fill
|
||||
final_blocks = []
|
||||
noise_idx = 0
|
||||
for idx in range(total_blocks):
|
||||
if idx == insert_at:
|
||||
final_blocks.append(special_line)
|
||||
else:
|
||||
final_blocks.append(blocks[noise_idx])
|
||||
noise_idx += 1
|
||||
|
||||
context = SEP.join(final_blocks)
|
||||
prompt = "What is the special magic number? Reply with only the number."
|
||||
response = magic
|
||||
return {"context": context, "prompt": prompt, "response": response}
|
||||
|
||||
|
||||
def generate_examples(n: int, k: int) -> list[dict]:
|
||||
"""Generate n examples (all for block length k) evenly across 10 depth bins."""
|
||||
if n <= 0:
|
||||
return []
|
||||
base = n // 10
|
||||
rem = n % 10
|
||||
counts = [base + (1 if i < rem else 0) for i in range(10)]
|
||||
out: list[dict] = []
|
||||
for depth_bin, c in enumerate(counts):
|
||||
for _ in range(c):
|
||||
out.append(_build_example(total_blocks=k, depth_bin=depth_bin))
|
||||
random.shuffle(out)
|
||||
return out
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate noise-wrapped special magic number dataset (similar structure to generate_ctx_kv.py)",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=RNG_SEED, help="Random seed")
|
||||
parser.add_argument(
|
||||
"--tokenizer-name",
|
||||
type=str,
|
||||
default="google/gemma-2-2b-it",
|
||||
help=("Tokenizer name"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-samples-per-bin",
|
||||
type=int,
|
||||
default=BASE_SAMPLES_PER_BIN,
|
||||
help="Baseline number of TRAINING samples per token bin (scaled by bin width). Validation & test are always 1000 each.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-prefix",
|
||||
type=str,
|
||||
default="data/raw_datasets/ctx_magic_number",
|
||||
help="Output directory prefix (bin range will be appended)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokens-per-block",
|
||||
"--tokens-per-pair",
|
||||
dest="tokens_per_block",
|
||||
type=int,
|
||||
default=TOKENS_PER_BLOCK,
|
||||
help="Heuristic tokens per noise block for bucketing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--only-first-n-bins",
|
||||
type=int,
|
||||
default=None,
|
||||
help="For quick tests: only generate the first N token bins",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print a small sample and exit without writing files",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
random.seed(args.seed)
|
||||
|
||||
# ----------------------------------------------------
|
||||
# Optional: report tokenizer-based token length stats
|
||||
# ----------------------------------------------------
|
||||
if args.tokenizer_name:
|
||||
try:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
except Exception as e: # pragma: no cover
|
||||
raise RuntimeError(
|
||||
"Failed to import transformers. Install it or omit --tokenizer-name."
|
||||
) from e
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, use_fast=True)
|
||||
noise_token_count = len(tokenizer(NOISE_BLOCK).input_ids)
|
||||
special_example = SPECIAL_TPL.format(magic_number="0000")
|
||||
special_token_count = len(tokenizer(special_example).input_ids)
|
||||
print(
|
||||
f"[Tokenizer: {args.tokenizer_name}] Noise block tokens: {noise_token_count} | Special line tokens: {special_token_count}"
|
||||
)
|
||||
|
||||
tok_bins = [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
|
||||
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
|
||||
]
|
||||
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
|
||||
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
|
||||
if args.only_first_n_bins is not None:
|
||||
tok_bins = tok_bins[: args.only_first_n_bins]
|
||||
|
||||
if args.tokenizer_name:
|
||||
max_hi = max(hi for _, hi in tok_bins)
|
||||
|
||||
def measure_len(k: int) -> int:
|
||||
if k == 1:
|
||||
ctx = SPECIAL_TPL.format(magic_number="0000")
|
||||
else:
|
||||
blocks = [NOISE_BLOCK] * (k - 1) + [
|
||||
SPECIAL_TPL.format(magic_number="0000")
|
||||
]
|
||||
ctx = SEP.join(blocks)
|
||||
return len(tokenizer(ctx).input_ids)
|
||||
|
||||
lengths: list[int] = [0]
|
||||
k = 1
|
||||
while True:
|
||||
L = measure_len(k)
|
||||
lengths.append(L)
|
||||
if L >= max_hi:
|
||||
break
|
||||
k += 1
|
||||
|
||||
len_bins = []
|
||||
for lo, hi in tok_bins:
|
||||
k_lo = None
|
||||
for kk in range(1, len(lengths)):
|
||||
if lengths[kk] >= lo:
|
||||
k_lo = kk
|
||||
break
|
||||
if k_lo is None or lengths[k_lo] >= hi:
|
||||
len_bins.append((0, 0))
|
||||
continue
|
||||
k_hi = len(lengths)
|
||||
for kk in range(k_lo, len(lengths)):
|
||||
if lengths[kk] >= hi:
|
||||
k_hi = kk
|
||||
break
|
||||
len_bins.append((k_lo, k_hi))
|
||||
|
||||
base_tokens = lengths[1]
|
||||
delta = (lengths[2] - lengths[1]) if len(lengths) > 2 else 0
|
||||
print(
|
||||
f"Using tokenizer-measured block ranges. base_tokens={base_tokens} approx_delta={delta}"
|
||||
)
|
||||
else:
|
||||
len_bins = [
|
||||
(lo // args.tokens_per_block, hi // args.tokens_per_block)
|
||||
for (lo, hi) in tok_bins
|
||||
]
|
||||
|
||||
if args.dry_run:
|
||||
for lb in len_bins:
|
||||
if lb[1] > lb[0]:
|
||||
k = max(1, lb[0])
|
||||
sample = generate_examples(10, k)
|
||||
print("Sample entry:")
|
||||
print(json.dumps(sample[0], indent=2))
|
||||
break
|
||||
return
|
||||
# -----------------------------------------------
|
||||
# Main generation per token bin
|
||||
# -----------------------------------------------
|
||||
TARGET_VAL = 1000
|
||||
TARGET_TEST = 1000
|
||||
for len_bin, tok_bin in zip(len_bins, tok_bins):
|
||||
if len_bin[1] <= len_bin[0]:
|
||||
print(f"Skipping token bin {tok_bin} (no valid block counts)")
|
||||
continue
|
||||
k_start = max(1, len_bin[0])
|
||||
k_end = max(1, len_bin[1])
|
||||
k_values = list(range(k_start, k_end))
|
||||
bin_size = len(k_values)
|
||||
save_dir = f"{args.out_prefix}_{tok_bin[0]}_{tok_bin[1]}"
|
||||
training_enabled = tok_bin[1] <= 1024 # unchanged policy
|
||||
if training_enabled:
|
||||
train_data: list[dict] = []
|
||||
# Distribute training budget across k values.
|
||||
# Scale: per_k = base_samples_per_bin / bin_size
|
||||
per_k_train = max(1, args.base_samples_per_bin // max(1, bin_size))
|
||||
for k in k_values:
|
||||
train_data += generate_examples(per_k_train, k)
|
||||
val_data: list[dict] = []
|
||||
test_data: list[dict] = []
|
||||
base_val = TARGET_VAL // bin_size
|
||||
rem_val = TARGET_VAL % bin_size
|
||||
base_test = TARGET_TEST // bin_size
|
||||
rem_test = TARGET_TEST % bin_size
|
||||
for idx, k in enumerate(k_values):
|
||||
n_val_k = base_val + (1 if idx < rem_val else 0)
|
||||
n_test_k = base_test + (1 if idx < rem_test else 0)
|
||||
if n_val_k:
|
||||
val_data += generate_examples(n_val_k, k)
|
||||
if n_test_k:
|
||||
test_data += generate_examples(n_test_k, k)
|
||||
random.shuffle(val_data)
|
||||
random.shuffle(test_data)
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
if training_enabled:
|
||||
save_jsonl(train_data, f"{save_dir}/train.jsonl")
|
||||
save_jsonl(val_data, f"{save_dir}/val.jsonl")
|
||||
save_jsonl(test_data, f"{save_dir}/test.jsonl")
|
||||
if training_enabled:
|
||||
print(
|
||||
f"Dataset generated at {save_dir} (train={len(train_data)} val={len(val_data)} test={len(test_data)})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Dataset (val/test only) generated at {save_dir} (val={len(val_data)} test={len(test_data)})"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
269
data/generate_fw_edu_qa_v2.py
Normal file
269
data/generate_fw_edu_qa_v2.py
Normal file
|
|
@ -0,0 +1,269 @@
|
|||
import argparse
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
|
||||
import pandas as pd
|
||||
from datasets import Dataset, load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
STOP_STRINGS = {
|
||||
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
|
||||
}
|
||||
|
||||
SYSTEM_TEMPLATE = (
|
||||
"You are a creative and helpful assistant.\n"
|
||||
"You are tasked with generating questions for reading comprehension tests.\n"
|
||||
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
|
||||
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
|
||||
"**DO NOT** hallucinate or make up information."
|
||||
)
|
||||
|
||||
# based on Make Your LLM Fully Utilize the Context (https://arxiv.org/pdf/2404.16811)
|
||||
PROMPT_TEMPLATE = (
|
||||
"### Instructions ###\n"
|
||||
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
|
||||
"information provided in the context, not general questions that suit any context.\n\n"
|
||||
"### Context ###\n"
|
||||
"{context}\n\n\n"
|
||||
"### Rules ###\n"
|
||||
"Rules to follow when generating the questions:\n"
|
||||
"1. The questions must be specific to the given context and fully answerable from information present in the given context.\n"
|
||||
"2. Ask questions that are fact-seeking based on the information provided.\n"
|
||||
"3. Make sure the questions are clear and unambiguous.\n"
|
||||
"4. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the questions.\n"
|
||||
"5. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
|
||||
"6. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
|
||||
"7. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
|
||||
"8. Ignore typos, spacing, and grammatical errors in the context.\n\n"
|
||||
"Rules to follow when generating the answers:\n"
|
||||
"1. The answers must use the (implied) information provided in the context.\n"
|
||||
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the answers.\n"
|
||||
"3. Do not just copy words from the context. Answer the question in your own words.\n"
|
||||
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
|
||||
"Respond with {n_qa_pairs} question-answer pairs.\n"
|
||||
"Always use proper grammar and punctuation.\n"
|
||||
"Try to use different question forms and styles.\n"
|
||||
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
|
||||
"The question-answer pairs should be in the following format:\n"
|
||||
"Question 1: {{question_1}}\n"
|
||||
"Answer 1: {{answer_1}}\n"
|
||||
"Question 2: {{question_2}}\n"
|
||||
"Answer 2: {{answer_2}}\n"
|
||||
"..."
|
||||
)
|
||||
|
||||
|
||||
def get_prompt(context, n_qa_pairs):
|
||||
prompt = PROMPT_TEMPLATE.format(context=context, n_qa_pairs=n_qa_pairs)
|
||||
return prompt
|
||||
|
||||
|
||||
def check_should_skip(txt: str, vllm_model: str) -> bool:
|
||||
"""Check if the response should be skipped based on stop strings."""
|
||||
for stop in STOP_STRINGS[vllm_model]:
|
||||
if stop in txt[-len(stop) :]:
|
||||
return (txt.split(stop)[0], False) # Found a valid stop string
|
||||
return (txt, True) # No valid stop string found, skip this response
|
||||
|
||||
|
||||
def postprocess_qa_pairs(res_txt: str):
|
||||
"""
|
||||
Postprocesses the QA pairs from the response text.
|
||||
|
||||
Args:
|
||||
res_txt: The response text.
|
||||
n_qa_pairs: The number of QA pairs.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists, the first containing the questions and the second containing the answers.
|
||||
"""
|
||||
# capture everything after each "Question {number}:" until "Answer"
|
||||
res_txt = remove_think(res_txt)
|
||||
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
|
||||
questions = re.findall(q_pattern, res_txt, flags=re.S)
|
||||
|
||||
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
|
||||
answers = re.findall(a_pattern, res_txt, flags=re.S)
|
||||
|
||||
if len(questions) != len(answers):
|
||||
print(f"Warning---number of questions and answers do not match")
|
||||
print(f"Number of questions: {len(questions)}")
|
||||
print(f"Number of answers: {len(answers)}")
|
||||
|
||||
out_q = []
|
||||
out_a = []
|
||||
n_skips = 0
|
||||
if (len(questions) > 0) and (len(answers) > 0):
|
||||
n_gen_pairs = min(len(questions), len(answers))
|
||||
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
|
||||
for i in range(n_gen_pairs):
|
||||
response = answers[i].strip()
|
||||
question = questions[i].strip()
|
||||
if not response or not question:
|
||||
print(f"Skipping empty question or answer at index {i}")
|
||||
continue
|
||||
if (not has_left_over) and (i == n_gen_pairs - 1):
|
||||
response, skip = check_should_skip(response, vllm_model)
|
||||
if skip:
|
||||
print(f"Skipping due to missing stop string")
|
||||
n_skips += 1
|
||||
continue
|
||||
out_q.append(question.strip())
|
||||
out_a.append(response.strip())
|
||||
print(f"Skipped {n_skips} responses due to missing stop strings")
|
||||
|
||||
return out_q, out_a
|
||||
|
||||
|
||||
def length_filter(sample, min_len, max_len):
|
||||
return min_len <= len(sample["text"]) <= max_len
|
||||
|
||||
|
||||
def remove_think(txt):
|
||||
return txt.split("</think>")[-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate QA pairs from FineWeb Edu dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm_model",
|
||||
type=str,
|
||||
default=os.environ.get("vllm_model", "google/gemma-2-27b-it"),
|
||||
help="VLLM model to use for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_pattern",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pattern to match shard files (e.g., '000_0000*')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_qa_pairs",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of question-answer pairs to generate per context",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--min_length",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Minimum length of the context to consider for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_length",
|
||||
type=int,
|
||||
default=2000,
|
||||
help="Maximum length of the context to consider for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_model_length",
|
||||
type=int,
|
||||
default=2**14,
|
||||
help="Maximum length of the model input (context + prompt + response) in tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Debug mode - process only first 100 samples",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vllm_model = args.vllm_model
|
||||
print(f"Using model: {vllm_model}")
|
||||
llm_kwargs = dict(
|
||||
model=vllm_model,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=args.max_model_length,
|
||||
limit_mm_per_prompt={"image": 0},
|
||||
)
|
||||
|
||||
llm = LLM(**llm_kwargs)
|
||||
tokenizer = llm.get_tokenizer()
|
||||
shard_pattern = args.shard_pattern
|
||||
n_qa_pairs = args.n_qa_pairs
|
||||
|
||||
paths = glob(
|
||||
f"./data/raw_datasets/fineweb_edu/sample/100BT/{shard_pattern}.parquet"
|
||||
)
|
||||
|
||||
split = "train[:100]" if args.debug else "train"
|
||||
for path in paths:
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files=path,
|
||||
split=split,
|
||||
)
|
||||
ds = ds.filter(
|
||||
length_filter,
|
||||
fn_kwargs={"min_len": args.min_length, "max_len": args.max_length},
|
||||
num_proc=8,
|
||||
)
|
||||
|
||||
ctxs = [sample["text"] for sample in iter(ds)]
|
||||
messages = [
|
||||
[
|
||||
{"role": "system", "content": SYSTEM_TEMPLATE},
|
||||
{"role": "user", "content": get_prompt(ctx, n_qa_pairs)},
|
||||
]
|
||||
for ctx in ctxs
|
||||
]
|
||||
|
||||
print(f"Generating from {len(messages)} contexts")
|
||||
completions = llm.chat(
|
||||
messages,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=2048,
|
||||
temperature=0.0,
|
||||
# needed for checking if stop tokens are present
|
||||
skip_special_tokens=False,
|
||||
include_stop_str_in_output=True,
|
||||
),
|
||||
)
|
||||
samples = []
|
||||
for ctx, completion in zip(ctxs, completions):
|
||||
questions, answers = postprocess_qa_pairs(completion.outputs[0].text)
|
||||
samples.append(
|
||||
{
|
||||
"context": ctx,
|
||||
"prompts_level_0": questions,
|
||||
"responses_level_0": answers,
|
||||
}
|
||||
)
|
||||
if args.debug:
|
||||
print(f"{ctx=}")
|
||||
print(f"{completion.outputs[0].text=}")
|
||||
for q, a in zip(questions, answers):
|
||||
print(f"{q=}")
|
||||
print(f"{a=}")
|
||||
print()
|
||||
print("=" * 80)
|
||||
|
||||
print(f"Generated {len(samples)} samples")
|
||||
df = pd.DataFrame(samples)
|
||||
ds = Dataset.from_pandas(df)
|
||||
val_ds = ds.take(10)
|
||||
ds = ds.skip(10)
|
||||
|
||||
shard_name = path.split("/")[-1].split(".")[0]
|
||||
shard_name += "_level_0"
|
||||
if args.debug:
|
||||
shard_name += "_debug"
|
||||
ds.to_parquet(
|
||||
f"data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}.parquet"
|
||||
)
|
||||
val_ds.to_parquet(
|
||||
f"data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}_val.parquet"
|
||||
)
|
||||
print(
|
||||
f"Saved to data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}.parquet"
|
||||
)
|
||||
print(
|
||||
f"Saved to data/raw_datasets/fw_qa_v2/min_{args.min_length}_to_{args.max_length}/{shard_name}_val.parquet"
|
||||
)
|
||||
296
data/generate_fw_edu_qa_v2_repeat.py
Normal file
296
data/generate_fw_edu_qa_v2_repeat.py
Normal file
|
|
@ -0,0 +1,296 @@
|
|||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import re
|
||||
from glob import glob
|
||||
|
||||
from datasets import load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
STOP_STRINGS = {
|
||||
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
|
||||
}
|
||||
|
||||
SYSTEM_TEMPLATE = (
|
||||
"You are a creative and helpful assistant.\n"
|
||||
"You are tasked with generating questions for reading comprehension tests.\n"
|
||||
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
|
||||
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
|
||||
"**DO NOT** hallucinate or make up information."
|
||||
)
|
||||
|
||||
# based on Make Your LLM Fully Utilize the Context (https://arxiv.org/pdf/2404.16811)
|
||||
PROMPT_TEMPLATE = (
|
||||
"### Instructions ###\n"
|
||||
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
|
||||
"information provided in the context, not general questions that suit any context.\n\n"
|
||||
"### Context ###\n"
|
||||
"{context}\n\n\n"
|
||||
"### Example Question-Answer Pairs ###\n"
|
||||
"{qa_pairs}\n\n\n"
|
||||
"### Rules ###\n"
|
||||
"Rules to follow when generating the questions:\n"
|
||||
"1. The questions must be specific to the given context and fully answerable from information present in *or* implied from the given context.\n"
|
||||
"2. The questions must *not* be redundant with the example questions-answer pairs provided.\n"
|
||||
"3. You should prioritize fact-seeking questions. Consider reversal questions, e.g., asking 'What causes X to happen?' is valid when 'Y causes X' is presented in the context.\n"
|
||||
"4. If all the facts in the context are already covered by the provided examples, you must generate *more complicated* questions that require reasoning beyond simple information retrieval.\nThis includes asking about information that can be inferred, requiring synthesizing information from multiple parts of the text, or understanding relationships between concepts, events, or individuals mentioned in the context. For example, if the context says 'The Eiffel Tower was completed in 1889 after 2 years of construction', you can ask 'When did the construction of the Eiffel Tower begin?'. Here's another example: if the context says 'Alice is Bob's mother. Bob is Charlie's Dad', you can ask 'Who is Charlie's grandmother?'.\n"
|
||||
"5. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the questions.\n"
|
||||
"6. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
|
||||
"7. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
|
||||
"8. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
|
||||
"9. Ignore typos, spacing, and grammatical errors in the context.\n\n"
|
||||
"Rules to follow when generating the answers:\n"
|
||||
"1. The answers must use the (implied) information provided in the context.\n"
|
||||
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the answers.\n"
|
||||
"3. Do not just copy words from the context. Answer the question in your own words.\n"
|
||||
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
|
||||
"Respond with {n_qa_pairs} question-answer pairs.\n"
|
||||
"Always use proper grammar and punctuation.\n"
|
||||
"Try to use different question forms and styles.\n"
|
||||
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
|
||||
"The question-answer pairs should be in the following format:\n"
|
||||
"Question 1: {{question_1}}\n"
|
||||
"Answer 1: {{answer_1}}\n"
|
||||
"Question 2: {{question_2}}\n"
|
||||
"Answer 2: {{answer_2}}\n"
|
||||
"..."
|
||||
)
|
||||
|
||||
|
||||
def get_prompt(context, example_qa_pairs, n_qa_pairs):
|
||||
prompt = PROMPT_TEMPLATE.format(
|
||||
context=context,
|
||||
qa_pairs=example_qa_pairs,
|
||||
n_qa_pairs=n_qa_pairs,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def check_should_skip(txt: str, vllm_model: str) -> bool:
|
||||
"""Check if the response should be skipped based on stop strings."""
|
||||
for stop in STOP_STRINGS[vllm_model]:
|
||||
if stop in txt[-len(stop) :]:
|
||||
return (txt.split(stop)[0], False) # Found a valid stop string
|
||||
return (txt, True) # No valid stop string found, skip this response
|
||||
|
||||
|
||||
def postprocess_qa_pairs(res_txt: str):
|
||||
"""
|
||||
Postprocesses the QA pairs from the response text.
|
||||
|
||||
Args:
|
||||
res_txt: The response text.
|
||||
n_qa_pairs: The number of QA pairs.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists, the first containing the questions and the second containing the answers.
|
||||
"""
|
||||
# capture everything after each "Question {number}:" until "Answer"
|
||||
res_txt = remove_think(res_txt)
|
||||
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
|
||||
questions = re.findall(q_pattern, res_txt, flags=re.S)
|
||||
|
||||
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
|
||||
answers = re.findall(a_pattern, res_txt, flags=re.S)
|
||||
|
||||
if len(questions) != len(answers):
|
||||
print(f"Warning---number of questions and answers do not match")
|
||||
print(f"Number of questions: {len(questions)}")
|
||||
print(f"Number of answers: {len(answers)}")
|
||||
|
||||
out_q = []
|
||||
out_a = []
|
||||
n_skips = 0
|
||||
if (len(questions) > 0) and (len(answers) > 0):
|
||||
n_gen_pairs = min(len(questions), len(answers))
|
||||
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
|
||||
for i in range(n_gen_pairs):
|
||||
response = answers[i].strip()
|
||||
question = questions[i].strip()
|
||||
if not response or not question:
|
||||
print(f"Skipping empty question or answer at index {i}")
|
||||
continue
|
||||
if (not has_left_over) and (i == n_gen_pairs - 1):
|
||||
response, skip = check_should_skip(response, vllm_model)
|
||||
if skip:
|
||||
print(f"Skipping due to missing stop string")
|
||||
n_skips += 1
|
||||
continue
|
||||
out_q.append(question.strip())
|
||||
out_a.append(response.strip())
|
||||
print(f"Skipped {n_skips} responses due to missing stop strings")
|
||||
|
||||
return out_q, out_a
|
||||
|
||||
|
||||
def flatten_list(l):
|
||||
out = []
|
||||
for x in l:
|
||||
out += x
|
||||
return out
|
||||
|
||||
|
||||
def remove_think(txt):
|
||||
return txt.split("</think>")[-1]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate QA pairs from FineWeb Edu dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vllm_model",
|
||||
type=str,
|
||||
default=os.environ.get("vllm_model", "google/gemma-2-27b-it"),
|
||||
help="VLLM model to use for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard_pattern",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Pattern to match shard files (e.g., '000_0000*')",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n_qa_pairs",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of question-answer pairs to generate per context",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_model_length",
|
||||
type=int,
|
||||
default=2**12,
|
||||
help="Maximum length of the model input (context + prompt + response) in tokens",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Debug mode - process only first 100 samples",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vllm_model = args.vllm_model
|
||||
print(f"Using model: {vllm_model}")
|
||||
llm_kwargs = dict(
|
||||
model=vllm_model,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=2**14,
|
||||
limit_mm_per_prompt={"image": 0},
|
||||
)
|
||||
|
||||
llm = LLM(**llm_kwargs)
|
||||
tokenizer = llm.get_tokenizer()
|
||||
shard_pattern = args.shard_pattern
|
||||
n_qa_pairs = args.n_qa_pairs
|
||||
|
||||
paths = glob(f"./data/raw_datasets/fw_qa_v2/{shard_pattern}.parquet")
|
||||
|
||||
split = "train[:100]" if args.debug else "train"
|
||||
for path in paths:
|
||||
assert "_level" in path, (
|
||||
"Path must contain '_level' to indicate the dataset level"
|
||||
)
|
||||
shard_name = path.split("/")[-1].split(".")[0].split("_debug")[0]
|
||||
if "/" in shard_pattern:
|
||||
shard_name = "/".join(shard_pattern.split("/")[:-1]) + "/" + shard_name
|
||||
cur_level = int(shard_name.split("_level_")[-1])
|
||||
next_level = cur_level + 1
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files=path,
|
||||
split=split,
|
||||
)
|
||||
prompt_cols = [col for col in ds.column_names if col.startswith("prompts")]
|
||||
response_cols = [col for col in ds.column_names if col.startswith("responses")]
|
||||
assert len(prompt_cols) > 0, "No prompt columns found in the dataset"
|
||||
if len(prompt_cols) != len(response_cols):
|
||||
raise ValueError(
|
||||
"Number of prompt columns does not match number of response columns"
|
||||
)
|
||||
|
||||
samples_data = []
|
||||
for sample in iter(ds):
|
||||
# Format existing QA pairs as examples
|
||||
example_qa_pairs = ""
|
||||
questions = flatten_list([sample[col] for col in prompt_cols])
|
||||
answers = flatten_list([sample[col] for col in response_cols])
|
||||
for i, (q, a) in enumerate(zip(questions, answers), 1):
|
||||
example_qa_pairs += f"Question {i}: {q}\nAnswer {i}: {a}\n"
|
||||
|
||||
samples_data.append(
|
||||
{"context": sample["context"], "example_qa_pairs": example_qa_pairs}
|
||||
)
|
||||
del ds
|
||||
gc.collect()
|
||||
|
||||
messages = [
|
||||
[
|
||||
{"role": "system", "content": SYSTEM_TEMPLATE},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_prompt(
|
||||
sample["context"], sample["example_qa_pairs"], n_qa_pairs
|
||||
),
|
||||
},
|
||||
]
|
||||
for sample in samples_data
|
||||
]
|
||||
|
||||
print(f"Generating from {len(messages)} contexts")
|
||||
completions = llm.chat(
|
||||
messages,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.0,
|
||||
# needed for checking if stop tokens are present
|
||||
skip_special_tokens=False,
|
||||
include_stop_str_in_output=True,
|
||||
),
|
||||
)
|
||||
samples = []
|
||||
for sample_data, completion in zip(samples_data, completions):
|
||||
questions, answers = postprocess_qa_pairs(completion.outputs[0].text)
|
||||
samples.append(
|
||||
{
|
||||
"context": sample_data["context"],
|
||||
f"prompts_level_{next_level}": questions,
|
||||
f"responses_level_{next_level}": answers,
|
||||
}
|
||||
)
|
||||
if args.debug:
|
||||
print(f"context={sample_data['context']}")
|
||||
print(f"example_qa_pairs={sample_data['example_qa_pairs']}")
|
||||
print(f"{completion.outputs[0].text=}")
|
||||
for q, a in zip(questions, answers):
|
||||
print(f"{q=}")
|
||||
print(f"{a=}")
|
||||
print()
|
||||
print("=" * 80)
|
||||
|
||||
del samples_data
|
||||
gc.collect()
|
||||
|
||||
print(f"Generated {len(samples)} samples")
|
||||
ds = load_dataset(
|
||||
"parquet",
|
||||
data_files=path,
|
||||
split=split,
|
||||
)
|
||||
ds = ds.add_column(
|
||||
f"prompts_level_{next_level}",
|
||||
[sample[f"prompts_level_{next_level}"] for sample in samples],
|
||||
)
|
||||
ds = ds.add_column(
|
||||
f"responses_level_{next_level}",
|
||||
[sample[f"responses_level_{next_level}"] for sample in samples],
|
||||
)
|
||||
|
||||
shard_name_base = shard_name.split("_level_")[0]
|
||||
shard_name = f"{shard_name_base}_level_{next_level}"
|
||||
if args.debug:
|
||||
shard_name += "_debug"
|
||||
ds.to_parquet(f"data/raw_datasets/fw_qa_v2/{shard_name}.parquet")
|
||||
print(f"Saved to data/raw_datasets/fw_qa_v2/{shard_name}.parquet")
|
||||
387
data/gutenburg_sample.txt
Normal file
387
data/gutenburg_sample.txt
Normal file
|
|
@ -0,0 +1,387 @@
|
|||
The Project Gutenberg eBook, Addison, by William John Courthope
|
||||
|
||||
|
||||
This eBook is for the use of anyone anywhere at no cost and with
|
||||
almost no restrictions whatsoever. You may copy it, give it away or
|
||||
re-use it under the terms of the Project Gutenberg License included
|
||||
with this eBook or online at www.gutenberg.org
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Title: Addison
|
||||
|
||||
|
||||
Author: William John Courthope
|
||||
|
||||
|
||||
|
||||
Release Date: November 27, 2012 [eBook #41496]
|
||||
|
||||
Language: English
|
||||
|
||||
Character set encoding: ISO-8859-1
|
||||
|
||||
|
||||
***START OF THE PROJECT GUTENBERG EBOOK ADDISON***
|
||||
|
||||
|
||||
E-text prepared by the Online Distributed Proofreading Team
|
||||
(http://www.pgdp.net) from page images generously made available by
|
||||
Internet Archive (http://archive.org)
|
||||
|
||||
|
||||
|
||||
Note: Images of the original pages are available through
|
||||
Internet Archive. See
|
||||
http://archive.org/details/addison_00cour
|
||||
|
||||
|
||||
Transcriber's note:
|
||||
|
||||
Text enclosed by underscores is in italics (_italics_).
|
||||
|
||||
Text enclosed by curly brackets is superscripted
|
||||
(example: y{e}).
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
English Men of Letters
|
||||
|
||||
Edited by John Morley
|
||||
|
||||
ADDISON
|
||||
|
||||
by
|
||||
|
||||
W. J. COURTHOPE
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Harper & Brothers Publishers
|
||||
New York and London
|
||||
1902
|
||||
|
||||
* * * * *
|
||||
|
||||
ENGLISH MEN OF LETTERS.
|
||||
|
||||
EDITED BY JOHN MORLEY.
|
||||
|
||||
JOHNSON Leslie Stephen.
|
||||
GIBBON J. C. Morison.
|
||||
SCOTT R. H. Hutton.
|
||||
SHELLEY J. A. Symonds.
|
||||
HUME T. H. Huxley.
|
||||
GOLDSMITH William Black.
|
||||
DEFOE William Minto.
|
||||
BURNS J. C. Shairp.
|
||||
SPENSER R. W. Church.
|
||||
THACKERAY Anthony Trollope.
|
||||
BURKE John Morley.
|
||||
MILTON Mark Pattison.
|
||||
HAWTHORNE Henry James, Jr.
|
||||
SOUTHEY E. Dowden.
|
||||
CHAUCER A. W. Ward.
|
||||
BUNYAN J. A. Froude.
|
||||
COWPER Goldwin Smith.
|
||||
POPE Leslie Stephen.
|
||||
BYRON John Nichol.
|
||||
LOCKE Thomas Fowler.
|
||||
WORDSWORTH F. Myers.
|
||||
DRYDEN G. Saintsbury.
|
||||
LANDOR Sidney Colvin.
|
||||
DE QUINCEY David Masson.
|
||||
LAMB Alfred Ainger.
|
||||
BENTLEY R. C. Jebb.
|
||||
DICKENS A. W. Ward.
|
||||
GRAY E. W. Gosse.
|
||||
SWIFT Leslie Stephen.
|
||||
STERNE H. D. Traill.
|
||||
MACAULAY J. Cotter Morison.
|
||||
FIELDING Austin Dobson.
|
||||
SHERIDAN Mrs. Oliphant.
|
||||
ADDISON W. J. Courthope.
|
||||
BACON R. W. Church.
|
||||
COLERIDGE H. D. Traill.
|
||||
SIR PHILIP SIDNEY J. A. Symonds.
|
||||
KEATS Sidney Colvin.
|
||||
CARLYLE John Nichol.
|
||||
|
||||
12mo, Cloth, 75 cents per volume.
|
||||
|
||||
_Other volumes in preparation._
|
||||
|
||||
PUBLISHED BY HARPER & BROTHERS, NEW YORK.
|
||||
|
||||
_Any of the above works will be sent by mail, postage prepaid, to any part
|
||||
of the United States, Canada, or Mexico, on receipt of the price._
|
||||
|
||||
* * * * *
|
||||
|
||||
|
||||
|
||||
CONTENTS.
|
||||
|
||||
|
||||
PAGE
|
||||
|
||||
CHAPTER I.
|
||||
THE STATE OF ENGLISH SOCIETY AND LETTERS
|
||||
AFTER THE RESTORATION 1
|
||||
|
||||
CHAPTER II.
|
||||
ADDISON'S FAMILY AND EDUCATION 21
|
||||
|
||||
CHAPTER III.
|
||||
ADDISON ON HIS TRAVELS 38
|
||||
|
||||
CHAPTER IV.
|
||||
HIS EMPLOYMENT IN AFFAIRS OF STATE 53
|
||||
|
||||
CHAPTER V.
|
||||
THE "TATLER" AND "SPECTATOR" 78
|
||||
|
||||
CHAPTER VI.
|
||||
"CATO" 110
|
||||
|
||||
CHAPTER VII.
|
||||
ADDISON'S QUARREL WITH POPE 125
|
||||
|
||||
CHAPTER VIII.
|
||||
THE LAST YEARS OF HIS LIFE 139
|
||||
|
||||
CHAPTER IX.
|
||||
THE GENIUS OF ADDISON 153
|
||||
|
||||
|
||||
|
||||
|
||||
ADDISON.
|
||||
|
||||
|
||||
|
||||
|
||||
CHAPTER I.
|
||||
|
||||
THE STATE OF ENGLISH SOCIETY AND LETTERS AFTER THE RESTORATION.
|
||||
|
||||
|
||||
Of the four English men of letters whose writings most fully embody the
|
||||
spirit of the eighteenth century, the one who provides the biographer with
|
||||
the scantiest materials is Addison. In his _Journal to Stella_, his social
|
||||
verses, and his letters to his friends, we have a vivid picture of those
|
||||
relations with women and that protracted suffering which invest with such
|
||||
tragic interest the history of Swift. Pope, by the publication of his own
|
||||
correspondence, has enabled us, in a way that he never intended, to
|
||||
understand the strange moral twist which distorted a nature by no means
|
||||
devoid of noble instincts. Johnson was fortunate in the companionship of
|
||||
perhaps the best biographer who ever lived. But of the real life and
|
||||
character of Addison scarcely any contemporary record remains. The formal
|
||||
narrative prefixed to his works by Tickell is, by that writer's own
|
||||
admission, little more than a bibliography. Steele, who might have told us
|
||||
more than any man about his boyhood and his manner of life in London, had
|
||||
become estranged from his old friend before his death. No writer has
|
||||
taken the trouble to preserve any account of the wit and wisdom that
|
||||
enlivened the "little senate" at Button's. His own letters are, as a rule,
|
||||
compositions as finished as his papers in the _Spectator_. Those features
|
||||
in his character which excite the greatest interest have been delineated
|
||||
by the hand of an enemy--an enemy who possessed an unrivalled power of
|
||||
satirical portrait-painting, and was restrained by no regard for truth
|
||||
from creating in the public mind such impressions about others as might
|
||||
serve to heighten the favourable opinion of himself.
|
||||
|
||||
This absence of dramatic incident in Addison's life would lead us
|
||||
naturally to conclude that he was deficient in the energy and passion
|
||||
which cause a powerful nature to leave a mark upon its age. Yet such a
|
||||
judgment would certainly be erroneous. Shy and reserved as he was, the
|
||||
unanimous verdict of his most illustrious contemporaries is decisive as to
|
||||
the respect and admiration which he excited among them. The man who could
|
||||
exert so potent an influence over the mercurial Steele, who could
|
||||
fascinate the haughty and cynical intellect of Swift, whose conversation,
|
||||
by the admission of his satirist Pope, had in it something more charming
|
||||
than that of any other man; of whom it was said that he might have been
|
||||
chosen king if he wished it; such a man, though to the coarse perception
|
||||
of Mandeville he might have seemed no more than "a parson in a tye-wig,"
|
||||
can hardly have been deficient in force of character.
|
||||
|
||||
Nor would it have been possible for a writer distinguished by mere
|
||||
elegance and refinement to leave a lasting impress on the literature and
|
||||
society of his country. In one generation after another, men representing
|
||||
opposing elements of rank, class, interest, and taste, have agreed in
|
||||
acknowledging Addison's extraordinary merits. "Whoever wishes," says
|
||||
Johnson--at the end of a biography strongly coloured with the
|
||||
prepossessions of a semi-Jacobite Tory--"whoever wishes to attain an
|
||||
English style, familiar but not coarse, and elegant but not ostentatious,
|
||||
must give his days and nights to the volumes of Addison." "Such a mark of
|
||||
national respect," says Macaulay, the best representative of middle-class
|
||||
opinion in the present century, speaking of the statue erected to Addison
|
||||
in Westminster Abbey, "was due to the unsullied statesman, to the
|
||||
accomplished scholar, to the master of pure English eloquence, to the
|
||||
consummate painter of life and manners. It was due, above all, to the
|
||||
great satirist who alone knew how to use ridicule without abusing it; who,
|
||||
without inflicting a wound, effected a great social reform, and who
|
||||
reconciled wit and virtue after a long and disastrous separation, during
|
||||
which wit had been led astray by profligacy, and virtue by fanaticism."
|
||||
|
||||
This verdict of a great critic is accepted by an age to which the grounds
|
||||
of it are, perhaps, not very apparent. The author of any ideal creation--a
|
||||
poem, a drama, or a novel--has an imprescriptible property in the fame of
|
||||
his work. But to harmonise conflicting social elements, to bring order out
|
||||
of chaos in the sphere of criticism, to form right ways of thinking about
|
||||
questions of morals, taste, and breeding, are operations of which the
|
||||
credit, though it is certainly to be ascribed to particular individuals,
|
||||
is generally absorbed by society itself. Macaulay's eulogy is as just as
|
||||
it is eloquent, but the pages of the _Spectator_ alone will hardly show
|
||||
the reader why Addison should be so highly praised for having reconciled
|
||||
wit with virtue. Nor, looking at him as a critic, will it appear a great
|
||||
achievement to have pointed out to English society the beauties of
|
||||
_Paradise Lost_, unless it be remembered that the taste of the preceding
|
||||
generation still influenced Addison's contemporaries, and that in that
|
||||
generation Cowley was accounted a greater poet than Milton.
|
||||
|
||||
To estimate Addison at his real value we must regard him as the chief
|
||||
architect of Public Opinion in the eighteenth century. But here again we
|
||||
are met by an initial difficulty, because it has become almost a
|
||||
commonplace of contemporary criticism to represent the eighteenth century
|
||||
as a period of sheer destruction. It is tacitly assumed by a school of
|
||||
distinguished philosophical writers that we have arrived at a stage in the
|
||||
world's history in which it is possible to take a positive and scientific
|
||||
view of human affairs. As it is of course necessary that from such a
|
||||
system all belief in the supernatural shall be jealously excluded, it has
|
||||
not seemed impossible to write the history of Thought itself in the
|
||||
eighteenth century. And in tracing the course of this supposed continuous
|
||||
stream it is natural that all the great English writers of the period
|
||||
should be described as in one way or another helping to pull down, or
|
||||
vainly to strengthen, the theological barriers erected by centuries of
|
||||
bigotry against the irresistible tide of enlightened progress.
|
||||
|
||||
It would be of course entirely out of place to discuss here the merits of
|
||||
this new school of history. Those who consider that, whatever glimpses we
|
||||
may obtain of the law and order of the universe, man is, as he always has
|
||||
been and always will be, a mystery to himself, will hardly allow that the
|
||||
operations of the human spirit can be traced in the dissecting-room. But
|
||||
it is, in any case, obvious that to treat the great _imaginative_ writers
|
||||
of any age as if they were only mechanical agents in an evolution of
|
||||
thought is to do them grave injustice. Such writers are, above all things,
|
||||
creative. Their first aim is to "show the very age and body of the time
|
||||
his form and pressure." No work of the eighteenth century, composed in a
|
||||
consciously destructive spirit, has taken its place among the acknowledged
|
||||
classics of the language. Even the _Tale of a Tub_ is to be regarded as a
|
||||
satire upon the aberrations of theologians from right reason, not upon the
|
||||
principles of Christianity itself. The _Essay on Man_ has, no doubt,
|
||||
logically a tendency towards Deism, but nobody ever read the poem for the
|
||||
sake of its philosophy; and it is well known that Pope was much alarmed
|
||||
when it was pointed out to him that his conclusions might be represented
|
||||
as incompatible with the doctrines of revealed religion.
|
||||
|
||||
The truth indeed seems to be the exact converse of what is alleged by the
|
||||
scientific historians. So far from the eighteenth century in England being
|
||||
an age of destructive analysis, its energies were chiefly devoted to
|
||||
political, social, and literary reconstruction. Whatever revolution in
|
||||
faith and manners the English nation had undergone had been the work of
|
||||
the two preceding centuries, and though the historic foundations of
|
||||
society remained untouched, the whole form of the superstructure had been
|
||||
profoundly modified.
|
||||
|
||||
"So tenacious are we," said Burke, towards the close of the last
|
||||
century, "of our old ecclesiastical modes and fashions of institution
|
||||
that very little change has been made in them since the fourteenth or
|
||||
fifteenth centuries, adhering in this particular as in all else to our
|
||||
old settled maxim never entirely nor at once to depart from antiquity.
|
||||
We found these institutions on the whole favourable to morality and
|
||||
discipline, and we thought they were susceptible of amendment without
|
||||
altering the ground. We thought they were capable of receiving and
|
||||
meliorating, and, above all, of preserving the accessories of science
|
||||
and literature as the order of Providence should successively produce
|
||||
them. And after all, with this Gothic and monkish education (for such
|
||||
it is the groundwork), we may put in our claim to as ample and early
|
||||
a share in all the improvements in science, in arts, and in literature
|
||||
which have illuminated the modern world as any other nation in Europe.
|
||||
We think one main cause of this improvement was our not despising the
|
||||
patrimony of knowledge which was left us by our forefathers."
|
||||
|
||||
All this is, in substance, true of our political as well as our
|
||||
ecclesiastical institutions. And yet, when Burke wrote, the great feudal
|
||||
and mediæval structure of England had been so transformed by the Wars of
|
||||
the Roses, the Reformation, the Rebellion, and the Revolution, that its
|
||||
ancient outlines were barely visible. In so far, therefore, as his words
|
||||
seem to imply that the social evolution he describes was produced by an
|
||||
imperceptible and almost mechanical process of national instinct, the
|
||||
impression they tend to create is entirely erroneous.
|
||||
|
||||
If we have been hitherto saved from such corruption as undermined the
|
||||
republics of Italy, from the religious wars that so long enfeebled and
|
||||
divided Germany, and from the Revolution that has severed modern France
|
||||
from her ancient history, thanks for this are due partly, no doubt, to
|
||||
favouring conditions of nature and society, but quite as much to the
|
||||
genius of great individuals who prepared the mind of the nation for the
|
||||
gradual assimilation of new ideas. Thus Langland and Wycliffe and their
|
||||
numerous followers, long before the Reformation, had so familiarised the
|
||||
minds of the people with their ideas of the Christian religion that the
|
||||
Sovereign was able to assume the Headship of the Church without the shock
|
||||
of a social convulsion. Fresh feelings and instincts grew up in the hearts
|
||||
of whole classes of the nation without at first producing any change in
|
||||
outward habits of life, and even without arousing a sense of their logical
|
||||
incongruity. These mixed ideas were constantly brought before the
|
||||
imagination in the works of the poets. Shakespeare abounds with passages
|
||||
in which, side by side with the old feudal, monarchical, catholic, and
|
||||
patriotic instincts of Englishmen, we find the sentiments of the Italian
|
||||
Renaissance. Spenser conveys Puritan doctrines sometimes by the mouth of
|
||||
shepherds, whose originals he had found in Theocritus and Virgil;
|
||||
sometimes under allegorical forms derived from books of chivalry and the
|
||||
ceremonial of the Catholic Church. Milton, the most rigidly Calvinistic of
|
||||
all the English poets in his opinions, is also the most severely classical
|
||||
in his style.
|
||||
|
||||
It was the task of Addison to carry on the reconciling traditions of our
|
||||
literature. It is his praise to have accomplished his task under
|
||||
conditions far more difficult than any that his predecessors had
|
||||
experienced. What they had done was to give instinctive and characteristic
|
||||
expression to the floating ideas of the society about them; what Addison
|
||||
and his contemporaries did was to found a public opinion by a conscious
|
||||
effort of reason and persuasion. Before the Civil Wars there had been at
|
||||
least no visible breach in the principle of Authority in Church and State.
|
||||
At the beginning of the eighteenth century constituted authority had been
|
||||
recently overthrown; one king had been beheaded, another had been
|
||||
expelled; the Episcopalian form of Church Government had been violently
|
||||
displaced in favour of the Presbyterian, and had been with almost equal
|
||||
violence restored. Whole classes of the population had been drawn into
|
||||
opposing camps during the Civil War, and still stood confronting each
|
||||
other with all the harsh antagonism of sentiment inherited from that
|
||||
conflict. Such a bare summary alone is sufficient to indicate the nature
|
||||
of the difficulties Addison had to encounter in his efforts to harmonise
|
||||
public opinion; but a more detailed examination of the state of society
|
||||
after the Restoration is required to place in its full light the
|
||||
extraordinary merits of the success that he achieved.
|
||||
|
||||
There was, to begin with, a vehement opposition between town and country.
|
||||
In the country the old ideas of Feudalism, modified by circumstances, but
|
||||
vigorous and deep-rooted, still prevailed. True, the military system of
|
||||
land-tenure had disappeared with the Restoration, but it was not so with
|
||||
the relations of life, and the habits of thought and feeling which the
|
||||
system had created. The features of surviving Feudalism have been
|
||||
inimitably preserved for us in the character of Sir Roger de Coverley.
|
||||
Living in the patriarchal fashion, in the midst of tenants and retainers,
|
||||
who looked up to him as their chief, and for whose welfare and protection
|
||||
he considered himself responsible, the country gentleman valued above all
|
||||
things the principle of Loyalty. To the moneyed classes in the towns he
|
||||
was instinctively opposed; he regarded their interests, both social and
|
||||
commercial, as contrary to his own; he looked with dislike and suspicion
|
||||
on the economical principles of government and conduct on which these
|
||||
classes naturally rely. Even the younger sons of county families had in
|
||||
Addison's day abandoned the custom, common enough in the feudal times, of
|
||||
seeking their fortune in trade. Many a Will Wimble now spent his whole
|
||||
life in the country, training dogs for his neighbours, fishing their
|
||||
streams, making whips for their young heirs, and even garters for their
|
||||
wives and daughters.[1]
|
||||
|
||||
|
||||
|
||||
8
data/sakana_wiki.txt
Normal file
8
data/sakana_wiki.txt
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
Sakana AI Co, Ltd. is a Japanese artificial intelligence company based in Tokyo.
|
||||
Overview
|
||||
|
||||
Sakana AI's main research fields are evolution and collective intelligence of AI. The company's name is derived from the Japanese word さかな (sakana), which means fish. This represents the idea of a school of fish coming together and forming a coherent entity from simple rules, which is an analogy of collective intelligence.[2]
|
||||
|
||||
The company was founded by David Ha, Llion Jones and Ren Ito. Llion Jones co-authored the famous paper "Attention Is All You Need" when he was working for Google in 2017. The company raised $30M in its seed funding round from Lux Capital and Khosla Ventures.[3] The company raised approximately $200M from companies such as Mitsubishi UFJ, SMBC, Mizuho, Itochu, KDDI, Nomura and Nvidia in its series A funding round in 2024.[4]
|
||||
|
||||
In January 2024, Sakana AI developed a method to build new AI models by 'breeding' multiple existing models, which it sees as a means to democratise AI development, as this process does not require large computational resources.[5] Sakana AI is also developing a model called the AI Scientist, which automates the entire process of scientific research.[6] The Nikkei estimated the company's value at 19 billion yen in 2024.[7]
|
||||
620
data/self_generate_qa.py
Normal file
620
data/self_generate_qa.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
import argparse
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datasets import Dataset, load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ctx_to_lora.data.definitions import (
|
||||
CLOSED_QA_INTX_TEMPLATES,
|
||||
RAW_DATA_DIR,
|
||||
SELF_GEN_DATA_DIR,
|
||||
)
|
||||
from ctx_to_lora.data.processing import (
|
||||
filter_none,
|
||||
get_preprocessing_fn,
|
||||
load_and_process_dataset,
|
||||
tokenize_ctx_text,
|
||||
)
|
||||
from ctx_to_lora.data.self_gen_template import (
|
||||
PRE_CTX,
|
||||
PROMPT_TEMPLATE,
|
||||
QA_PROMPT_TEMPLATE,
|
||||
SELF_GEN_SYSTEM_MSG,
|
||||
SELF_QA_INTX,
|
||||
)
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.utils import clear_gpu
|
||||
|
||||
STOP_STRINGS = {
|
||||
"google/gemma-2-2b-it": ["<eos>", "<end_of_turn>"],
|
||||
}
|
||||
|
||||
MODEL_CTX_LEN = {
|
||||
"google/gemma-2-27b-it": 8192,
|
||||
"google/gemma-2-2b-it": 8192,
|
||||
"google/gemma-2-9b-it": 8192,
|
||||
# qwen 4b has 256k ctx length but using lower max lengths is faster
|
||||
"Qwen/Qwen3-4B-Instruct-2507": 2**13 + 2**12,
|
||||
}
|
||||
|
||||
|
||||
def truncate_middle_if_too_long(
|
||||
input_ids: list[int],
|
||||
max_length: int,
|
||||
max_new_tokens: int = 256,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Truncate the middle of a list of tokens to fit within a maximum length.
|
||||
|
||||
Args:
|
||||
tokens: List of token IDs
|
||||
max_length: Maximum length for the truncated tokens
|
||||
|
||||
Returns:
|
||||
List of truncated token IDs
|
||||
"""
|
||||
max_new_tokens_half = max_new_tokens // 2
|
||||
# leave max_new_tokens for generation
|
||||
half = max_length // 2 - max_new_tokens_half
|
||||
if len(input_ids) > max_length:
|
||||
return input_ids[:half] + input_ids[-half:]
|
||||
return input_ids
|
||||
|
||||
|
||||
def get_prompt(context: str, q: str, remove_qa_template: bool) -> str:
|
||||
prompt = QA_PROMPT_TEMPLATE if not remove_qa_template else PROMPT_TEMPLATE
|
||||
return prompt.format(context=context, question=q)
|
||||
|
||||
|
||||
def add_closed_qa_prompt(q: str, closed_qa_prob: float = 0.1) -> str:
|
||||
if random.random() <= closed_qa_prob:
|
||||
q = random.choice(CLOSED_QA_INTX_TEMPLATES).format(input=q)
|
||||
return q
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""Load dataset names from YAML config file."""
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
|
||||
def get_dataset_configs(
|
||||
ds_names: list[str] | None,
|
||||
config: dict | None,
|
||||
split: str | None,
|
||||
) -> list[tuple[str, str]]:
|
||||
assert not (ds_names and config), "Cannot provide both ds_names and config"
|
||||
if ds_names:
|
||||
assert split, "When using ds_names, --split must be provided"
|
||||
# Validate ds_names format
|
||||
for ds_name in ds_names:
|
||||
if not isinstance(ds_name, str):
|
||||
raise ValueError(f"Invalid dataset name: {ds_name}")
|
||||
return [(ds_name, split) for ds_name in ds_names]
|
||||
|
||||
if config:
|
||||
dataset_configs = []
|
||||
|
||||
# Process train datasets
|
||||
train_ds_names = config.get("train_ds_names", [])
|
||||
# self_gen_train_ds_names = [
|
||||
# (ds_name.split("/")[-1], "train")
|
||||
# for ds_name in train_ds_names
|
||||
# if ds_name.startswith("self_gen/")
|
||||
# ]
|
||||
self_gen_train_ds_names = [
|
||||
(ds_name, "train")
|
||||
for ds_name in train_ds_names
|
||||
if ds_name.startswith("self_gen/")
|
||||
]
|
||||
if not self_gen_train_ds_names:
|
||||
print("No self_gen datasets found in train_ds_names")
|
||||
dataset_configs.extend(self_gen_train_ds_names)
|
||||
|
||||
# Process validation datasets
|
||||
val_ds_names = config.get("val_ds_names", [])
|
||||
self_gen_val_ds_names = [
|
||||
(ds_name, "validation")
|
||||
for ds_name in val_ds_names
|
||||
if ds_name.startswith("self_gen/")
|
||||
]
|
||||
if not self_gen_val_ds_names:
|
||||
print("No self_gen datasets found in val_ds_names")
|
||||
dataset_configs.extend(self_gen_val_ds_names)
|
||||
|
||||
return dataset_configs
|
||||
|
||||
|
||||
def create_messages(
|
||||
ctxs: list[str],
|
||||
questions: list[list[str]],
|
||||
vllm_model: str,
|
||||
system_template: str,
|
||||
remove_qa_template: bool,
|
||||
) -> list[list[dict]]:
|
||||
"""Create chat messages for the model."""
|
||||
# if "gemma" in vllm_model:
|
||||
# gemma models do not support system messages
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
system_template + "\n\n\n" + get_prompt(ctx, q, remove_qa_template)
|
||||
).strip(),
|
||||
}
|
||||
]
|
||||
for ctx, q_list in zip(ctxs, questions)
|
||||
for q in q_list
|
||||
]
|
||||
# else:
|
||||
# return [
|
||||
# [
|
||||
# {"role": "system", "content": system_template},
|
||||
# {"role": "user", "content": get_prompt(ctx, q)},
|
||||
# ]
|
||||
# for ctx, q_list in zip(ctxs, questions)
|
||||
# for q in q_list
|
||||
# ]
|
||||
|
||||
|
||||
def self_generate(
|
||||
ds_name: str,
|
||||
split: str,
|
||||
args: argparse.Namespace,
|
||||
llm: LLM,
|
||||
system_template: str,
|
||||
parquet_file: str | None = None,
|
||||
do_truncate: bool = False,
|
||||
) -> None:
|
||||
"""Process a single dataset and generate QA pairs."""
|
||||
|
||||
shard_name = ""
|
||||
|
||||
# Conflict checks for ds_name-derived overrides
|
||||
if ds_name is not None:
|
||||
# temperature & closed_qa already handled later; add new ones
|
||||
if "_temp_" in ds_name and args.temp != 0.0:
|
||||
raise ValueError(
|
||||
f"Multiple sources of truth for temperature: CLI arg --temp={args.temp} and dataset name contains temp specification."
|
||||
)
|
||||
if "_closed_qa_prob_" in ds_name and args.closed_qa_prob != 0.0:
|
||||
raise ValueError(
|
||||
f"Multiple sources of truth for closed_qa_prob: CLI arg --closed_qa_prob={args.closed_qa_prob} and dataset name contains closed_qa_prob specification."
|
||||
)
|
||||
|
||||
# Base values from args
|
||||
temp = args.temp
|
||||
closed_qa_prob = args.closed_qa_prob
|
||||
|
||||
# Overrides from ds_name pattern if present
|
||||
if ds_name is not None:
|
||||
if "_temp_" in ds_name:
|
||||
m = re.search(r"_temp_([\d.]+)", ds_name)
|
||||
if m:
|
||||
temp = float(m.group(1))
|
||||
if "_closed_qa_prob_" in ds_name:
|
||||
m = re.search(r"_closed_qa_prob_([\d.]+)", ds_name)
|
||||
if m:
|
||||
closed_qa_prob = float(m.group(1))
|
||||
|
||||
print(f"Processing dataset: {ds_name}, split: {split}")
|
||||
print(f"Using temperature: {temp}")
|
||||
print(f"Using closed QA prompt probability: {closed_qa_prob}")
|
||||
|
||||
if parquet_file:
|
||||
print(f"Loading dataset from parquet file: {parquet_file}")
|
||||
|
||||
split = "train"
|
||||
ds_name = "/".join(parquet_file.split(RAW_DATA_DIR)[-1].split("/")[:-1])
|
||||
|
||||
shard_name = "_" + os.path.basename(parquet_file).replace(".parquet", "")
|
||||
ds = load_dataset(path="parquet", data_files=[parquet_file], split="train")
|
||||
processing_fn = get_preprocessing_fn(ds_name, is_eval=False)
|
||||
ds = ds.map(processing_fn, num_proc=8)
|
||||
|
||||
else:
|
||||
ds_name = ds_name.split("/")[-1] # Extract just the dataset name
|
||||
|
||||
print(f"Loading dataset: {ds_name} with split: {split}")
|
||||
kwargs = dict(ds_name=ds_name, split=split)
|
||||
|
||||
ds = load_and_process_dataset(**kwargs, num_proc=8, remove_cols=False)
|
||||
print(f"Loaded dataset: {ds_name} with split: {split}")
|
||||
|
||||
if args.debug:
|
||||
ds = ds.take(10)
|
||||
|
||||
ds = ds.filter(filter_none, batched=False, num_proc=8)
|
||||
|
||||
tk = get_tokenizer(args.vllm_model, train=True)
|
||||
|
||||
self_qa_intx_tokens = tk(SELF_QA_INTX, add_special_tokens=False)["input_ids"][1:]
|
||||
if args.remove_qa_template:
|
||||
self_qa_intx_tokens = tk("\n\n", add_special_tokens=False)["input_ids"]
|
||||
n_self_qa_intx_tokens = len(self_qa_intx_tokens)
|
||||
pre_ctx_tokens = tk(PRE_CTX, add_special_tokens=False)["input_ids"]
|
||||
n_pre_ctx_tokens = len(pre_ctx_tokens)
|
||||
sys_tokens = tk(system_template.split("\n")[0], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:-1]
|
||||
n_sys_tokens = len(sys_tokens)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
ds = ds.map(
|
||||
tokenize_ctx_text,
|
||||
fn_kwargs={"tokenizer": tk},
|
||||
batched=True,
|
||||
batch_size=50_000,
|
||||
keep_in_memory=True,
|
||||
)
|
||||
|
||||
ctxs = [sample["context"] for sample in ds]
|
||||
questions = [
|
||||
[add_closed_qa_prompt(q, closed_qa_prob) for q in sample["prompts"] if q]
|
||||
for sample in ds
|
||||
]
|
||||
|
||||
questions = [q_list for q_list in ds["prompts"] if len(q_list) > 0]
|
||||
|
||||
print(f"Loaded {len(ctxs)} contexts and {len(questions)} questions")
|
||||
|
||||
k = 16
|
||||
fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
|
||||
|
||||
chunk_size = 1_000
|
||||
for chunk_idx, start in enumerate(range(0, len(ctxs), chunk_size)):
|
||||
print(f"Processing chunk {chunk_idx}")
|
||||
|
||||
chunk_ctxs = ctxs[start : start + chunk_size]
|
||||
chunk_questions = questions[start : start + chunk_size]
|
||||
chunk_messages = create_messages(
|
||||
chunk_ctxs,
|
||||
chunk_questions,
|
||||
args.vllm_model,
|
||||
SELF_GEN_SYSTEM_MSG,
|
||||
args.remove_qa_template,
|
||||
)
|
||||
|
||||
if do_truncate:
|
||||
# we should only do this for evaluation data
|
||||
tokenized_contents = tk(
|
||||
[m[0]["content"] for m in chunk_messages],
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
tokenized_contents["input_ids"] = [
|
||||
truncate_middle_if_too_long(
|
||||
ids,
|
||||
max_length=MODEL_CTX_LEN[args.vllm_model],
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
)
|
||||
for ids in tokenized_contents["input_ids"]
|
||||
]
|
||||
contents = tk.batch_decode(
|
||||
tokenized_contents["input_ids"], skip_special_tokens=True
|
||||
)
|
||||
for c, m in zip(contents, chunk_messages):
|
||||
m[0]["content"] = c
|
||||
|
||||
print(f"Generating from {len(chunk_messages)} contexts")
|
||||
|
||||
# Clear GPU memory before processing the next chunk
|
||||
clear_gpu()
|
||||
execute_qa_generation(
|
||||
fpath + f"_{chunk_idx:04d}",
|
||||
args,
|
||||
llm,
|
||||
temp,
|
||||
tk,
|
||||
self_qa_intx_tokens,
|
||||
n_self_qa_intx_tokens,
|
||||
sys_tokens,
|
||||
n_sys_tokens,
|
||||
chunk_ctxs,
|
||||
ds[start : start + chunk_size]["ctx_ids"],
|
||||
chunk_questions,
|
||||
chunk_messages,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def execute_qa_generation(
|
||||
fpath,
|
||||
args,
|
||||
llm,
|
||||
temp,
|
||||
tk,
|
||||
self_qa_intx_tokens,
|
||||
n_self_qa_intx_tokens,
|
||||
sys_tokens,
|
||||
n_sys_tokens,
|
||||
ctxs,
|
||||
ctx_ids,
|
||||
questions,
|
||||
messages,
|
||||
k,
|
||||
):
|
||||
completions = llm.chat(
|
||||
messages,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=args.max_new_tokens,
|
||||
logprobs=k,
|
||||
temperature=temp,
|
||||
seed=42,
|
||||
spaces_between_special_tokens=False,
|
||||
skip_special_tokens=False,
|
||||
include_stop_str_in_output=True,
|
||||
),
|
||||
)
|
||||
|
||||
self_gen_data = {
|
||||
ctx: {
|
||||
"ctx_ids": ctx_ids,
|
||||
"input_ids": [],
|
||||
"response_start_end": [],
|
||||
"logprobs_vals": [],
|
||||
"logprobs_indices": [],
|
||||
}
|
||||
for ctx, ctx_ids in zip(ctxs, ctx_ids)
|
||||
}
|
||||
c = 0
|
||||
n_skips = 0
|
||||
sys_start = None
|
||||
for ctx, q_list in zip(ctxs, questions):
|
||||
# self_gen_data[ctx]["ctx_ids"] = ctx_ids
|
||||
for i, _ in enumerate(q_list):
|
||||
# response = completions[c + i].outputs[0].text
|
||||
reason = completions[c + i].outputs[0].finish_reason
|
||||
if reason != "stop":
|
||||
# print(f"idx: {c + i}")
|
||||
print(f"finish_reason: {completions[c + i].outputs[0].finish_reason}")
|
||||
print(f"Skipping due to finish_reason={reason} != 'stop'")
|
||||
n_skips += 1
|
||||
continue
|
||||
|
||||
# includes the logprob before the first response token
|
||||
# but excludes the logprob from eos token
|
||||
logp = completions[c + i].outputs[0].logprobs
|
||||
|
||||
# len = num response tokens
|
||||
n_response_tokens = len(completions[c + i].outputs[0].token_ids)
|
||||
|
||||
logp_indices = np.empty((n_response_tokens, k), dtype=np.int32)
|
||||
# float-16 is better for this range
|
||||
logp_vals = np.empty((n_response_tokens, k), dtype=np.float16)
|
||||
assert len(logp) == n_response_tokens, (
|
||||
f"Expected {n_response_tokens} logp entries, got {len(logp)}"
|
||||
)
|
||||
|
||||
for li, info_d in enumerate(logp):
|
||||
for j, (idx, tok_info) in enumerate(info_d.items()):
|
||||
logp_indices[li, j] = idx
|
||||
logp_vals[li, j] = tok_info.logprob
|
||||
|
||||
prompt_ids = completions[c + i].prompt_token_ids # 1d list
|
||||
# token_ids only includes generated tokens, not the prompt
|
||||
response_token_ids = completions[c + i].outputs[0].token_ids # 1d list
|
||||
all_ids = prompt_ids + response_token_ids
|
||||
res_start = len(prompt_ids)
|
||||
res_end = res_start + n_response_tokens
|
||||
|
||||
if sys_start is None:
|
||||
for ii in range(len(prompt_ids) - n_sys_tokens):
|
||||
if prompt_ids[ii : ii + n_sys_tokens] == sys_tokens:
|
||||
# found the start of the system message
|
||||
sys_start = ii
|
||||
break
|
||||
|
||||
q_start = None
|
||||
for ii in range(
|
||||
len(prompt_ids) - n_self_qa_intx_tokens,
|
||||
-1,
|
||||
-1,
|
||||
):
|
||||
if prompt_ids[ii : ii + n_self_qa_intx_tokens] == self_qa_intx_tokens:
|
||||
# found the start of the user input
|
||||
q_start = ii + n_self_qa_intx_tokens
|
||||
break
|
||||
|
||||
# bos + question + eos + start model turn + response + eos
|
||||
input_ids = all_ids[:sys_start] + all_ids[q_start:res_end]
|
||||
|
||||
# relative to the input_ids
|
||||
res_start = res_start - q_start + sys_start
|
||||
res_end = res_start + n_response_tokens
|
||||
|
||||
# arrays will be saved as nested lists of numbers
|
||||
|
||||
self_gen_data[ctx]["input_ids"].append(input_ids)
|
||||
# assume single-turn chat
|
||||
self_gen_data[ctx]["response_start_end"].append((res_start, res_end))
|
||||
self_gen_data[ctx]["logprobs_vals"].append(logp_vals)
|
||||
self_gen_data[ctx]["logprobs_indices"].append(logp_indices)
|
||||
|
||||
c += i + 1
|
||||
|
||||
print(f"Skipped {n_skips} responses due to missing stop strings")
|
||||
samples = [
|
||||
{
|
||||
# "context": ctx,
|
||||
# "prompts": q_list,
|
||||
# "responses": self_gen_data[ctx]["responses"],
|
||||
"ctx_ids": self_gen_data[ctx]["ctx_ids"],
|
||||
"input_ids": self_gen_data[ctx]["input_ids"],
|
||||
"response_start_end": self_gen_data[ctx]["response_start_end"],
|
||||
# "prompt_start_end": self_gen_data[ctx]["prompt_start_end"],
|
||||
"logprobs_vals": self_gen_data[ctx]["logprobs_vals"],
|
||||
"logprobs_indices": self_gen_data[ctx]["logprobs_indices"],
|
||||
}
|
||||
for ctx, q_list in zip(ctxs, questions)
|
||||
]
|
||||
|
||||
if args.debug:
|
||||
for sample in samples:
|
||||
# print(f"context={tk.decode(sample['ctx_ids'])}")
|
||||
print(f"QA={[tk.decode(ids) for ids in sample['input_ids']]}")
|
||||
|
||||
for input_ids, (start, end) in zip(
|
||||
sample["input_ids"], sample["response_start_end"]
|
||||
):
|
||||
print(f"start={start}, end={end}")
|
||||
print(f"response={tk.decode(input_ids[start:end])}")
|
||||
|
||||
print(f"logprobs_vals={[x.shape for x in sample['logprobs_vals']]}")
|
||||
print(f"logprobs_indices={[x.shape for x in sample['logprobs_indices']]}")
|
||||
for indices in sample["logprobs_indices"]:
|
||||
print(f"logprobs_indices={indices[-1]}")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"Generated {len(samples)} samples")
|
||||
# random.shuffle(samples)
|
||||
|
||||
# Save results
|
||||
# df = pd.DataFrame(samples)
|
||||
# ds_out = Dataset.from_pandas(df)
|
||||
ds_out = Dataset.from_list(samples)
|
||||
# fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
|
||||
|
||||
if args.debug:
|
||||
fpath += "_debug"
|
||||
os.makedirs(os.path.dirname(fpath), exist_ok=True)
|
||||
|
||||
fpath = f"{fpath}.parquet"
|
||||
ds_out.to_parquet(fpath)
|
||||
print(f"Saved to {fpath}")
|
||||
|
||||
# Cleanup
|
||||
del samples, ds_out, completions, messages, ctxs, questions
|
||||
clear_gpu()
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Generate QA pairs using VLLM")
|
||||
parser.add_argument(
|
||||
"--vllm_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="VLLM model name (e.g., google/gemma-2-2b-it)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode (process only 10 samples)",
|
||||
)
|
||||
|
||||
# Either config file OR ds_names + split
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to YAML config file with train_ds_names/val_ds_names",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ds_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of dataset names/shard patterns",
|
||||
)
|
||||
group.add_argument(
|
||||
"--glob_pattern",
|
||||
type=str,
|
||||
help="Glob pattern to match dataset names (e.g., 'data/raw_datasets/fw_qa_3/*')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
help="Dataset split to use when using --ds_names (required with --ds_names)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for sampling (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--closed_qa_prob",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Probability of using closed QA prompt template (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_truncate",
|
||||
action="store_true",
|
||||
help="Truncate contexts to fit model context length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_qa_template",
|
||||
action="store_true",
|
||||
help="Remove QA template formatting from prompts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of new tokens to generate (default: 256)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if args.ds_names and not args.split:
|
||||
raise ValueError("--split is required when using --ds_names")
|
||||
|
||||
vllm_model = args.vllm_model
|
||||
print(f"Using model: {vllm_model}")
|
||||
|
||||
# Setup model-specific configurations
|
||||
llm_kwargs = dict(
|
||||
model=vllm_model,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=MODEL_CTX_LEN.get(vllm_model),
|
||||
max_num_batched_tokens=16384,
|
||||
max_num_seqs=32, # avoid oom when getting logprobs
|
||||
)
|
||||
|
||||
print(f"{llm_kwargs=}")
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
# Get dataset configs from config or CLI args
|
||||
config = load_config(args.config) if args.config else None
|
||||
if args.ds_names or args.config:
|
||||
dataset_configs = get_dataset_configs(
|
||||
ds_names=args.ds_names,
|
||||
config=config,
|
||||
split=args.split,
|
||||
)
|
||||
|
||||
# Process each dataset
|
||||
for ds_name, split in dataset_configs:
|
||||
print(f"Processing dataset: {ds_name}, split: {split}")
|
||||
self_generate(
|
||||
ds_name, split, args, llm, SELF_GEN_SYSTEM_MSG, None, args.do_truncate
|
||||
)
|
||||
else:
|
||||
assert args.glob_pattern, (
|
||||
"glob_pattern must be provided if no ds_names or config"
|
||||
)
|
||||
files = glob(args.glob_pattern)
|
||||
for file in files:
|
||||
print(f"Processing file: {file}")
|
||||
self_generate(
|
||||
ds_name=None,
|
||||
parquet_file=file,
|
||||
split=args.split,
|
||||
args=args,
|
||||
llm=llm,
|
||||
system_template=SELF_GEN_SYSTEM_MSG,
|
||||
do_truncate=args.do_truncate,
|
||||
)
|
||||
685
demo/app.py
Normal file
685
demo/app.py
Normal file
|
|
@ -0,0 +1,685 @@
|
|||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
# Add the src directory to the path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from ctx_to_lora.data.processing import tokenize_ctx_text
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling import hypernet
|
||||
|
||||
sys.modules["ctx_to_lora.modeling_utils"] = hypernet
|
||||
|
||||
# Global state
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
modulated_model = None
|
||||
chat_history = []
|
||||
ctx_tokenizer = None
|
||||
base_tokenizer = None
|
||||
|
||||
try:
|
||||
DEFAULT_CONTEXT = Path("data/sakana_wiki.txt").read_text(encoding="utf-8").strip()
|
||||
except FileNotFoundError:
|
||||
DEFAULT_CONTEXT = ""
|
||||
|
||||
WARNING_MESSAGE = (
|
||||
"⚠️ **Caution**: This is an educational proof-of-concept demonstration.\n"
|
||||
"The model may generate inaccurate information or hallucinate facts."
|
||||
)
|
||||
|
||||
FOOTER = """
|
||||
⚠️ This model is an experimental prototype and is only available for educational and research and development purposes. It is not suitable for commercial use or in environments where failure can have significant effects (mission-critical environments).
|
||||
The use of this model is at the user's own risk and its performance and results is not guaranteed in any way.
|
||||
Sakana AI is not responsible for any direct or indirect loss resulting from using this model, regardless of the outcome.
|
||||
"""
|
||||
|
||||
|
||||
def load_custom_chat_template(tokenizer, model_name):
|
||||
if "gemma" in model_name.lower():
|
||||
template_path = "chat_templates/google/gemma-2-2b-it.jinja"
|
||||
if os.path.exists(template_path):
|
||||
with open(template_path) as f:
|
||||
template_content = f.read()
|
||||
tokenizer.chat_template = template_content
|
||||
print(f"Loaded custom chat template from {template_path}")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_available_checkpoints():
|
||||
trained_d2l_checkpoints = {
|
||||
str(path)
|
||||
for path in Path().glob("trained_d2l/**/pytorch_model.bin")
|
||||
if path.is_file()
|
||||
}
|
||||
run_output_checkpoints = {
|
||||
str(path)
|
||||
for path in Path().glob("train_outputs/runs/**/pytorch_model.bin")
|
||||
if path.is_file()
|
||||
}
|
||||
checkpoints = sorted(trained_d2l_checkpoints) + sorted(
|
||||
run_output_checkpoints - trained_d2l_checkpoints
|
||||
)
|
||||
return checkpoints if checkpoints else ["No checkpoints found"]
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
checkpoint_path: str,
|
||||
) -> tuple[str, gr.update, gr.update, gr.update, gr.update]:
|
||||
global modulated_model, ctx_tokenizer, base_tokenizer, chat_history
|
||||
|
||||
if not checkpoint_path or checkpoint_path == "No checkpoints found":
|
||||
return (
|
||||
"⚠️ Please select a valid checkpoint",
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
try:
|
||||
print(f"Loading checkpoint: {checkpoint_path}")
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
|
||||
state_dict = torch.load(checkpoint_path, weights_only=False)
|
||||
modulated_model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict,
|
||||
train=False,
|
||||
use_flash_attn=True,
|
||||
use_sequence_packing=False,
|
||||
)
|
||||
modulated_model = modulated_model.to(device).to(torch.bfloat16)
|
||||
modulated_model.eval()
|
||||
|
||||
ctx_encoder_model_name_or_path = (
|
||||
modulated_model.ctx_encoder_args.ctx_encoder_model_name_or_path
|
||||
or modulated_model.base_model.config.name_or_path
|
||||
)
|
||||
ctx_tokenizer = get_tokenizer(ctx_encoder_model_name_or_path, train=False)
|
||||
base_tokenizer = get_tokenizer(
|
||||
modulated_model.base_model.config.name_or_path, train=False
|
||||
)
|
||||
|
||||
load_custom_chat_template(
|
||||
base_tokenizer, modulated_model.base_model.config.name_or_path
|
||||
)
|
||||
|
||||
chat_history = [{"role": "system", "content": ""}]
|
||||
|
||||
model_name = modulated_model.base_model.config.name_or_path
|
||||
success_msg = (
|
||||
f"✅ Successfully loaded checkpoint!\n\nBase Model: {model_name}\n\n"
|
||||
"You can now add context and start chatting."
|
||||
)
|
||||
return (
|
||||
success_msg,
|
||||
gr.update(interactive=True), # msg
|
||||
gr.update(interactive=True), # send_btn
|
||||
gr.update(interactive=True), # system_msg
|
||||
gr.update(interactive=True), # clear_btn
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = (
|
||||
f"❌ Error loading checkpoint:\n{str(e)}\n\n{traceback.format_exc()}"
|
||||
)
|
||||
print(error_msg)
|
||||
return (
|
||||
error_msg,
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=False),
|
||||
)
|
||||
|
||||
|
||||
def process_context(context: str) -> dict:
|
||||
context = context.strip() if context else ""
|
||||
tokenized_contexts = tokenize_ctx_text({"context": [context]}, ctx_tokenizer)
|
||||
ctx_ids = tokenized_contexts["ctx_ids"]
|
||||
ctx_ids = [
|
||||
torch.tensor(ctx_id, dtype=torch.long, device=device) for ctx_id in ctx_ids
|
||||
]
|
||||
ctx_attn_mask = [torch.ones_like(ids) for ids in ctx_ids]
|
||||
ctx_attn_mask = [
|
||||
torch.tensor(mask, dtype=torch.long, device=device) for mask in ctx_attn_mask
|
||||
]
|
||||
ctx_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_ids,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_attn_mask,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
return {"ctx_ids": ctx_ids, "ctx_attn_mask": ctx_attn_mask}
|
||||
|
||||
|
||||
def add_user_message(message: str, history):
|
||||
if not message.strip():
|
||||
return history, ""
|
||||
return history + [[message, None]], ""
|
||||
|
||||
|
||||
def generate_response(
|
||||
history: list[list[str]],
|
||||
system_msg: str,
|
||||
context: str,
|
||||
context_scaler: float,
|
||||
bias_scaler: float,
|
||||
):
|
||||
global modulated_model, chat_history, ctx_tokenizer, base_tokenizer
|
||||
|
||||
if modulated_model is None:
|
||||
history[-1][1] = "Please load a checkpoint first."
|
||||
yield history
|
||||
return
|
||||
|
||||
if not history or history[-1][0] is None:
|
||||
yield history
|
||||
return
|
||||
|
||||
try:
|
||||
user_message = history[-1][0]
|
||||
|
||||
if system_msg.strip() and chat_history[0]["role"] == "system":
|
||||
chat_history[0]["content"] = system_msg.strip()
|
||||
|
||||
chat_history.append({"role": "user", "content": user_message})
|
||||
|
||||
context = context.strip() if context else ""
|
||||
print(f"Processing single context with scaler: {context_scaler}")
|
||||
print(f"Bias scaler: {bias_scaler}")
|
||||
|
||||
with torch.inference_mode(), torch.amp.autocast(str(device)):
|
||||
ctx_inputs = process_context(context)
|
||||
ctx_ids = ctx_inputs["ctx_ids"].to(device)
|
||||
ctx_attn_mask = ctx_inputs["ctx_attn_mask"].to(device)
|
||||
|
||||
scalers_tensor = torch.tensor(
|
||||
[context_scaler], dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
model_inputs = base_tokenizer.apply_chat_template(
|
||||
chat_history, return_tensors="pt", add_generation_prompt=True
|
||||
).to(device)
|
||||
|
||||
print(f"Context: {context}")
|
||||
print(f"Chat history: {chat_history}")
|
||||
|
||||
outputs = modulated_model.generate(
|
||||
ctx_ids=ctx_ids,
|
||||
ctx_attn_mask=ctx_attn_mask,
|
||||
n_ctx_chunks=torch.tensor([len(ctx_ids)], device=ctx_ids.device),
|
||||
scalers=scalers_tensor,
|
||||
bias_scaler=bias_scaler,
|
||||
input_ids=model_inputs,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
response = base_tokenizer.decode(
|
||||
outputs[0][model_inputs.shape[1] :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
chat_history.append({"role": "assistant", "content": response})
|
||||
|
||||
words = response.split()
|
||||
partial_response = ""
|
||||
for word in words:
|
||||
partial_response += word + " "
|
||||
history[-1][1] = partial_response.strip()
|
||||
yield history
|
||||
|
||||
history[-1][1] = response
|
||||
yield history
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
error_msg = f"❌ Error: {str(e)}"
|
||||
print(f"Error generating response: {str(e)}\n\n{traceback.format_exc()}")
|
||||
history[-1][1] = error_msg
|
||||
yield history
|
||||
|
||||
|
||||
def reset_chat(system_msg: str):
|
||||
global chat_history
|
||||
chat_history = [
|
||||
{"role": "system", "content": system_msg.strip() if system_msg else ""}
|
||||
]
|
||||
return [[None, WARNING_MESSAGE]], "Chat history reset successfully!"
|
||||
|
||||
|
||||
custom_css = """
|
||||
:root {
|
||||
color-scheme: light;
|
||||
}
|
||||
|
||||
.gradio-container {
|
||||
font-family: 'Inter', sans-serif;
|
||||
}
|
||||
|
||||
.chat-container {
|
||||
border-radius: 10px;
|
||||
border: 2px solid #d1d5db;
|
||||
}
|
||||
|
||||
.context-field {
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
margin-bottom: 10px;
|
||||
border: 2px solid #d1d5db;
|
||||
}
|
||||
|
||||
.status-box {
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
margin: 10px 0;
|
||||
border: 2px solid #e5e7eb;
|
||||
}
|
||||
|
||||
.primary-button {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
color: white;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.secondary-button {
|
||||
background-color: #607d8b;
|
||||
border: none;
|
||||
border-radius: 6px;
|
||||
color: white;
|
||||
}
|
||||
|
||||
#chatbot {
|
||||
height: 500px;
|
||||
}
|
||||
|
||||
.instruction-text {
|
||||
font-style: italic;
|
||||
color: #666;
|
||||
font-size: 0.9em;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.warning-box {
|
||||
background-color: #fff3cd;
|
||||
border: 2px solid #ffc107;
|
||||
border-radius: 8px;
|
||||
padding: 12px;
|
||||
margin: 10px 0;
|
||||
color: #856404;
|
||||
font-size: 0.95em;
|
||||
}
|
||||
|
||||
.warning-box strong {
|
||||
color: #d97706;
|
||||
}
|
||||
|
||||
.disabled-overlay {
|
||||
background-color: #f5f5f5;
|
||||
border: 3px dashed #999;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
color: #999;
|
||||
}
|
||||
|
||||
.chat-disabled-notice {
|
||||
border: 3px solid #f59e0b;
|
||||
border-radius: 8px;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
color: #92400e;
|
||||
font-weight: 600;
|
||||
text-align: center;
|
||||
background-color: #fef3c7;
|
||||
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.internalization-banner {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
margin-bottom: 15px;
|
||||
text-align: center;
|
||||
font-weight: 600;
|
||||
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.15);
|
||||
border: 3px solid #5568d3;
|
||||
}
|
||||
|
||||
.internalization-banner h3 {
|
||||
margin: 0 0 10px 0;
|
||||
font-size: 1.2em;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.internalization-banner p {
|
||||
margin: 5px 0;
|
||||
font-size: 0.95em;
|
||||
font-weight: 400;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.context-section-header {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border: 3px solid #5568d3;
|
||||
border-left: 6px solid #4c51bf;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
border-radius: 6px;
|
||||
color: white;
|
||||
box-shadow: 0 2px 6px rgba(102, 126, 234, 0.3);
|
||||
}
|
||||
|
||||
.context-section-header strong,
|
||||
.context-section-header small {
|
||||
color: white;
|
||||
}
|
||||
|
||||
.chat-section-header {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
border: 3px solid #2563eb;
|
||||
border-left: 6px solid #1d4ed8;
|
||||
padding: 15px;
|
||||
margin-bottom: 15px;
|
||||
border-radius: 6px;
|
||||
color: white;
|
||||
box-shadow: 0 2px 6px rgba(59, 130, 246, 0.3);
|
||||
}
|
||||
|
||||
.chat-section-header strong,
|
||||
.chat-section-header small {
|
||||
color: white;
|
||||
}
|
||||
|
||||
.panel-box {
|
||||
background-color: rgba(249, 250, 251, 0.5);
|
||||
border: 2px solid rgba(209, 213, 219, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 20px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
.chat-panel-box {
|
||||
background-color: rgba(249, 250, 251, 0.5);
|
||||
border: 2px solid rgba(209, 213, 219, 0.5);
|
||||
border-radius: 12px;
|
||||
padding: 20px;
|
||||
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
|
||||
}
|
||||
|
||||
#checkpoint-dropdown {
|
||||
position: relative;
|
||||
z-index: 20;
|
||||
}
|
||||
|
||||
#checkpoint-dropdown [role="listbox"] {
|
||||
z-index: 9999 !important;
|
||||
}
|
||||
|
||||
/* Dark mode support */
|
||||
.dark .panel-box,
|
||||
.dark .chat-panel-box {
|
||||
background-color: rgba(31, 41, 55, 0.5);
|
||||
border: 2px solid rgba(75, 85, 99, 0.5);
|
||||
}
|
||||
|
||||
.dark .context-field,
|
||||
.dark .chat-container {
|
||||
border-color: rgba(75, 85, 99, 0.5);
|
||||
}
|
||||
|
||||
.dark .status-box {
|
||||
border-color: rgba(55, 65, 81, 0.5);
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def create_demo():
|
||||
with gr.Blocks(
|
||||
title="Doc-to-LoRA Chat Interface",
|
||||
theme=gr.themes.Soft(),
|
||||
css=custom_css,
|
||||
) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# 📜 Doc-to-LoRA Chat Interface
|
||||
|
||||
Load a hypernetwork checkpoint and chat with a context-modulated language model.
|
||||
Add one context with a scaling parameter to influence the model's responses.
|
||||
"""
|
||||
)
|
||||
|
||||
gr.HTML(
|
||||
"""
|
||||
<div class="internalization-banner">
|
||||
<h3>🧠 How Context Internalization Works</h3>
|
||||
<p>📥 Contexts are processed by the hypernetworkto dynamically modulate the base model's parameters</p>
|
||||
<p>🚫 Contexts are NOT passed as text to the base model — they influence behavior internally</p>
|
||||
<p>💬 Only your chat messages (below) are sent to the language model</p>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
### 📖 Usage Instructions
|
||||
|
||||
1. **Load a Checkpoint**: Select a hypernetwork checkpoint from the dropdown and click "Load Checkpoint"
|
||||
2. **Configure Context**:
|
||||
- Enter your context information in the text field
|
||||
- Adjust the scaling slider to control context influence
|
||||
3. **Set Bias Scaler**: Adjust the bias scaler to control overall model behavior
|
||||
4. **Start Chatting**: Once the model is loaded, type your message and press Shift+Enter or click Send
|
||||
5. **Reset**: Use the "Reset Chat" button to start a new conversation
|
||||
|
||||
💡 **Tip**: You can use context to provide background information or specific knowledge
|
||||
that should influence the model's responses.
|
||||
"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_classes="panel-box"):
|
||||
gr.Markdown("### 📦 Load Checkpoint")
|
||||
gr.Markdown(
|
||||
"*Select a trained hypernetwork checkpoint to begin.*",
|
||||
elem_classes="instruction-text",
|
||||
)
|
||||
|
||||
checkpoint_dropdown = gr.Dropdown(
|
||||
choices=get_available_checkpoints(),
|
||||
label="Select Checkpoint",
|
||||
value=None,
|
||||
interactive=True,
|
||||
elem_id="checkpoint-dropdown",
|
||||
)
|
||||
|
||||
load_btn = gr.Button("Load Checkpoint", variant="primary", size="lg")
|
||||
|
||||
status_box = gr.Textbox(
|
||||
label="Status",
|
||||
lines=8,
|
||||
interactive=False,
|
||||
elem_classes="status-box",
|
||||
)
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
gr.HTML(
|
||||
"""
|
||||
<div class="context-section-header">
|
||||
<strong>🧠 Context Internalization (Hypernetwork Input)</strong><br>
|
||||
<small>This context modulates the model internally — it is NOT shown to the base model</small>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
|
||||
context = gr.Textbox(
|
||||
label="🧠 Context (Internalized via Hypernetwork)",
|
||||
placeholder="Enter context to be internalized by the hypernetwork...",
|
||||
lines=4,
|
||||
value=DEFAULT_CONTEXT,
|
||||
)
|
||||
context_scaler = gr.Slider(
|
||||
minimum=-2.0,
|
||||
maximum=2.0,
|
||||
step=0.01,
|
||||
value=1.0,
|
||||
label="Context Scaling",
|
||||
)
|
||||
|
||||
gr.Markdown("---")
|
||||
|
||||
bias_scaler = gr.Slider(
|
||||
minimum=-2.0,
|
||||
maximum=2.0,
|
||||
step=0.01,
|
||||
value=1.0,
|
||||
label="Bias Scaler",
|
||||
info="A single scalar applied to bias parameters (independent of contexts)",
|
||||
)
|
||||
|
||||
with gr.Column(scale=2, elem_classes="chat-panel-box"):
|
||||
gr.HTML(
|
||||
"""
|
||||
<div class="chat-section-header">
|
||||
<strong>💬 Chat Interface (Direct Input to Base Model)</strong><br>
|
||||
<small>Your messages here are the ONLY text the base model sees — contexts above influence it internally</small>
|
||||
</div>
|
||||
"""
|
||||
)
|
||||
|
||||
chat_status_notice = gr.HTML(
|
||||
"""
|
||||
<div class="chat-disabled-notice">
|
||||
🔒 <strong>Chat Disabled:</strong> Please load a checkpoint first to enable chat functionality.
|
||||
</div>
|
||||
""",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
system_msg = gr.Textbox(
|
||||
label="System Message (Optional - Sent to Base Model)",
|
||||
placeholder="Load a checkpoint to enable chat...",
|
||||
lines=2,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
chatbot = gr.Chatbot(
|
||||
label="Conversation",
|
||||
show_copy_button=True,
|
||||
height=500,
|
||||
elem_id="chatbot",
|
||||
elem_classes="chat-container",
|
||||
value=[
|
||||
[
|
||||
None,
|
||||
"🔒 Chat is currently disabled. Please load a checkpoint from the left panel to begin chatting.",
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
msg = gr.Textbox(
|
||||
label="Your Message (Sent Directly to Base Model)",
|
||||
placeholder="⚠️ Load a checkpoint first to start chatting...",
|
||||
lines=2,
|
||||
scale=4,
|
||||
interactive=False,
|
||||
)
|
||||
send_btn = gr.Button(
|
||||
"🔒 Send (Disabled)",
|
||||
variant="primary",
|
||||
scale=1,
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
clear_btn = gr.Button(
|
||||
"🔒 Reset Chat (Disabled)",
|
||||
variant="secondary",
|
||||
interactive=False,
|
||||
)
|
||||
reset_status = gr.Textbox(label="Reset Status", visible=False)
|
||||
|
||||
load_btn.click(
|
||||
fn=load_checkpoint,
|
||||
inputs=[checkpoint_dropdown],
|
||||
outputs=[status_box, msg, send_btn, system_msg, clear_btn],
|
||||
).then(
|
||||
fn=lambda: (
|
||||
gr.update(visible=False),
|
||||
gr.update(
|
||||
placeholder="Type your message here... (Shift+Enter for new line)"
|
||||
),
|
||||
gr.update(value="Send"),
|
||||
gr.update(value="🔄 Reset Chat"),
|
||||
gr.update(value=[[None, WARNING_MESSAGE]]),
|
||||
),
|
||||
outputs=[chat_status_notice, msg, send_btn, clear_btn, chatbot],
|
||||
)
|
||||
|
||||
msg.submit(
|
||||
fn=add_user_message,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[chatbot, msg],
|
||||
).then(
|
||||
fn=generate_response,
|
||||
inputs=[
|
||||
chatbot,
|
||||
system_msg,
|
||||
context,
|
||||
context_scaler,
|
||||
bias_scaler,
|
||||
],
|
||||
outputs=[chatbot],
|
||||
)
|
||||
|
||||
send_btn.click(
|
||||
fn=add_user_message,
|
||||
inputs=[msg, chatbot],
|
||||
outputs=[chatbot, msg],
|
||||
).then(
|
||||
fn=generate_response,
|
||||
inputs=[
|
||||
chatbot,
|
||||
system_msg,
|
||||
context,
|
||||
context_scaler,
|
||||
bias_scaler,
|
||||
],
|
||||
outputs=[chatbot],
|
||||
)
|
||||
|
||||
clear_btn.click(
|
||||
fn=reset_chat,
|
||||
inputs=[system_msg],
|
||||
outputs=[chatbot, reset_status],
|
||||
)
|
||||
|
||||
gr.Markdown(f"---\n{FOOTER.strip()}")
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = create_demo()
|
||||
demo.launch(
|
||||
server_name="0.0.0.0",
|
||||
server_port=7861,
|
||||
share=False,
|
||||
debug=True,
|
||||
)
|
||||
41
examples/python_api.py
Normal file
41
examples/python_api.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
# caveat: this interface only supports non-batched inputs
|
||||
# for batched inference please see `src/ctx_to_lora/modeling/hypernet.py`
|
||||
import torch
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
|
||||
# model loading
|
||||
checkpoint_path = "trained_d2l/gemma_demo/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path, weights_only=False)
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict, train=False, use_sequence_packing=False
|
||||
)
|
||||
model.reset()
|
||||
tokenizer = get_tokenizer(model.base_model.name_or_path)
|
||||
|
||||
# prepare data
|
||||
doc = open("data/sakana_wiki.txt", "r").read()
|
||||
chat = [{"role": "user", "content": "Tell me about Sakana AI."}]
|
||||
chat_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
|
||||
# calls after internalization will be influenced by internalized info
|
||||
model.internalize(doc)
|
||||
|
||||
outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
print(tokenizer.decode(outputs[0]))
|
||||
|
||||
|
||||
# remove internalized info
|
||||
# model.reset()
|
||||
|
||||
# without internalized info, the model will halucinate
|
||||
# outputs = model.generate(input_ids=chat_ids, max_new_tokens=512)
|
||||
# print(tokenizer.decode(outputs[0]))
|
||||
24
install.sh
Executable file
24
install.sh
Executable file
|
|
@ -0,0 +1,24 @@
|
|||
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
uv self update
|
||||
uv venv --python 3.10 --seed
|
||||
uv pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --torch-backend=cu124
|
||||
uv sync
|
||||
uv pip install tokenizers==0.21.0
|
||||
uv pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
|
||||
uv pip install flashinfer-python==0.2.2 -i https://flashinfer.ai/whl/cu124/torch2.6
|
||||
|
||||
# download squad dataset
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 uv run huggingface-cli download --repo-type dataset rajpurkar/squad --local-dir data/raw_datasets/squad
|
||||
uv run data/build_drop_compact.py
|
||||
uv run data/build_pwc_compact.py
|
||||
uv run data/build_ropes_compact.py
|
||||
uv run data/build_squad_compact.py
|
||||
|
||||
# optional: needed for gated models
|
||||
# uv run huggingface-cli login
|
||||
|
||||
# optional: needed for logging with wandb
|
||||
# wandb login
|
||||
|
||||
# optional: dev
|
||||
# uv run pre-commit install
|
||||
86
pyproject.toml
Normal file
86
pyproject.toml
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
[project]
|
||||
name = "ctx-to-lora"
|
||||
version = "0.0.1"
|
||||
authors = [{ name = "Rujikorn Charakorn" }]
|
||||
description = ""
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.10"
|
||||
dependencies = [
|
||||
"transformers==4.51.3",
|
||||
"deepspeed==0.17.1",
|
||||
"accelerate==1.6.0",
|
||||
"datasets==3.6.0",
|
||||
"setuptools",
|
||||
"peft",
|
||||
"jupyter",
|
||||
"matplotlib",
|
||||
"hf_transfer",
|
||||
"torchmetrics",
|
||||
"inflect",
|
||||
"pre-commit",
|
||||
"tensorboardX",
|
||||
"wandb",
|
||||
"fasttext-wheel",
|
||||
"einops",
|
||||
"jaxtyping",
|
||||
"liger-kernel",
|
||||
"tensorboard",
|
||||
"flask",
|
||||
"gradio>=4.40.0",
|
||||
"pandas",
|
||||
"plotly",
|
||||
"rouge-score",
|
||||
"vllm==0.8.5.post1",
|
||||
"huggingface-hub[hf-transfer]>=0.32.0",
|
||||
"opt-einsum>=3.4.0",
|
||||
"kagglehub[hf-datasets]>=0.3.12",
|
||||
"kaggle>=1.7.4.5",
|
||||
"bitsandbytes>=0.46.1",
|
||||
"google-cloud-storage>=3.2.0",
|
||||
"wonderwords>=2.2.0",
|
||||
"llmlingua>=0.2.2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=61.0"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.pyright]
|
||||
exclude = [
|
||||
"**/node_modules",
|
||||
"**/__pycache__",
|
||||
"**/.*",
|
||||
".venv",
|
||||
".github",
|
||||
".vscode",
|
||||
"chat_templates",
|
||||
"eval_results",
|
||||
"configs",
|
||||
"EditingLlama",
|
||||
"icae_v2",
|
||||
"lm-evaluation-harness",
|
||||
"llm-comparator",
|
||||
"LongBench",
|
||||
"scripts",
|
||||
"train_outputs",
|
||||
"./data/",
|
||||
"/data/",
|
||||
"generated_tasks",
|
||||
"outputs",
|
||||
"plots",
|
||||
"tmp",
|
||||
"wandb",
|
||||
".wandb",
|
||||
".ruff_cache",
|
||||
"assets",
|
||||
]
|
||||
typeCheckingMode = "off"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 88
|
||||
select = ["F401"] # remove unused imports
|
||||
ignore = ["E", "F"]
|
||||
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
known_local_folder = ["ctx_to_lora"]
|
||||
180
run_eval.py
Normal file
180
run_eval.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
import logging
|
||||
|
||||
from ctx_to_lora.eval_utils import run_eval
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Evaluate a checkpoint")
|
||||
parser.add_argument(
|
||||
"--model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Evaluate a base model from HuggingFace Hub, without loading checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to the checkpoint to evaluate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
choices=["validation", "test"],
|
||||
default="validation",
|
||||
help="Which split to evaluate on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--datasets",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help=(
|
||||
"Specific datasets to evaluate on."
|
||||
"If not provided, uses default from args.yaml"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Eval batch size for teacher forcing",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size_gen",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Eval batch size for generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_val_samples_per_ds",
|
||||
type=int,
|
||||
default=-1,
|
||||
help=(
|
||||
"Maximum number of validation samples per dataset. "
|
||||
"If -1, uses values from checkpoint config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_test_samples_per_ds",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Maximum number of validation samples per dataset. "
|
||||
"If -1, uses values from checkpoint config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_ctx_chunk_len",
|
||||
type=int,
|
||||
default=-1,
|
||||
help="Maximum length of context chunk for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of new tokens to generate during evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_context",
|
||||
action="store_true",
|
||||
help="Remove context when evaluating the base model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cd",
|
||||
action="store_true",
|
||||
help="Use context distillation model for evaluation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cd_update_iterations",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of update iterations for context distillation during evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cd_use_gen_q",
|
||||
action="store_true",
|
||||
help="Use generated queries for context distillation training.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--q_gen_rounds",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of rounds of query generation for context distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cd_batch_size",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Batch size for context distillation.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_iterative_mode",
|
||||
action="store_true",
|
||||
help="Use iterative mode LoRA layer-by-layer generation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_llmlingua",
|
||||
action="store_true",
|
||||
help="Use LLMLingua compression for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llmlingua_compression_rate",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="Compression rate for LLMLingua",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_t2l",
|
||||
action="store_true",
|
||||
help="Use Text-to-LoRA model for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_ctx_to_input",
|
||||
action="store_true",
|
||||
help="Add ctx to base model's input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--truncate_if_too_long_inp",
|
||||
action="store_true",
|
||||
help="Truncate input sequences that are too long",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--truncate_if_too_long_ctx",
|
||||
action="store_true",
|
||||
help="Truncate ctx sequences that are too long",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_lora_scaling",
|
||||
type=float,
|
||||
default=1.0,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flip_ctx_inp",
|
||||
action="store_true",
|
||||
help="Flip the order of context and input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_generative_adapter",
|
||||
action="store_true",
|
||||
help="Use generative adapter for evaluation",
|
||||
)
|
||||
|
||||
cli_args = vars(parser.parse_args())
|
||||
|
||||
if cli_args["model_name_or_path"]:
|
||||
assert cli_args["max_ctx_chunk_len"] <= 0, (
|
||||
f"Evaluating base model shouldn't be used with `max_ctx_chunk_len`"
|
||||
)
|
||||
|
||||
eval_batch_size_gen = cli_args.pop("eval_batch_size_gen")
|
||||
eval_batch_size = cli_args.pop("eval_batch_size")
|
||||
run_eval(
|
||||
**cli_args,
|
||||
eval_batch_size=eval_batch_size_gen,
|
||||
generative=True,
|
||||
)
|
||||
17
scripts/main_exp/0-download_data.py
Normal file
17
scripts/main_exp/0-download_data.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from huggingface_hub import snapshot_download
|
||||
|
||||
if __name__ == "__main__":
|
||||
self_gen_data_dir = "./data/raw_datasets/self_gen/"
|
||||
snapshot_download(
|
||||
"SakanaAI/self_gen_qa_d2l",
|
||||
repo_type="dataset",
|
||||
local_dir=self_gen_data_dir,
|
||||
# we can filter based on model by using the `allow_patterns` argument
|
||||
# based on https://huggingface.co/datasets/SakanaAI/self_gen_qa_d2l/tree/main
|
||||
# we can use
|
||||
# - `Qwen` for downloading the data for `Qwen/Qwen3-4B-Instruct-2507`
|
||||
# - `google` for downloading the data for `google/gemma-2-2b-it`
|
||||
# - `mistralai` for downloading the data for `mistralai/Mistral-7B-Instruct-v0.2`
|
||||
#
|
||||
# allow_patterns="google/*", # downloading the data for `google/gemma-2-2b-it`
|
||||
)
|
||||
14
scripts/main_exp/1-train.sh
Executable file
14
scripts/main_exp/1-train.sh
Executable file
|
|
@ -0,0 +1,14 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True
|
||||
30
scripts/main_exp/2-train-chunk.sh
Executable file
30
scripts/main_exp/2-train-chunk.sh
Executable file
|
|
@ -0,0 +1,30 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj \
|
||||
--lora_r=8 \
|
||||
--eval_strategy=no \
|
||||
--max_qas_len=512 \
|
||||
--max_qas_per_sample=1 \
|
||||
--per_rank_gen=True \
|
||||
--per_layer_processing=True \
|
||||
--gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=20000 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--max_packed_inp_len=1024 \
|
||||
--max_packed_ctx_len=2048 \
|
||||
--use_per_ctx_average_loss=True \
|
||||
--use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True \
|
||||
--torch_empty_cache_steps=10 \
|
||||
--from_pretrained_checkpoint=train_outputs/runs/$RUN_NAME/checkpoint-80000/pytorch_model.bin \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--warmup_steps=2000 \
|
||||
--learning_rate=2e-5
|
||||
25
scripts/main_exp/README.md
Normal file
25
scripts/main_exp/README.md
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
# D2L pipeline
|
||||
### Data
|
||||
You can either download the generated data (recommended, ~100 GB for each model) or generate them by youself.
|
||||
Please see [`0-download_data.sh`](0-download_data.sh) for how to do model-specific data download.
|
||||
```bash
|
||||
# download training data for all three models (328GB)
|
||||
uv run bash scripts/main_exp/0-download_data.sh
|
||||
```
|
||||
|
||||
Generating data from scratch can take very long if not parallelized across multiple gpus.
|
||||
```bash
|
||||
# generate training data (takes very long if not parallelized across multiple gpus)
|
||||
# optional: use the command below for generating data from scratch
|
||||
# uv run bash scripts/main_exp/gen_data.sh
|
||||
```
|
||||
|
||||
### Training
|
||||
Simply run the training script once the data is ready.
|
||||
```bash
|
||||
# train
|
||||
uv run bash scripts/main_exp/1-train.sh
|
||||
```
|
||||
|
||||
### Evaluation
|
||||
All evaluation scripts for reproducing the main results in the paper are included in [eval](eval/) directory.
|
||||
8
scripts/main_exp/eval/base_model.sh
Executable file
8
scripts/main_exp/eval/base_model.sh
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
# no truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1
|
||||
|
||||
# w/ truncation
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --truncate_if_too_long_inp
|
||||
|
||||
# no context
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen 1 --remove_context
|
||||
6
scripts/main_exp/eval/cd.sh
Executable file
6
scripts/main_exp/eval/cd.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=4
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=1
|
||||
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
1
scripts/main_exp/eval/cd_minibatch.sh
Normal file
|
|
@ -0,0 +1 @@
|
|||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets $1 --split test --use_cd --cd_update_iterations 50 --eval_batch_size_gen=1 --truncate_if_too_long_inp --cd_use_gen_q --q_gen_rounds=5 --cd_batch_size=2
|
||||
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
6
scripts/main_exp/eval/cd_oracle.sh
Executable file
|
|
@ -0,0 +1,6 @@
|
|||
# qa
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
|
||||
|
||||
# longbench
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets longbench/multifieldqa_en_e longbench/2wikimqa_e longbench/qasper_e --split test --use_cd --cd_update_iterations 300 --eval_batch_size_gen=1 --truncate_if_too_long_inp
|
||||
13
scripts/main_exp/eval/d2l.sh
Executable file
13
scripts/main_exp/eval/d2l.sh
Executable file
|
|
@ -0,0 +1,13 @@
|
|||
# main results
|
||||
# batched
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1
|
||||
|
||||
# iterative
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --max_ctx_chunk_len 8192 --eval_batch_size_gen 1 --use_iterative_mode
|
||||
|
||||
# query internalization
|
||||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad --split test --eval_batch_size_gen=1 --flip_ctx_inp
|
||||
|
||||
# replaced squad context
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_assistant_ctx_no_passage --split test
|
||||
WANDB_MODE=disabled uv run python run_eval.py --checkpoint_path train_outputs/runs/$RUN_NAME/checkpoint-$step/pytorch_model.bin --datasets squad_negative_no_passage --split test
|
||||
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
279
scripts/main_exp/eval/imagenette_eval.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
from argparse import Namespace
|
||||
from difflib import SequenceMatcher
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from tqdm import tqdm
|
||||
from transformers import AutoProcessor, Gemma3ForConditionalGeneration
|
||||
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.modeling.ctx_encoder import PerLayerActivations
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers
|
||||
from ctx_to_lora.modeling.lora_merger import combine_lora
|
||||
|
||||
CLASS_NAMES = [
|
||||
"tench",
|
||||
"English springer",
|
||||
"cassette player",
|
||||
"chain saw",
|
||||
"church",
|
||||
"French horn",
|
||||
"garbage truck",
|
||||
"gas pump",
|
||||
"golf ball",
|
||||
"parachute",
|
||||
]
|
||||
CLASS_TO_INT = {name: i for i, name in enumerate(CLASS_NAMES)}
|
||||
INPUT_TXT = f"What is in this image? Choose exactly one of the following classes: {', '.join(CLASS_NAMES)}. Response with only the correct class without any other text."
|
||||
RUN_DIR = "train_outputs/runs/Oct16_02-37-04_slurm0-a3nodeset-8_94074_1d62ecb8"
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
text = re.sub(r"[^a-z0-9\s]", " ", text.lower())
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
def _normalize_compact(text: str) -> str:
|
||||
return re.sub(r"[^a-z0-9]", "", text.lower())
|
||||
|
||||
|
||||
def _build_alias_map():
|
||||
alias_overrides = {
|
||||
"english springer spaniel": "English springer",
|
||||
"springer spaniel": "English springer",
|
||||
"chainsaw": "chain saw",
|
||||
"dump truck": "garbage truck",
|
||||
"refuse truck": "garbage truck",
|
||||
"garbage lorry": "garbage truck",
|
||||
"fuel pump": "gas pump",
|
||||
"gas station pump": "gas pump",
|
||||
"cassette deck": "cassette player",
|
||||
"cassette recorder": "cassette player",
|
||||
"fish": "tench",
|
||||
"tench fish": "tench",
|
||||
"french horn instrument": "French horn",
|
||||
"golfball": "golf ball",
|
||||
"skydiving": "parachute",
|
||||
"parachutist": "parachute",
|
||||
}
|
||||
|
||||
alias_map = {}
|
||||
|
||||
def register(alias: str, canonical: str):
|
||||
alias = _normalize_text(alias)
|
||||
if alias:
|
||||
alias_map[alias] = canonical
|
||||
alias_map[_normalize_compact(alias)] = canonical
|
||||
|
||||
for name in CLASS_NAMES:
|
||||
register(name, name)
|
||||
register(name.replace(" ", ""), name)
|
||||
register(name.replace(" ", "-"), name)
|
||||
|
||||
for alias, canonical in alias_overrides.items():
|
||||
register(alias, canonical)
|
||||
|
||||
return alias_map
|
||||
|
||||
|
||||
CLASS_ALIAS_MAP = _build_alias_map()
|
||||
|
||||
|
||||
def pred_to_class_id(pred_txt: str) -> int:
|
||||
norm_pred = _normalize_text(pred_txt)
|
||||
compact_pred = _normalize_compact(pred_txt)
|
||||
|
||||
for alias, canonical in CLASS_ALIAS_MAP.items():
|
||||
if alias and (alias in norm_pred or alias in compact_pred):
|
||||
return CLASS_TO_INT[canonical]
|
||||
|
||||
pred_tokens = set(norm_pred.split())
|
||||
best_class = None
|
||||
best_token_hits = -1
|
||||
for name in CLASS_NAMES:
|
||||
class_tokens = set(_normalize_text(name).split())
|
||||
if class_tokens and class_tokens.issubset(pred_tokens):
|
||||
return CLASS_TO_INT[name]
|
||||
hits = sum(token in pred_tokens for token in class_tokens)
|
||||
if hits > best_token_hits:
|
||||
best_token_hits = hits
|
||||
best_class = name
|
||||
|
||||
best_ratio = -1.0
|
||||
for name in CLASS_NAMES:
|
||||
ratio = SequenceMatcher(None, norm_pred, _normalize_text(name)).ratio()
|
||||
if ratio > best_ratio:
|
||||
best_ratio = ratio
|
||||
best_class = name
|
||||
|
||||
return CLASS_TO_INT[best_class]
|
||||
|
||||
|
||||
def load_checkpoint():
|
||||
checkpoint_path = f"{RUN_DIR}/checkpoint-80000/pytorch_model.bin"
|
||||
state_dict = torch.load(checkpoint_path)
|
||||
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
state_dict,
|
||||
train=False,
|
||||
base_model_kwargs=dict(attn_implementation="flash_attention_2"),
|
||||
use_flash_attn=True,
|
||||
use_sequence_packing=False, # for generation
|
||||
)
|
||||
tokenizer = get_tokenizer("google/gemma-2-2b-it")
|
||||
model.eval()
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def load_ctx_encoder():
|
||||
model_id = "google/gemma-3-4b-it"
|
||||
ctx_model = Gemma3ForConditionalGeneration.from_pretrained(
|
||||
model_id, device_map="auto"
|
||||
).eval()
|
||||
ctx_encoder_config = Namespace(ctx_encoder_last_layer=26, keep_lm_head=True)
|
||||
ctx_model.language_model = PerLayerActivations(
|
||||
ctx_model.language_model, ctx_encoder_config
|
||||
)
|
||||
processor = AutoProcessor.from_pretrained(model_id)
|
||||
return ctx_model, processor
|
||||
|
||||
|
||||
def template_image(img, ctx_processor):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": [{"type": "text", "text": ""}],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image", "image": img},
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
inputs = ctx_processor.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=True,
|
||||
tokenize=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs
|
||||
|
||||
|
||||
@torch.inference_mode
|
||||
def get_ctx_features(ctx_inputs, ctx_encoder):
|
||||
forward_outputs = ctx_encoder(**ctx_inputs, output_hidden_states=True)
|
||||
ctx_features = torch.stack(forward_outputs.hidden_states, dim=1)
|
||||
return ctx_features
|
||||
|
||||
|
||||
def generate_loras(ctx_inputs, ctx_features):
|
||||
generated_loras, _ = model.hypernet.generate_weights(
|
||||
ctx_features, attn_mask=torch.ones_like(ctx_inputs["input_ids"])
|
||||
)
|
||||
generated_loras = combine_lora(
|
||||
generated_loras,
|
||||
n_chunks=torch.tensor((1,), device=model.device),
|
||||
lora_bias=model.hypernet.get_head_bias()
|
||||
if model.hypernet.config.use_bias
|
||||
else None,
|
||||
)
|
||||
return generated_loras
|
||||
|
||||
|
||||
def apply_loras(model, generated_loras):
|
||||
n_queries = torch.ones(1, dtype=torch.int32, device=model.device)
|
||||
|
||||
apply_lora_to_layers(
|
||||
model.base_model,
|
||||
model.hypernet.layer_indices,
|
||||
generated_loras,
|
||||
n_queries,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model, base_tokenizer = load_checkpoint()
|
||||
ctx_encoder, ctx_processor = load_ctx_encoder()
|
||||
ds = load_dataset("frgfm/imagenette", "full_size", split="validation")
|
||||
# ds = ds.shuffle().select(range(int(0.05 * len(ds))))
|
||||
input_ids = base_tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": INPUT_TXT}],
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
).to(model.device)
|
||||
|
||||
preds = []
|
||||
pred_txts = []
|
||||
corrects = []
|
||||
labels = ds["label"]
|
||||
for sample in tqdm(ds):
|
||||
img = sample["image"]
|
||||
ctx_inputs = template_image(img, ctx_processor).to(ctx_encoder.device)
|
||||
ctx_features = get_ctx_features(ctx_inputs, ctx_encoder)
|
||||
generated_loras = generate_loras(ctx_inputs, ctx_features)
|
||||
apply_loras(model, generated_loras)
|
||||
|
||||
model_outputs = model.base_model.generate(
|
||||
input_ids, max_new_tokens=256, do_sample=False
|
||||
)
|
||||
pred_txt = base_tokenizer.decode(
|
||||
model_outputs[0][len(input_ids[0]) :], skip_special_tokens=True
|
||||
)
|
||||
|
||||
pred_txts.append(pred_txt)
|
||||
preds.append(pred_to_class_id(pred_txt))
|
||||
is_correct = preds[-1] == labels[len(preds) - 1]
|
||||
corrects.append(is_correct)
|
||||
print(
|
||||
f"GT: {CLASS_NAMES[labels[len(preds) - 1]]}, Pred: {pred_txt} -> {CLASS_NAMES[preds[-1]]}, Correct: {is_correct}"
|
||||
)
|
||||
|
||||
acc = sum(corrects) / len(corrects)
|
||||
print(f"Final accuracy: {acc:4f}")
|
||||
|
||||
jsonl_path = os.path.join(RUN_DIR, "imagenette_eval.jsonl")
|
||||
meta_path = os.path.join(RUN_DIR, "imagenette_eval.meta.json")
|
||||
|
||||
with open(jsonl_path, "w") as f:
|
||||
for i, (pred_txt, pred_id, label_id) in enumerate(
|
||||
zip(pred_txts, preds, labels)
|
||||
):
|
||||
f.write(
|
||||
json.dumps(
|
||||
{
|
||||
"index": i,
|
||||
"label": int(label_id),
|
||||
"label_name": CLASS_NAMES[label_id],
|
||||
"pred_text": pred_txt,
|
||||
"pred_class_id": int(pred_id),
|
||||
"pred_class_name": CLASS_NAMES[pred_id],
|
||||
"correct": bool(pred_id == label_id),
|
||||
}
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
meta = {
|
||||
"dataset": "frgfm/imagenette",
|
||||
"subset": "full_size",
|
||||
"split": "validation",
|
||||
"run_dir": RUN_DIR,
|
||||
"prompt": INPUT_TXT,
|
||||
"accuracy": float(acc),
|
||||
"num_samples": len(preds),
|
||||
"class_names": CLASS_NAMES,
|
||||
}
|
||||
with open(meta_path, "w") as f:
|
||||
json.dump(meta, f, indent=2)
|
||||
|
||||
print(f"Wrote samples to {jsonl_path}")
|
||||
print(f"Wrote metadata to {meta_path}")
|
||||
5
scripts/main_exp/eval/llmlingua.sh
Executable file
5
scripts/main_exp/eval/llmlingua.sh
Executable file
|
|
@ -0,0 +1,5 @@
|
|||
for dataset in squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e; do
|
||||
for rate in 0.9 0.8 0.6 0.4 0.2 0.1; do
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets "$dataset" --split test --eval_batch_size_gen=1 --use_llmlingua --llmlingua_compression_rate "$rate" --truncate_if_too_long_ctx
|
||||
done
|
||||
done
|
||||
4
scripts/main_exp/eval/t2l.sh
Executable file
4
scripts/main_exp/eval/t2l.sh
Executable file
|
|
@ -0,0 +1,4 @@
|
|||
# download t2l checkpoint
|
||||
uv run huggingface-cli download SakanaAI/text-to-lora --local-dir . --include "trained_t2l/gemma_2b_t2l"
|
||||
|
||||
WANDB_MODE=disabled uv run run_eval.py --model_name_or_path google/gemma-2-2b-it --datasets squad drop ropes longbench/qasper_e longbench/2wikimqa_e longbench/multifieldqa_en_e --split test --eval_batch_size_gen=1 --use_t2l
|
||||
20
scripts/main_exp/gen_data.sh
Executable file
20
scripts/main_exp/gen_data.sh
Executable file
|
|
@ -0,0 +1,20 @@
|
|||
# download fineweb_edu to `data/raw_datasets/fineweb_edu
|
||||
uv run data/download_fineweb_edu.py
|
||||
|
||||
# generate qa data
|
||||
# run from 000 to 013
|
||||
for shard_id in $(seq -f "%03g" 0 13); do
|
||||
uv run data/generate_fw_edu_qa_v2.py --shard_pattern "${shard_id}_00000" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it --max_length=2000 --max_model_length=2048
|
||||
uv run data/generate_fw_edu_qa_v2_repeat.py --shard_pattern "min_0_to_2000/${shard_id}*level_0" --n_qa_pairs=5 --vllm_model=google/gemma-3-12b-it
|
||||
|
||||
# self-generated response QA data
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern "data/raw_datasets/fw_qa_v2/min_0_to_2000/${shard_id}*_level_1*" --closed_qa_prob 1.0
|
||||
done
|
||||
|
||||
|
||||
# val split
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --glob_pattern 'data/raw_datasets/fw_qa_v2/min_0_to_2000/*_level_0_val.parquet'
|
||||
|
||||
# self-gen data for other ds
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names squad_compact ropes_compact drop_compact --split train --closed_qa_prob 1.0
|
||||
uv run data/self_generate_qa.py --vllm_model google/gemma-2-2b-it --ds_names pwc_compact --split train --closed_qa_prob 0.0
|
||||
29
scripts/main_exp/train-cross-enc-chunk-slurm.sh
Normal file
29
scripts/main_exp/train-cross-enc-chunk-slurm.sh
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=ctxlora
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --partition=a3
|
||||
#SBATCH --gpus=8
|
||||
#SBATCH --output=slurm_logs/%x-%j.out
|
||||
#SBATCH --error=slurm_logs/%x-%j.out
|
||||
|
||||
port=$((10000 + ($SLURM_JOBID % 50000)))
|
||||
echo "Using port: $port"
|
||||
|
||||
# port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=20000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--warmup_steps=2000 \
|
||||
--learning_rate=2e-5 \
|
||||
"$@"
|
||||
24
scripts/main_exp/train-cross-enc-slurm.sh
Normal file
24
scripts/main_exp/train-cross-enc-slurm.sh
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
#!/bin/bash
|
||||
#SBATCH --job-name=ctxlora
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --partition=a3
|
||||
#SBATCH --gpus=8
|
||||
#SBATCH --output=slurm_logs/%x-%j.out
|
||||
#SBATCH --error=slurm_logs/%x-%j.out
|
||||
|
||||
port=$((10000 + ($SLURM_JOBID % 50000)))
|
||||
echo "Using port: $port"
|
||||
|
||||
# port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True --ctx_encoder_model_name_or_path=google/gemma-3-4b-it \
|
||||
"$@"
|
||||
15
scripts/main_exp/train_no_qa.sh
Executable file
15
scripts/main_exp/train_no_qa.sh
Executable file
|
|
@ -0,0 +1,15 @@
|
|||
#!/bin/bash
|
||||
|
||||
port=29051
|
||||
|
||||
uv run accelerate launch --config_file accelerate_config.yaml --main_process_port $port \
|
||||
--num_processes=8 --gpu_ids all train.py \
|
||||
configs/main_exp/self_gen_lv1_closed_qa_1_no_qa_l2l.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--target_modules=down_proj --lora_r=8 \
|
||||
--eval_strategy=no --max_qas_len=2048 --max_qas_per_sample=1 \
|
||||
--per_rank_gen=True --per_layer_processing=True --gen_lora_l1_reg_coef=0.1 \
|
||||
--max_steps=80000 --gradient_accumulation_steps=8 --max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 --use_per_ctx_average_loss=True --use_kl_loss=True \
|
||||
--quantize_ctx_encoder=True \
|
||||
"$@"
|
||||
1
scripts/niah/0-gen_data.sh
Executable file
1
scripts/niah/0-gen_data.sh
Executable file
|
|
@ -0,0 +1 @@
|
|||
uv run data/generate_ctx_magic_number.py
|
||||
42
scripts/niah/1-train.sh
Executable file
42
scripts/niah/1-train.sh
Executable file
|
|
@ -0,0 +1,42 @@
|
|||
#!/bin/bash
|
||||
WANDB_MODE=disabled uv run train.py \
|
||||
configs/niah_exp/ctx_magic_number_32_256.yaml \
|
||||
--model_name_or_path=google/gemma-2-2b-it \
|
||||
--num_train_epochs=1 \
|
||||
--per_device_train_batch_size=-1 \
|
||||
--gradient_accumulation_steps=16 \
|
||||
--per_device_eval_batch_size=16 \
|
||||
--exp_setup=hyper_lora \
|
||||
--aggregator_type=perceiver \
|
||||
--target_modules=down_proj \
|
||||
--num_blocks=8 \
|
||||
--num_self_attn_per_block=0 \
|
||||
--num_pre_head_layers=1 \
|
||||
--lora_r=8 \
|
||||
--eval_steps=100 \
|
||||
--logging_steps=10 \
|
||||
--save_steps=1000 \
|
||||
--learning_rate=4e-5 \
|
||||
--lora_dropout=0.0 \
|
||||
--neftune_noise_alpha=0 \
|
||||
--per_rank_gen=True \
|
||||
--per_layer_processing=True \
|
||||
--gen_lora_l1_reg_coef=1.5 \
|
||||
--use_sequence_packing=True \
|
||||
--max_packed_inp_len=4096 \
|
||||
--max_packed_ctx_len=4096 \
|
||||
--dataloader_num_workers=0 \
|
||||
--dataloader_prefetch_factor=None \
|
||||
--eval_on_start=False \
|
||||
--ctx_encoder_type=early_exit \
|
||||
--n_latent_queries=208 \
|
||||
--use_kl_loss=False \
|
||||
--eval_on_start=True \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--max_val_samples_per_ds=100 \
|
||||
--seed=1 \
|
||||
--use_per_ctx_average_loss=True \
|
||||
--torch_empty_cache_steps=10 \
|
||||
"$@"
|
||||
1
scripts/niah/2-eval-test.sh
Executable file
1
scripts/niah/2-eval-test.sh
Executable file
|
|
@ -0,0 +1 @@
|
|||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path CHECKPOINT_PATH --datasets ctx_magic_number_32_1024 ctx_magic_number_1024_2048 ctx_magic_number_3072_4096 ctx_magic_number_7168_8192 ctx_magic_number_15360_16384 ctx_magic_number_28672_32768 ctx_magic_number_57344_65536 ctx_magic_number_122880_131072 --max_ctx_chunk_len=1024 --split test --eval_batch_size_gen=4
|
||||
1
scripts/niah/2-eval.sh
Executable file
1
scripts/niah/2-eval.sh
Executable file
|
|
@ -0,0 +1 @@
|
|||
WANDB_MODE=disabled uv run run_eval.py --checkpoint_path $CHECKPOINT_PATH --datasets ctx_magic_number_32_1024 ctx_magic_number_1024_2048 ctx_magic_number_2048_3072 ctx_magic_number_3072_4096 ctx_magic_number_4096_5120 ctx_magic_number_5120_6144 ctx_magic_number_6144_7168 ctx_magic_number_7168_8192 ctx_magic_number_8192_9216 ctx_magic_number_9216_10240 ctx_magic_number_10240_11264 ctx_magic_number_11264_12288 ctx_magic_number_12288_13312 ctx_magic_number_13312_14336 ctx_magic_number_14336_15360 ctx_magic_number_15360_16384 ctx_magic_number_16384_20480 ctx_magic_number_20480_24576 ctx_magic_number_24576_28672 ctx_magic_number_28672_32768 ctx_magic_number_32768_40960 ctx_magic_number_40960_49152 ctx_magic_number_49152_57344 ctx_magic_number_57344_65536 ctx_magic_number_65536_73728 ctx_magic_number_73728_81920 ctx_magic_number_81920_90112 ctx_magic_number_90112_98304 ctx_magic_number_98304_106496 ctx_magic_number_106496_114688 ctx_magic_number_114688_122880 ctx_magic_number_122880_131072 --max_ctx_chunk_len=1024 --split test --eval_batch_size_gen=4
|
||||
8
scripts/niah/README.md
Normal file
8
scripts/niah/README.md
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
# NIAH experiment
|
||||
```bash
|
||||
# run the scripts in this order
|
||||
# data generation is only needed to be run once
|
||||
uv run bash scripts/niah/0-gen_data.sh
|
||||
uv run bash scripts/niah/1-train.sh
|
||||
uv run bash scripts/niah/2-eval.sh
|
||||
```
|
||||
39
scripts/niah/train_mistral_7b.sh
Normal file
39
scripts/niah/train_mistral_7b.sh
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
WANDB_MODE=disabled run uv run train.py configs/niah_exp/ctx_magic_number_32_256.yaml \
|
||||
--model_name_or_path=mistralai/Mistral-7B-Instruct-v0.2 \
|
||||
--num_train_epochs=1 \
|
||||
--per_device_train_batch_size=-1 \
|
||||
--gradient_accumulation_steps=64 \
|
||||
--per_device_eval_batch_size=16 \
|
||||
--exp_setup=hyper_lora \
|
||||
--aggregator_type=perceiver \
|
||||
--target_modules=down_proj \
|
||||
--num_blocks=8 \
|
||||
--num_self_attn_per_block=0 \
|
||||
--num_pre_head_layers=1 \
|
||||
--lora_r=8 \
|
||||
--eval_steps=100 \
|
||||
--logging_steps=10 \
|
||||
--save_steps=1000 \
|
||||
--learning_rate=4e-5 \
|
||||
--lora_dropout=0.0 \
|
||||
--neftune_noise_alpha=0 \
|
||||
--per_rank_gen=True \
|
||||
--per_layer_processing=True \
|
||||
--gen_lora_l1_reg_coef=2.0 \
|
||||
--use_sequence_packing=True \
|
||||
--max_packed_inp_len=1024 \
|
||||
--max_packed_ctx_len=1024 \
|
||||
--dataloader_num_workers=0 \
|
||||
--dataloader_prefetch_factor=None \
|
||||
--eval_on_start=False \
|
||||
--ctx_encoder_type=early_exit \
|
||||
--n_latent_queries=208 \
|
||||
--use_kl_loss=False \
|
||||
--eval_on_start=True \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--max_val_samples_per_ds=100 \
|
||||
--seed=1 \
|
||||
--use_per_ctx_average_loss=True \
|
||||
--torch_empty_cache_steps=10
|
||||
39
scripts/niah/train_qwen3_4b.sh
Normal file
39
scripts/niah/train_qwen3_4b.sh
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
WANDB_MODE=disabled run uv run train.py configs/niah_exp/ctx_magic_number_32_256.yaml \
|
||||
--model_name_or_path=Qwen/Qwen3-4B-Instruct-2507 \
|
||||
--num_train_epochs=1 \
|
||||
--per_device_train_batch_size=-1 \
|
||||
--gradient_accumulation_steps=32 \
|
||||
--per_device_eval_batch_size=16 \
|
||||
--exp_setup=hyper_lora \
|
||||
--aggregator_type=perceiver \
|
||||
--target_modules=down_proj \
|
||||
--num_blocks=8 \
|
||||
--num_self_attn_per_block=0 \
|
||||
--num_pre_head_layers=1 \
|
||||
--lora_r=8 \
|
||||
--eval_steps=100 \
|
||||
--logging_steps=10 \
|
||||
--save_steps=1000 \
|
||||
--learning_rate=4e-5 \
|
||||
--lora_dropout=0.0 \
|
||||
--neftune_noise_alpha=0 \
|
||||
--per_rank_gen=True \
|
||||
--per_layer_processing=True \
|
||||
--gen_lora_l1_reg_coef=0.5 \
|
||||
--use_sequence_packing=True \
|
||||
--max_packed_inp_len=2048 \
|
||||
--max_packed_ctx_len=2048 \
|
||||
--dataloader_num_workers=0 \
|
||||
--dataloader_prefetch_factor=None \
|
||||
--eval_on_start=False \
|
||||
--ctx_encoder_type=early_exit \
|
||||
--n_latent_queries=208 \
|
||||
--use_kl_loss=False \
|
||||
--eval_on_start=True \
|
||||
--max_ctx_chunk_len=512 \
|
||||
--min_ctx_chunk_len=25 \
|
||||
--num_chunk_probs='{"1":"0.5", "2":"0.125", "3":"0.0625", "4":"0.0625", "5":"0.0625", "6":"0.0625", "7":"0.0625", "8":"0.0625"}' \
|
||||
--max_val_samples_per_ds=100 \
|
||||
--seed=1 \
|
||||
--use_per_ctx_average_loss=True \
|
||||
--torch_empty_cache_steps=10
|
||||
12
setup.py
Normal file
12
setup.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
# read the contents of the README file
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import setup
|
||||
|
||||
this_directory = Path(__file__).parent
|
||||
long_description = (this_directory / "README.md").read_text()
|
||||
|
||||
setup(
|
||||
long_description=long_description,
|
||||
long_description_content_type="text/markdown",
|
||||
)
|
||||
0
src/ctx_to_lora/__init__.py
Normal file
0
src/ctx_to_lora/__init__.py
Normal file
596
src/ctx_to_lora/configs.py
Normal file
596
src/ctx_to_lora/configs.py
Normal file
|
|
@ -0,0 +1,596 @@
|
|||
import dataclasses
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Literal, NewType
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from transformers import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
HfArgumentParser,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
||||
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
||||
|
||||
|
||||
DataClassType = NewType("DataClassType", Any)
|
||||
|
||||
|
||||
class ArgumentParser(HfArgumentParser):
|
||||
def parse_yaml_and_args(
|
||||
self, yaml_arg: str, other_args: list[str] | None = None
|
||||
) -> list[dataclass]:
|
||||
"""
|
||||
Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
|
||||
|
||||
Args:
|
||||
yaml_arg (`str`):
|
||||
The path to the config file used
|
||||
other_args (`List[str]`, *optional`):
|
||||
A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
|
||||
|
||||
Returns:
|
||||
[`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
|
||||
"""
|
||||
arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
|
||||
|
||||
outputs = []
|
||||
# strip other args list into dict of key-value pairs
|
||||
other_args = {
|
||||
arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args
|
||||
}
|
||||
used_args = {}
|
||||
|
||||
# overwrite the default/loaded value with the value provided to the command line
|
||||
# adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
|
||||
for data_yaml, data_class in zip(arg_list, self.dataclass_types):
|
||||
keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
|
||||
inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
|
||||
for arg, val in other_args.items():
|
||||
# add only if in keys
|
||||
if arg in keys:
|
||||
if val in ["None", "none", "null", "NULL"]:
|
||||
val = None
|
||||
inputs[arg] = val
|
||||
used_args[arg] = val
|
||||
continue
|
||||
base_type = data_yaml.__dataclass_fields__[arg].type
|
||||
inputs[arg] = val
|
||||
|
||||
# cast type for ints, floats (default to strings)
|
||||
if base_type in [int, float]:
|
||||
inputs[arg] = base_type(val)
|
||||
|
||||
if base_type == list[str]:
|
||||
inputs[arg] = [str(v) for v in val.split(",")]
|
||||
|
||||
# bool of a non-empty string is True, so we manually check for bools
|
||||
if base_type == bool:
|
||||
if val in ["true", "True"]:
|
||||
inputs[arg] = True
|
||||
else:
|
||||
inputs[arg] = False
|
||||
|
||||
if base_type == dict:
|
||||
inputs[arg] = yaml.load(val, Loader=yaml.FullLoader)
|
||||
|
||||
# add to used-args so we can check if double add
|
||||
if arg not in used_args:
|
||||
used_args[arg] = val
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Duplicate argument provided: {arg}, may cause unexpected behavior"
|
||||
)
|
||||
|
||||
obj = data_class(**inputs)
|
||||
outputs.append(obj)
|
||||
for arg in other_args:
|
||||
if arg not in used_args:
|
||||
raise ValueError(f"Argument provided not found in dataclass: {arg}")
|
||||
return outputs
|
||||
|
||||
def parse(self) -> DataClassType | tuple[DataClassType]:
|
||||
if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
|
||||
# If we pass only one argument to the script and it's the path to a YAML file,
|
||||
# let's parse it to get our arguments.
|
||||
output = self.parse_yaml_file(os.path.abspath(sys.argv[1].split("=")[-1]))
|
||||
# parse command line args and yaml file
|
||||
elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
|
||||
output = self.parse_yaml_and_args(
|
||||
os.path.abspath(sys.argv[1].split("=")[-1]), sys.argv[2:]
|
||||
)
|
||||
# parse --config for the yaml path and other command line args
|
||||
elif any([arg.startswith("--config") for arg in sys.argv]):
|
||||
yaml_arg = [
|
||||
arg
|
||||
for arg in sys.argv[1:]
|
||||
if arg.startswith("--config") and arg.endswith(".yaml")
|
||||
][0]
|
||||
other_args = [arg for arg in sys.argv[1:] if arg != yaml_arg]
|
||||
output = self.parse_yaml_and_args(
|
||||
os.path.abspath(yaml_arg.split("=")[-1]), other_args
|
||||
)
|
||||
# parse command line args only
|
||||
else:
|
||||
output = self.parse_args_into_dataclasses()
|
||||
|
||||
if len(output) == 1:
|
||||
output = output[0]
|
||||
return output
|
||||
|
||||
|
||||
class ExperimentSetup(str, Enum):
|
||||
HYPERLORA = "hyper_lora"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingArguments(TrainingArguments):
|
||||
output_dir: str = field(
|
||||
default="",
|
||||
metadata={"help": "Placeholder. Will be overwritten by train.py"},
|
||||
)
|
||||
tf32: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use tf32 precision."},
|
||||
)
|
||||
bf16: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use bf16 precision."},
|
||||
)
|
||||
label_names: list[str] = field(
|
||||
default=("labels",),
|
||||
metadata={
|
||||
"help": "List of strings to specify the label names in the dataset. "
|
||||
"This is used to compute the loss and metrics."
|
||||
},
|
||||
)
|
||||
include_for_metrics: list[str] = field(
|
||||
default=("inputs",),
|
||||
metadata={
|
||||
"help": "List of strings to specify additional data to include in the `compute_metrics` function."
|
||||
"Options: 'inputs', 'loss'."
|
||||
},
|
||||
)
|
||||
per_device_eval_batch_size: int = field(
|
||||
default=64,
|
||||
metadata={
|
||||
"help": "Batch size for evaluation. "
|
||||
"If not set, will use the same as per_device_train_batch_size."
|
||||
},
|
||||
)
|
||||
per_device_train_batch_size: int = field(
|
||||
default=1,
|
||||
metadata={
|
||||
"help": "Batch size for training. "
|
||||
"If not set, will use the same as per_device_eval_batch_size."
|
||||
},
|
||||
)
|
||||
# TODO: use this! (check trainer.py for proper computation)
|
||||
average_tokens_across_devices: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "compute num_items_in_batch across devices."},
|
||||
)
|
||||
# mem leak if use persistent workers
|
||||
# https://github.com/pytorch/pytorch/issues/62066
|
||||
# https://github.com/huggingface/transformers/issues/30943
|
||||
dataloader_persistent_workers: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to keep the workers alive after a dataset has been consumed once."
|
||||
},
|
||||
)
|
||||
dataloader_prefetch_factor: int = field(
|
||||
default=16,
|
||||
metadata={"help": "Number of batches loaded in advance by each worker."},
|
||||
)
|
||||
dataloader_num_workers: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of subprocesses to use for data loading."},
|
||||
)
|
||||
neftune_noise_alpha: float = field(
|
||||
default=5.0,
|
||||
metadata={"help": "Neftune noise alpha for the optimizer."},
|
||||
)
|
||||
learning_rate: float = field(
|
||||
default=4e-5,
|
||||
metadata={"help": "Initial learning rate."},
|
||||
)
|
||||
weight_decay: float = field(
|
||||
default=0.01,
|
||||
metadata={"help": "Weight decay for the optimizer."},
|
||||
)
|
||||
optim: str = field(
|
||||
default="adamw_torch_fused",
|
||||
metadata={"help": "Optimizer."},
|
||||
)
|
||||
adam_beta1: float = field(
|
||||
default=0.9,
|
||||
metadata={"help": "Adam beta 1."},
|
||||
)
|
||||
adam_beta2: float = field(
|
||||
default=0.999,
|
||||
metadata={"help": "Adam beta 2."},
|
||||
)
|
||||
adam_epsilon: float = field(
|
||||
default=1e-8,
|
||||
metadata={"help": "Adam epsilon."},
|
||||
)
|
||||
lr_scheduler_type: str = field(
|
||||
default="cosine_with_min_lr",
|
||||
metadata={"help": "Learning rate scheduler type."},
|
||||
)
|
||||
lr_scheduler_kwargs: dict = field(
|
||||
default=None,
|
||||
metadata={"help": "Learning rate scheduler kwargs."},
|
||||
)
|
||||
warmup_steps: int = field(
|
||||
default=100,
|
||||
metadata={"help": "Number of warmup steps."},
|
||||
)
|
||||
eval_on_start: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to evaluate on the start of training."},
|
||||
)
|
||||
eval_strategy: str = field(
|
||||
default="steps",
|
||||
metadata={"help": "Evaluation strategy."},
|
||||
)
|
||||
eval_steps: int = field(
|
||||
default=1_000,
|
||||
metadata={"help": "Evaluation steps."},
|
||||
)
|
||||
metric_for_best_model: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Metric for best model."},
|
||||
)
|
||||
load_best_model_at_end: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to load the best model at the end of training."},
|
||||
)
|
||||
save_total_limit: int = field(
|
||||
default=2,
|
||||
metadata={"help": "Total number of checkpoints to save."},
|
||||
)
|
||||
save_strategy: str = field(
|
||||
default="steps",
|
||||
)
|
||||
save_steps: int = field(
|
||||
default=5_000,
|
||||
)
|
||||
save_safetensors: bool = field(
|
||||
default=False,
|
||||
)
|
||||
logging_strategy: str = field(
|
||||
default="steps",
|
||||
)
|
||||
logging_steps: int = field(
|
||||
default=100,
|
||||
)
|
||||
use_liger_kernel: bool = field(
|
||||
default=False,
|
||||
)
|
||||
remove_unused_columns: bool = field(
|
||||
default=False,
|
||||
)
|
||||
# needed to avoid OOM by compute the metrics batch by batch
|
||||
# w/o this the trainer stores logits of all sample in memory...
|
||||
batch_eval_metrics: bool = field(
|
||||
default=True,
|
||||
)
|
||||
logging_first_step: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to log the first step."},
|
||||
)
|
||||
ddp_find_unused_parameters: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to find unused parameters in DDP."},
|
||||
)
|
||||
ddp_timeout: int = field(
|
||||
default=2**20,
|
||||
metadata={"help": "Timeout for distributed data parallel training."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments for the base model.
|
||||
"""
|
||||
|
||||
model_name_or_path: str = field(
|
||||
default=None,
|
||||
metadata={"help": ("Base model name or path.")},
|
||||
)
|
||||
use_flash_attn: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use flash attention."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAArguments:
|
||||
lora_r: int | None = field(
|
||||
default=8,
|
||||
metadata={"help": ("LoRA R value.")},
|
||||
)
|
||||
lora_dropout: float | None = field(
|
||||
default=0.0,
|
||||
metadata={"help": ("LoRA dropout.")},
|
||||
)
|
||||
target_modules: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": ("LoRA target modules.")},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CtxTrainingArguments:
|
||||
exp_setup: ExperimentSetup = field(
|
||||
default=ExperimentSetup.HYPERLORA,
|
||||
metadata={"help": "Experiment setup - LoRA, HyperLoRA, or full finetuning"},
|
||||
)
|
||||
from_pretrained_checkpoint: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the pretrained checkpoint."},
|
||||
)
|
||||
max_base_len: int | None = field(
|
||||
default=2**13,
|
||||
metadata={"help": "Maximum base length for training."},
|
||||
)
|
||||
use_sequence_packing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether to use sequence packing."},
|
||||
)
|
||||
max_ctx_len: int = field(
|
||||
default=-1,
|
||||
metadata={"help": "Max context length. Overrides ctx tokenizer length."},
|
||||
)
|
||||
max_qas_len: int = field(
|
||||
default=2**11,
|
||||
metadata={
|
||||
"help": "Maximum question-answering token length of each sample for training. "
|
||||
"QA pairs that are longer than this value will be split up into multiple samples."
|
||||
},
|
||||
)
|
||||
max_qas_per_sample: int = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": "Max QA pair per context. If a context has more QA pairs than this value, "
|
||||
"they will be split up into multiple samples."
|
||||
},
|
||||
)
|
||||
num_chunk_probs: dict = field(
|
||||
default=None,
|
||||
metadata={"help": "Probability distribution over chunk nums."},
|
||||
)
|
||||
max_ctx_chunk_len: int = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": "Max context chunk length. If a context is longer than this value, "
|
||||
"it will be split up into multiple chunks."
|
||||
},
|
||||
)
|
||||
min_ctx_chunk_len: int = field(
|
||||
default=-1,
|
||||
metadata={
|
||||
"help": "Min context chunk length. Used only with random chunking training"
|
||||
},
|
||||
)
|
||||
max_ctx_chunk_num: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Max number of context chunks per sample."},
|
||||
)
|
||||
max_packed_inp_len: int | None = field(
|
||||
default=2**14,
|
||||
metadata={"help": "Maximum packed input length for training."},
|
||||
)
|
||||
max_packed_ctx_len: int | None = field(
|
||||
# forward pass of the ctx encoder is cheaper --> longer packed len
|
||||
default=2**15,
|
||||
metadata={"help": "Maximum packed context length for training."},
|
||||
)
|
||||
|
||||
max_new_tokens: int | None = field(
|
||||
default=256,
|
||||
metadata={"help": "Maximum new tokens for generation-based evaluation."},
|
||||
)
|
||||
gen_per_device_eval_batch_size: int | None = field(
|
||||
default=1,
|
||||
metadata={"help": "Per device evaluation batch size for generation."},
|
||||
)
|
||||
notes: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Wandb notes for the experiment."},
|
||||
)
|
||||
use_kl_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use KL loss."},
|
||||
)
|
||||
use_per_ctx_average_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use per-context average loss."},
|
||||
)
|
||||
gen_lora_l1_reg_coef: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "L1 regularization coefficient for generated LoRAs."},
|
||||
)
|
||||
add_negative_prompt: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to add negative prompt training."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
train_ds_names: list[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Training dataset names."},
|
||||
)
|
||||
|
||||
streaming: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use streaming dataset for training."},
|
||||
)
|
||||
val_ds_names: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Validation dataset names."},
|
||||
)
|
||||
test_ds_names: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Test dataset names."},
|
||||
)
|
||||
max_train_samples_per_ds: int | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Maximum number of training samples per dataset."},
|
||||
)
|
||||
max_val_samples_per_ds: int | None = field(
|
||||
default=1000,
|
||||
metadata={"help": "Maximum number of validation samples per dataset."},
|
||||
)
|
||||
max_test_samples_per_ds: int | None = field(
|
||||
default=500,
|
||||
metadata={"help": "Maximum number of test samples per dataset."},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypernetArguments:
|
||||
latent_size: int = field(
|
||||
default=512,
|
||||
metadata={"help": "Latent size for HyperLoRA."},
|
||||
)
|
||||
use_light_weight_lora: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use light-weight LoRA."},
|
||||
)
|
||||
light_weight_latent_size: int = field(
|
||||
default=128,
|
||||
metadata={"help": "Latent size for light-weight LoRA."},
|
||||
)
|
||||
dropout_rate: float = field(
|
||||
default=0.0,
|
||||
metadata={"help": "Dropout rate for HyperLoRA."},
|
||||
)
|
||||
extra_modules: list[str] | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Extra modules to train."},
|
||||
)
|
||||
per_rank_gen: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use per-rank generation."},
|
||||
)
|
||||
use_bias: bool = field(
|
||||
default=True, metadata={"help": "Whether to include data-dependent LoRA"}
|
||||
)
|
||||
use_per_rank_bias: bool = field(
|
||||
default=False, metadata={"help": "Whether to use per-rank bias."}
|
||||
)
|
||||
per_layer_processing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use per-layer processing (after preceiver)."},
|
||||
)
|
||||
use_token_mixing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use token mixing block."},
|
||||
)
|
||||
num_pre_head_layers: int = field(
|
||||
default=1, metadata={"help": "# of layers before hypernet head"}
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CtxEncoderArguments:
|
||||
ctx_encoder_model_name_or_path: str = field(
|
||||
default=None,
|
||||
metadata={"help": "Context encoder model name or path."},
|
||||
)
|
||||
ctx_encoder_type: Literal["embed_only", "per_layer_activations", "early_exit"] = (
|
||||
field(
|
||||
default="early_exit",
|
||||
metadata={
|
||||
"help": "Context encoder type. "
|
||||
"Options: 'embed_only', 'per_layer_activations', 'early_exit'."
|
||||
},
|
||||
)
|
||||
)
|
||||
# used only with `early_exit` type
|
||||
layer_idx: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Layer index for context encoder. "
|
||||
"Default to L//4 where L is the number of layers of the ctx model. "
|
||||
"Only used when ctx_encoder_type==early_exit"
|
||||
},
|
||||
)
|
||||
quantize_ctx_encoder: bool = field(
|
||||
default=False, metadata={"help": "Wheter to quantize the ctx encoder."}
|
||||
)
|
||||
ctx_encoder_last_layer: int | None = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Maximum number of layers for the context encoder. "
|
||||
"Only used when ctx_encoder_type==per_layer_activations"
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorArguments:
|
||||
aggregator_type: Literal["pooler", "perceiver"] = field(
|
||||
default="perceiver",
|
||||
metadata={"help": "Aggregator type for HyperLoRA."},
|
||||
)
|
||||
|
||||
# pooler
|
||||
pooling_type: str = field(
|
||||
default="mean",
|
||||
metadata={"help": "Pooling type for HyperLoRA."},
|
||||
)
|
||||
num_latent_factor: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of latent factors for Perceiver."},
|
||||
)
|
||||
n_latent_queries: int = field(
|
||||
default=208, # 26 * 8
|
||||
metadata={"help": "Number of latent queries of Perceiver."},
|
||||
)
|
||||
|
||||
num_blocks: int = field(
|
||||
default=8,
|
||||
metadata={"help": "Number of blocks for Perceiver."},
|
||||
)
|
||||
num_self_attn_per_block: int = field(
|
||||
default=0,
|
||||
metadata={"help": "Number of self-attention layers per block for Perceiver."},
|
||||
)
|
||||
shared_weights: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to share weights across blocks for Perceiver."},
|
||||
)
|
||||
|
||||
|
||||
# needed for loading model from checkpoint
|
||||
# see https://github.com/huggingface/transformers/pull/34632
|
||||
torch.serialization.add_safe_globals(
|
||||
[
|
||||
DataArguments,
|
||||
CtxTrainingArguments,
|
||||
ModelArguments,
|
||||
LoRAArguments,
|
||||
TrainingArguments,
|
||||
HypernetArguments,
|
||||
AggregatorArguments,
|
||||
CtxEncoderArguments,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(ExperimentSetup)
|
||||
print(ExperimentSetup.LORA)
|
||||
print(ExperimentSetup.HYPER_LORA)
|
||||
print(ExperimentSetup.FULL_FINETUNE)
|
||||
0
src/ctx_to_lora/data/__init__.py
Normal file
0
src/ctx_to_lora/data/__init__.py
Normal file
145
src/ctx_to_lora/data/collator.py
Normal file
145
src/ctx_to_lora/data/collator.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from transformers.data import (
|
||||
DataCollatorWithFlattening,
|
||||
default_data_collator,
|
||||
)
|
||||
|
||||
from ctx_to_lora.utils import check_is_iterable, concat_list
|
||||
|
||||
flattener = DataCollatorWithFlattening()
|
||||
|
||||
|
||||
def flatten_if_not_packed(inp_list):
|
||||
# no padding
|
||||
sample = inp_list[0]
|
||||
n = len(inp_list)
|
||||
# training data is packed
|
||||
if "position_ids" in sample:
|
||||
if n == 1:
|
||||
n_queries = sample.pop("n_queries")
|
||||
n_ctx_chunks = sample.pop("n_ctx_chunks")
|
||||
batch = default_data_collator(inp_list, return_tensors="pt")
|
||||
batch["n_queries"] = torch.tensor(n_queries)
|
||||
batch["n_ctx_chunks"] = torch.tensor(n_ctx_chunks)
|
||||
return batch
|
||||
elif n > 1:
|
||||
raise NotImplementedError("Please use batch_size=1 when using packed data")
|
||||
# when batch_size > 1 (never used?)
|
||||
# return default_data_collator(concat_batch(inp_list), return_tensors="pt")
|
||||
|
||||
# for eval data (not packed) during training
|
||||
need_flatten = check_is_iterable(sample["input_ids"][0])
|
||||
assert not need_flatten, f"Validation data should not be nested."
|
||||
|
||||
n_queries = torch.ones(len(inp_list), dtype=torch.int32)
|
||||
n_ctx_chunks = torch.tensor(
|
||||
[len(example["ctx_ids"]) for example in inp_list], dtype=torch.int32
|
||||
)
|
||||
packed_inputs = flattener(inp_list, return_tensors="pt")
|
||||
|
||||
packed_inputs["n_queries"] = n_queries
|
||||
packed_inputs["n_ctx_chunks"] = n_ctx_chunks
|
||||
|
||||
if "ctx_ids" in sample:
|
||||
# HACK: assumes 1 ctx chunk here
|
||||
# assert all(len(ctx_ids) == 1 for ctx_ids in sample["ctx_ids"]), (
|
||||
# "ctx_ids can only have one chunk for eval. "
|
||||
# "Please implement chunked ctx forward pass to handle this."
|
||||
# )
|
||||
ctx_ids = concat_list([example.pop("ctx_ids") for example in inp_list])
|
||||
ctx_position_ids = torch.cat([torch.arange(len(ids)) for ids in ctx_ids])
|
||||
ctx_ids = torch.tensor(concat_list(ctx_ids))
|
||||
|
||||
packed_inputs["ctx_ids"] = ctx_ids.unsqueeze(0)
|
||||
packed_inputs["ctx_position_ids"] = ctx_position_ids.unsqueeze(0)
|
||||
# for eval info
|
||||
if "ctx_ids_len" in sample:
|
||||
packed_inputs["ctx_ids_len"] = [
|
||||
example["ctx_ids_len"] for example in inp_list
|
||||
]
|
||||
|
||||
return packed_inputs
|
||||
|
||||
|
||||
def eval_collator(inp_list, tokenizer):
|
||||
# only used for teacher-forcing eval
|
||||
# input is a list of tokenized sequences
|
||||
padding_kwargs = dict(padding=True, padding_side="right", return_tensors="pt")
|
||||
|
||||
has_ctx_ids = "ctx_ids" in inp_list[0]
|
||||
if has_ctx_ids:
|
||||
# pad to the longest ctx_len in the batch
|
||||
# which can have a different length from the input_ids, attn_mask, labels
|
||||
ctx_ids = [example.pop("ctx_ids") for example in inp_list]
|
||||
ctx_attn_mask = [torch.ones_like(x) for x in ctx_ids]
|
||||
ctx_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_ids,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_attn_mask,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
|
||||
for inp in inp_list:
|
||||
inp["attention_mask"] = torch.ones_like(inp["input_ids"])
|
||||
|
||||
labels = [x.pop("labels") for x in inp_list]
|
||||
# need to pass the whole inp bc we also track the lengths (with specal keys)
|
||||
padded_seq = tokenizer.pad(inp_list, **padding_kwargs)
|
||||
|
||||
# hacky explicit padding since the labels are not padded by default
|
||||
labels = tokenizer.pad({"input_ids": labels}, **padding_kwargs)["input_ids"]
|
||||
labels = torch.where(padded_seq["attention_mask"] == 0, -100, labels)
|
||||
out = {**padded_seq, "labels": labels}
|
||||
|
||||
if has_ctx_ids:
|
||||
out["ctx_ids"] = ctx_ids
|
||||
out["ctx_attn_mask"] = ctx_attn_mask
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def generation_collator(inp_list, tokenizer):
|
||||
padding_kwargs = dict(padding=True, padding_side="left", return_tensors="pt")
|
||||
input_ids = [torch.tensor(x.pop("input_ids")) for x in inp_list]
|
||||
labels = [x.pop("labels") for x in inp_list]
|
||||
for i, label in enumerate(labels):
|
||||
# we don't include the labels in the output during generation
|
||||
# remove the response tokens
|
||||
idx = np.argmax(label != -100)
|
||||
idx = max(1, idx)
|
||||
input_ids[i] = input_ids[i][:idx]
|
||||
attn_mask = [torch.ones_like(x) for x in input_ids]
|
||||
|
||||
out = tokenizer.pad(
|
||||
{"input_ids": input_ids, "attention_mask": attn_mask}, **padding_kwargs
|
||||
)
|
||||
|
||||
if "ctx_ids" in inp_list[0]:
|
||||
# pad to the longest ctx_len in the batch
|
||||
# which can have a different length from the input_ids, attn_mask, labels
|
||||
ctx_ids = [example.pop("ctx_ids") for example in inp_list]
|
||||
n_chunks = [len(x) for x in ctx_ids]
|
||||
ctx_ids = concat_list(ctx_ids)
|
||||
ctx_ids = [torch.tensor(x) for x in ctx_ids]
|
||||
ctx_attn_mask = [torch.ones_like(x) for x in ctx_ids]
|
||||
|
||||
ctx_ids = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_ids,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
ctx_attn_mask = torch.nn.utils.rnn.pad_sequence(
|
||||
ctx_attn_mask,
|
||||
batch_first=True,
|
||||
padding_value=0,
|
||||
)
|
||||
out["ctx_ids"] = ctx_ids
|
||||
out["ctx_attn_mask"] = ctx_attn_mask
|
||||
|
||||
out["n_ctx_chunks"] = torch.tensor(n_chunks, dtype=torch.int32)
|
||||
return out
|
||||
250
src/ctx_to_lora/data/definitions.py
Normal file
250
src/ctx_to_lora/data/definitions.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
IGNORE_INDEX = -100
|
||||
|
||||
TRANSFORMED_DATA_DIR = "data/processed_datasets"
|
||||
RAW_DATA_DIR = "data/raw_datasets/"
|
||||
SELF_GEN_DATA_DIR = f"{RAW_DATA_DIR}/self_gen/"
|
||||
|
||||
# for chunking
|
||||
CTX_AFFIXES = {
|
||||
"google/gemma-2-2b-it": {
|
||||
"prefix": [2, 106, 1645, 110], # <bos><start_of_turn>user\n\n\n
|
||||
"suffix": [107, 108, 106, 2516, 108], # <end_of_turn>\n<start_of_turn>model\n
|
||||
},
|
||||
"mistralai/Mistral-7B-Instruct-v0.2": {
|
||||
"prefix": [1, 733, 16289, 28793, 28705, 13, 13], # `<s> [INST] \n\n`
|
||||
"suffix": [733, 28748, 16289, 28793], # ` [/INST] `
|
||||
},
|
||||
"Qwen/Qwen3-4B-Instruct-2507": {
|
||||
# `<|im_start|>system\n<|im_end|>\n<|im_start|>user\n`
|
||||
"prefix": [151644, 8948, 198, 151645, 198, 151644, 872, 198],
|
||||
# `<|im_end|>\n<|im_start|>assistant\n`
|
||||
"suffix": [151645, 198, 151644, 77091, 198],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
LONGBENCH_TASKS = [
|
||||
"longbench/qasper",
|
||||
"longbench/multifieldqa_en",
|
||||
"longbench/2wikimqa",
|
||||
]
|
||||
|
||||
LONGBENCH_E_TASKS = [
|
||||
"longbench/qasper_e",
|
||||
"longbench/multifieldqa_en_e",
|
||||
"longbench/2wikimqa_e",
|
||||
]
|
||||
|
||||
DS_KWARGS = {
|
||||
"pwc": dict(
|
||||
train=dict(path="sggetao/PwC", split="train"),
|
||||
validation=dict(path="sggetao/PwC", split="test[:900]"),
|
||||
test=dict(path="sggetao/PwC", split="test"),
|
||||
),
|
||||
"pwc_compact": dict(
|
||||
train=dict(
|
||||
path="parquet",
|
||||
data_files="data/raw_datasets/pwc_compact/train/ds.parquet",
|
||||
split="train",
|
||||
),
|
||||
),
|
||||
"pwc_compact_tiny": dict(
|
||||
train=dict(
|
||||
path="parquet",
|
||||
data_files="data/raw_datasets/pwc_compact/train/ds.parquet",
|
||||
# correspond to skipping to first 900 samples in the original dataset
|
||||
split="train[60:200]",
|
||||
),
|
||||
),
|
||||
"pwc_tiny": dict(
|
||||
train=dict(path="sggetao/PwC", split="train[900:2000]"),
|
||||
validation=dict(path="sggetao/PwC", split="train[:900]"),
|
||||
),
|
||||
"squad": dict(
|
||||
train=dict(path="data/raw_datasets/squad", split="train"),
|
||||
validation=dict(path="data/raw_datasets/squad", split="validation[:1000]"),
|
||||
test=dict(path="data/raw_datasets/squad", split="validation"),
|
||||
),
|
||||
"squad_compact": dict(
|
||||
train=dict(
|
||||
path="parquet",
|
||||
data_files="data/raw_datasets/squad_compact/train/ds.parquet",
|
||||
# correspond to skipping to first 900 samples in the original dataset
|
||||
split="train[180:]",
|
||||
),
|
||||
),
|
||||
"squad_negative": dict(
|
||||
test=dict(path="data/raw_datasets/squad", split="validation"),
|
||||
),
|
||||
"squad_assistant_ctx": dict(
|
||||
test=dict(path="data/raw_datasets/squad", split="validation"),
|
||||
),
|
||||
"squad_negative_no_passage": dict(
|
||||
test=dict(path="data/raw_datasets/squad", split="validation"),
|
||||
),
|
||||
"squad_assistant_ctx_no_passage": dict(
|
||||
test=dict(path="data/raw_datasets/squad", split="validation"),
|
||||
),
|
||||
"drop": dict(
|
||||
train=dict(
|
||||
path="ucinlp/drop",
|
||||
split="train",
|
||||
),
|
||||
validation=dict(
|
||||
path="ucinlp/drop",
|
||||
split="validation[:900]",
|
||||
),
|
||||
test=dict(
|
||||
path="ucinlp/drop",
|
||||
split="validation",
|
||||
),
|
||||
),
|
||||
"drop_compact": dict(
|
||||
train=dict(
|
||||
path="parquet",
|
||||
data_files="data/raw_datasets/drop_compact/train/ds.parquet",
|
||||
# correspond to skipping to first 900 samples in the original dataset
|
||||
split="train",
|
||||
)
|
||||
),
|
||||
"ropes": dict(
|
||||
train=dict(
|
||||
path="allenai/ropes",
|
||||
split="train",
|
||||
),
|
||||
validation=dict(
|
||||
path="allenai/ropes",
|
||||
split="validation[:900]",
|
||||
),
|
||||
test=dict(path="allenai/ropes", split="validation"),
|
||||
),
|
||||
"ropes_compact": dict(
|
||||
train=dict(
|
||||
path="parquet",
|
||||
data_files="data/raw_datasets/ropes_compact/train/ds.parquet",
|
||||
split="train",
|
||||
)
|
||||
),
|
||||
}
|
||||
|
||||
# add ctx_numbers
|
||||
tok_bins = [(64, 128), (128, 256), (256, 512)] + [
|
||||
(512 + 256 * i, 512 + 256 * (i + 1)) for i in range(14)
|
||||
]
|
||||
tok_bins += [(32, 128), (128, 256), (256, 512), (512, 1024), (32, 1024)] + [
|
||||
(1024 * i, 1024 * (i + 1)) for i in range(1, 16)
|
||||
]
|
||||
tok_bins += [(2**14 + 2**12 * (i), 2**14 + 2**12 * (i + 1)) for i in range(4)]
|
||||
tok_bins += [(2**15 + 2**13 * (i), 2**15 + 2**13 * (i + 1)) for i in range(12)]
|
||||
for toy_ds_name in ["ctx_numbers", "ctx_kv", "ctx_magic_number"]:
|
||||
for tok_bin in tok_bins:
|
||||
DS_KWARGS[f"{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}"] = dict(
|
||||
train=dict(
|
||||
path="json",
|
||||
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/train.jsonl",
|
||||
split="train",
|
||||
),
|
||||
validation=dict(
|
||||
path="json",
|
||||
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/val.jsonl",
|
||||
split="train",
|
||||
),
|
||||
test=dict(
|
||||
path="json",
|
||||
data_files=f"data/raw_datasets/{toy_ds_name}_{tok_bin[0]}_{tok_bin[1]}/test.jsonl",
|
||||
split="train",
|
||||
),
|
||||
)
|
||||
|
||||
# LongBench kwargs
|
||||
for ds_name in LONGBENCH_TASKS + LONGBENCH_E_TASKS:
|
||||
DS_KWARGS[ds_name] = dict(
|
||||
test=dict(
|
||||
path="THUDM/LongBench",
|
||||
name=ds_name.split("/")[-1],
|
||||
split="test",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
CLOSED_QA_DATASETS = {
|
||||
"longbench/qasper",
|
||||
"longbench/multifieldqa_en",
|
||||
"longbench/2wikimqa",
|
||||
"squad",
|
||||
"squad_negative",
|
||||
"squad_assistant_ctx",
|
||||
"squad_negative_no_passage",
|
||||
"squad_assistant_ctx_no_passage",
|
||||
"ropes",
|
||||
"drop",
|
||||
}
|
||||
|
||||
MULTI_ANSWER_DATASETS = {
|
||||
"longbench/qasper",
|
||||
"longbench/multifieldqa_en",
|
||||
"longbench/hotpotqa",
|
||||
"longbench/2wikimqa",
|
||||
"squad",
|
||||
"squad_negative",
|
||||
"squad_assistant_ctx",
|
||||
"squad_negative_no_passage",
|
||||
"squad_assistant_ctx_no_passage",
|
||||
"drop",
|
||||
}
|
||||
|
||||
|
||||
for ds_name in list(CLOSED_QA_DATASETS):
|
||||
if ds_name.startswith("longbench/"):
|
||||
CLOSED_QA_DATASETS.add(f"{ds_name}_e")
|
||||
|
||||
|
||||
for ds_name in list(MULTI_ANSWER_DATASETS):
|
||||
if ds_name.startswith("longbench/"):
|
||||
MULTI_ANSWER_DATASETS.add(f"{ds_name}_e")
|
||||
|
||||
# for training closed qa datasets, e.g., hotpot_qa, squad, etc.
|
||||
CLOSED_QA_INTX_TEMPLATES = [
|
||||
"Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"Answer without any explanation.\n\nQuestion: {input}",
|
||||
"Based on the provided text, what is the answer to the following question? Provide only the answer.\n\nQuestion: {input}",
|
||||
"Extract the answer to the question from the text. Be concise. Do not explain.\n\nQuestion: {input}",
|
||||
"What is the answer to this question, based on the context? Respond with the answer only.\n\nQuestion: {input}",
|
||||
"Provide a direct answer to the question using the given passages. Do not give any explanation.\n\nQuestion: {input}",
|
||||
"Answer the question using only information from the provided text. No extra words.\n\nQuestion: {input}",
|
||||
"From the passages, answer the question. Just the answer, please.\n\nQuestion: {input}",
|
||||
"Give the answer to the question. Do not include any other text.\n\nQuestion: {input}",
|
||||
"The answer to the question is in the text. Find it and state it clearly. No need for explanation.\n\nQuestion: {input}",
|
||||
"Concisely answer the question based on the text provided. Don't include any other words. Just the answer.\n\nQuestion: {input}",
|
||||
"Read the passages and answer the question with the minimal necessary words.\n\nQuestion: {input}",
|
||||
"What is the direct response to the question, according to the text? Avoid explanation.\n\nQuestion: {input}",
|
||||
"Please provide only the answer to the question, derived from the text.\n\nQuestion: {input}",
|
||||
"Using the provided context, answer the question. Output the answer and nothing else.\n\nQuestion: {input}",
|
||||
"Identify the answer in the text and present it without elaboration.\n\nQuestion: {input}",
|
||||
"Answer the following question based on the text. Your answer should be brief and to the point. No explanation.\n\nQuestion: {input}",
|
||||
"Based on the information given, what is the answer to the question? Only state the answer.\n\nQuestion: {input}",
|
||||
"Find the answer to the question in the provided passages and write it down. No explanations.\n\nQuestion: {input}",
|
||||
"The question is: {input}. Provide the answer based on the text, and nothing more.",
|
||||
"Question: {input}\nAnswer directly based on the text provided. No extra words.",
|
||||
"Question: {input}\nPlease provide the answer based on the text. No explanation is needed.",
|
||||
]
|
||||
|
||||
|
||||
EVAL_INTX_TEMPLATES = {
|
||||
# binary (yes/no, a/b) qa given ctx
|
||||
"ropes": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
# short-ctx reasoning
|
||||
"drop": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
# short-ctx extractive qa
|
||||
"squad": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"squad_negative": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"squad_assistant_ctx": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"squad_negative_no_passage": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"squad_assistant_ctx_no_passage": "Answer the following question. Output only the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
# longbench
|
||||
"longbench/qasper": 'Answer the question as concisely as you can, using a single phrase or sentence if possible.\nIf the question cannot be answered based on the information in the article, write "unanswerable".\nIf the question is a yes/no question, answer "yes", "no", or "unanswerable". Do not provide any explanation.\n\nQuestion: {input}',
|
||||
"longbench/multifieldqa_en": "Answer the following question. Only output the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
"longbench/2wikimqa": "Answer the following question. Only output the answer and do not output any other words.\n\nQuestion: {input}",
|
||||
}
|
||||
for ds_name in LONGBENCH_E_TASKS:
|
||||
EVAL_INTX_TEMPLATES[ds_name] = EVAL_INTX_TEMPLATES[ds_name[:-2]]
|
||||
290
src/ctx_to_lora/data/packing.py
Normal file
290
src/ctx_to_lora/data/packing.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
# based on
|
||||
# https://github.com/MeetKai/functionary/blob/aa3dbdd65f7e388f2386622606bdfeec95c2b863/functionary/train/packing/packed_dataset.py
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pprint
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ctx_to_lora.utils import check_is_iterable, concat_list
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def pack_data_points_by_length(
|
||||
lens: list[list[int]],
|
||||
ctx_lens: list[list[int]],
|
||||
max_packed_inp_len: int,
|
||||
max_packed_ctx_len: int,
|
||||
max_size: int = -1,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
if not lens:
|
||||
return []
|
||||
|
||||
len_arr = np.array([sum(l) for l in lens], dtype=np.long)
|
||||
ctx_len_arr = np.array([sum(l) for l in ctx_lens], dtype=np.long)
|
||||
n = len(len_arr)
|
||||
assert len(ctx_len_arr) == n, "Length of ctx_len_arr must match length of lens"
|
||||
|
||||
if n == 1:
|
||||
return (
|
||||
[0]
|
||||
if len_arr[0] <= max_packed_inp_len and ctx_len_arr[0] <= max_packed_ctx_len
|
||||
else []
|
||||
)
|
||||
|
||||
# Create cumulative sum arrays for efficient range queries
|
||||
cumsum_inp_len = np.cumsum(len_arr)
|
||||
cumsum_ctx_len = np.cumsum(ctx_len_arr)
|
||||
|
||||
idx_pairs = []
|
||||
i = 0
|
||||
|
||||
while i < n:
|
||||
# Find the maximum j such that sum(lens[i:j+1]) <= max_packed_inp_len
|
||||
start_sum_inp = cumsum_inp_len[i - 1] if i > 0 else 0
|
||||
valid_ends_inp = (cumsum_inp_len[i:] - start_sum_inp) <= max_packed_inp_len
|
||||
|
||||
start_sum_ctx = cumsum_ctx_len[i - 1] if i > 0 else 0
|
||||
valid_ends_ctx = (cumsum_ctx_len[i:] - start_sum_ctx) <= max_packed_ctx_len
|
||||
valid_ends = valid_ends_inp & valid_ends_ctx
|
||||
|
||||
if not np.any(valid_ends):
|
||||
# Single item exceeds max_packed_inp_len, skip it
|
||||
logging.debug(
|
||||
f"Skipping item {i} with input length {len_arr[i]} and context length {ctx_len_arr[i]}"
|
||||
)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# Find the last valid index
|
||||
max_valid_idx = i + np.where(valid_ends)[0][-1]
|
||||
|
||||
# Apply max_size constraint
|
||||
if max_size > 0:
|
||||
max_valid_idx = min(max_valid_idx, i + max_size - 1)
|
||||
|
||||
idx_pairs.append((i, max_valid_idx + 1))
|
||||
i = max_valid_idx + 1
|
||||
|
||||
return idx_pairs
|
||||
|
||||
|
||||
def pack_data_points_FA(
|
||||
batch: dict[str, any],
|
||||
) -> dict[str, np.ndarray]:
|
||||
if not batch:
|
||||
raise ValueError("Batch is empty")
|
||||
|
||||
# Pre-allocate lists with known sizes
|
||||
total_ctx_len = sum(len(y) for x in batch["ctx_ids"] for y in x)
|
||||
total_inp_len = sum(len(y) for x in batch["input_ids"] for y in x)
|
||||
|
||||
ctx_ids = np.empty(total_ctx_len, dtype=np.long)
|
||||
ctx_position_ids = np.empty(total_ctx_len, dtype=np.long)
|
||||
input_ids = np.empty(total_inp_len, dtype=np.long)
|
||||
position_ids = np.empty(total_inp_len, dtype=np.long)
|
||||
labels = np.empty(total_inp_len, dtype=np.long)
|
||||
|
||||
has_logprobs = "logprobs_vals" in batch
|
||||
|
||||
if has_logprobs:
|
||||
sequences = zip(
|
||||
batch["input_ids"],
|
||||
batch["labels"],
|
||||
batch["logprobs_vals"],
|
||||
batch["logprobs_indices"],
|
||||
)
|
||||
n_labels = sum(len(y) for x in batch["logprobs_vals"] for y in x)
|
||||
k = len(batch["logprobs_vals"][0][0][0]) # assuming all have same k
|
||||
logprobs_vals = np.empty((n_labels, k), dtype=np.float32)
|
||||
logprobs_indices = np.empty((n_labels, k), dtype=np.int32)
|
||||
logits_offset = 0
|
||||
else:
|
||||
sequences = zip(batch["input_ids"], batch["labels"])
|
||||
|
||||
offset = 0
|
||||
for sample in sequences:
|
||||
input_ids_b, labels_b = sample[:2]
|
||||
inp_start = offset
|
||||
|
||||
# compute position_ids for each sub-list in input_ids_b
|
||||
local_start = inp_start
|
||||
for ids_b in input_ids_b:
|
||||
local_end = local_start + len(ids_b)
|
||||
position_ids[local_start:local_end] = np.arange(len(ids_b), dtype=np.int32)
|
||||
local_start = local_end
|
||||
|
||||
input_ids_b = concat_list(input_ids_b)
|
||||
labels_b = concat_list(labels_b)
|
||||
|
||||
inp_len = len(input_ids_b)
|
||||
inp_end = offset + inp_len
|
||||
|
||||
input_ids[inp_start:inp_end] = input_ids_b
|
||||
labels[inp_start:inp_end] = labels_b
|
||||
offset += inp_len
|
||||
|
||||
if has_logprobs:
|
||||
logprobs_vals_b, logprobs_indices_b = sample[2:]
|
||||
logprobs_vals_b = concat_list(logprobs_vals_b)
|
||||
logprobs_indices_b = concat_list(logprobs_indices_b)
|
||||
logits_len = len(logprobs_vals_b)
|
||||
logprobs_vals[logits_offset : logits_offset + logits_len] = logprobs_vals_b
|
||||
logprobs_indices[logits_offset : logits_offset + logits_len] = (
|
||||
logprobs_indices_b
|
||||
)
|
||||
logits_offset += logits_len
|
||||
|
||||
ctx_offset = 0
|
||||
for ctx_ids_b in batch["ctx_ids"]:
|
||||
local_start = ctx_offset
|
||||
for ctx_ids_b_item in ctx_ids_b:
|
||||
local_end = local_start + len(ctx_ids_b_item)
|
||||
ctx_position_ids[local_start:local_end] = np.arange(
|
||||
len(ctx_ids_b_item), dtype=np.int32
|
||||
)
|
||||
local_start = local_end
|
||||
|
||||
ctx_ids_b = concat_list(ctx_ids_b)
|
||||
ctx_len = len(ctx_ids_b)
|
||||
ctx_start, ctx_end = ctx_offset, ctx_offset + ctx_len
|
||||
ctx_ids[ctx_start:ctx_end] = ctx_ids_b
|
||||
ctx_offset += ctx_len
|
||||
|
||||
out = {
|
||||
"ctx_ids": ctx_ids,
|
||||
"ctx_position_ids": ctx_position_ids,
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"labels": labels,
|
||||
}
|
||||
if has_logprobs:
|
||||
out["logprobs_vals"] = logprobs_vals
|
||||
out["logprobs_indices"] = logprobs_indices
|
||||
return out
|
||||
|
||||
|
||||
def pack_batch(
|
||||
batch: dict[str, any],
|
||||
max_packed_inp_len: int,
|
||||
max_packed_ctx_len: int,
|
||||
max_packed_size: int = -1,
|
||||
metadata_path: str = "",
|
||||
) -> dict[str, any]:
|
||||
need_flatten = check_is_iterable(batch["input_ids"][0][0])
|
||||
assert need_flatten, (
|
||||
f"Packing requires the input_ids to be nested "
|
||||
f"(allowing multiple QAs per sample), but got {batch['input_ids'][0]}"
|
||||
)
|
||||
|
||||
n_queries = [len(x) for x in batch["input_ids"]]
|
||||
n_ctx_chunks = [len(x) for x in batch["ctx_ids"]]
|
||||
inp_lens = [[len(y) for y in x] for x in batch["input_ids"]]
|
||||
inp_count = len(inp_lens)
|
||||
if "ctx_ids" not in batch:
|
||||
raise ValueError("Batch must contain 'ctx_ids' and 'labels' keys")
|
||||
# we do not pad so we can just take the length of the tokens
|
||||
ctx_lens = [[len(y) for y in x] for x in batch["ctx_ids"]]
|
||||
|
||||
# Group indices
|
||||
idx_pairs = pack_data_points_by_length(
|
||||
inp_lens,
|
||||
ctx_lens,
|
||||
max_packed_inp_len,
|
||||
max_packed_ctx_len,
|
||||
max_packed_size,
|
||||
)
|
||||
|
||||
# Pack groups
|
||||
packed_batch = {
|
||||
"ctx_ids": [],
|
||||
"ctx_position_ids": [],
|
||||
"input_ids": [],
|
||||
"position_ids": [],
|
||||
"labels": [],
|
||||
"n_queries": [],
|
||||
"n_ctx_chunks": [],
|
||||
}
|
||||
has_logprobs = "logprobs_vals" in batch
|
||||
if has_logprobs:
|
||||
packed_batch["logprobs_vals"] = []
|
||||
packed_batch["logprobs_indices"] = []
|
||||
|
||||
packing_efficiency_ratios = []
|
||||
ctx_packing_efficiency_ratios = []
|
||||
|
||||
for idx_pair in idx_pairs:
|
||||
start_idx, end_idx = idx_pair[0], idx_pair[1]
|
||||
group_items = {
|
||||
"ctx_ids": batch["ctx_ids"][start_idx:end_idx],
|
||||
"input_ids": batch["input_ids"][start_idx:end_idx],
|
||||
"labels": batch["labels"][start_idx:end_idx],
|
||||
}
|
||||
if has_logprobs:
|
||||
group_items["logprobs_vals"] = batch["logprobs_vals"][start_idx:end_idx]
|
||||
group_items["logprobs_indices"] = batch["logprobs_indices"][
|
||||
start_idx:end_idx
|
||||
]
|
||||
packed_item = pack_data_points_FA(group_items)
|
||||
packed_batch["ctx_ids"].append(packed_item["ctx_ids"])
|
||||
packed_batch["ctx_position_ids"].append(packed_item["ctx_position_ids"])
|
||||
packed_batch["input_ids"].append(packed_item["input_ids"])
|
||||
packed_batch["position_ids"].append(packed_item["position_ids"])
|
||||
packed_batch["labels"].append(packed_item["labels"])
|
||||
packed_batch["n_queries"].append(n_queries[start_idx:end_idx])
|
||||
packed_batch["n_ctx_chunks"].append(n_ctx_chunks[start_idx:end_idx])
|
||||
if has_logprobs:
|
||||
packed_batch["logprobs_vals"].append(packed_item["logprobs_vals"])
|
||||
packed_batch["logprobs_indices"].append(packed_item["logprobs_indices"])
|
||||
|
||||
if max_packed_inp_len > 0:
|
||||
inp_efficiency = len(packed_item["input_ids"]) / max_packed_inp_len
|
||||
packing_efficiency_ratios.append(inp_efficiency)
|
||||
|
||||
if max_packed_ctx_len > 0:
|
||||
ctx_efficiency = len(packed_item["ctx_ids"]) / max_packed_ctx_len
|
||||
ctx_packing_efficiency_ratios.append(ctx_efficiency)
|
||||
|
||||
# Calculate length statistics
|
||||
packed_inp_lens_arr = np.array([len(x) for x in packed_batch["input_ids"]])
|
||||
packed_ctx_lens_arr = np.array([len(x) for x in packed_batch["ctx_ids"]])
|
||||
|
||||
# Log performance statistics
|
||||
avg_inp_packing_efficiency = (
|
||||
np.mean(packing_efficiency_ratios) if packing_efficiency_ratios else 0
|
||||
)
|
||||
avg_ctx_packing_efficiency = (
|
||||
np.mean(ctx_packing_efficiency_ratios) if ctx_packing_efficiency_ratios else 0
|
||||
)
|
||||
|
||||
# Create packing statistics dictionary
|
||||
packing_stats = {
|
||||
"original_samples": inp_count,
|
||||
"packed_samples": len(idx_pairs),
|
||||
"avg_inp_packing_efficiency": float(avg_inp_packing_efficiency),
|
||||
"avg_ctx_packing_efficiency": float(avg_ctx_packing_efficiency),
|
||||
"input_ids_length_stats": {
|
||||
"avg": float(np.mean(packed_inp_lens_arr)),
|
||||
"std": float(np.std(packed_inp_lens_arr)),
|
||||
"min": int(np.min(packed_inp_lens_arr)),
|
||||
"max": int(np.max(packed_inp_lens_arr)),
|
||||
},
|
||||
"context_ids_length_stats": {
|
||||
"avg": float(np.mean(packed_ctx_lens_arr)),
|
||||
"std": float(np.std(packed_ctx_lens_arr)),
|
||||
"min": int(np.min(packed_ctx_lens_arr)),
|
||||
"max": int(np.max(packed_ctx_lens_arr)),
|
||||
},
|
||||
}
|
||||
|
||||
# Save to metadata_path if provided
|
||||
if metadata_path:
|
||||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||||
with open(metadata_path, "w") as f:
|
||||
json.dump(packing_stats, f, indent=4)
|
||||
|
||||
logging.debug(f"Packing stats:\n{pprint.pformat(packing_stats, indent=2)}")
|
||||
|
||||
return packed_batch
|
||||
206
src/ctx_to_lora/data/preprocessing_fn.py
Normal file
206
src/ctx_to_lora/data/preprocessing_fn.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
import logging
|
||||
import random
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from ctx_to_lora.data.definitions import CLOSED_QA_INTX_TEMPLATES, EVAL_INTX_TEMPLATES
|
||||
from ctx_to_lora.utils import concat_list
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def closed_qa_prompting(prompt: str):
|
||||
template = random.choice(CLOSED_QA_INTX_TEMPLATES)
|
||||
return template.format(input=prompt)
|
||||
|
||||
|
||||
def chat_to_str(messages: list[dict[str, str]]):
|
||||
return "Below is the chat history from the current user.\n\n" + "\n\n".join(
|
||||
[
|
||||
"Message from: {role}\n{content}".format(
|
||||
**{**m, "role": m["role"].capitalize()}
|
||||
)
|
||||
for m in messages
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def get_preprocessing_fn(
|
||||
ds_name: str,
|
||||
is_eval: bool,
|
||||
) -> Callable[[dict[str, Any]], dict[str, Any]]:
|
||||
"""
|
||||
Get preprocessing function for a specific dataset.
|
||||
|
||||
Args:
|
||||
ds_name: Name of the dataset
|
||||
|
||||
Returns:
|
||||
A preprocessing function that takes and returns a dictionary
|
||||
"""
|
||||
f = lambda x: x
|
||||
if ds_name.startswith("self_gen") or ds_name.endswith("_compact"):
|
||||
# already processed data, do nothing
|
||||
return f
|
||||
|
||||
if "fw_qa_v2" in ds_name:
|
||||
|
||||
def f(sample):
|
||||
# get questions/answers from all levels in the ds
|
||||
q_cols = [col for col in sample.keys() if col.startswith("prompts_level")]
|
||||
r_cols = [col for col in sample.keys() if col.startswith("responses_level")]
|
||||
questions = concat_list([sample[col] for col in q_cols])
|
||||
responses = concat_list([sample[col] for col in r_cols])
|
||||
min_len = min(len(questions), len(responses))
|
||||
|
||||
if min_len == 0:
|
||||
return {
|
||||
"context": None,
|
||||
"prompts": None,
|
||||
"responses": None,
|
||||
}
|
||||
|
||||
return {
|
||||
"context": sample["context"],
|
||||
"prompts": questions[:min_len],
|
||||
"responses": responses[:min_len],
|
||||
}
|
||||
|
||||
elif ds_name.startswith("longbench"):
|
||||
|
||||
def f(sample):
|
||||
return {
|
||||
"context": sample["context"],
|
||||
"prompt": sample["input"],
|
||||
"response": sample["answers"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "pwc" or ds_name == "pwc_tiny":
|
||||
# original pwc
|
||||
def f(sample):
|
||||
return {
|
||||
"context": sample["input"],
|
||||
"prompt": sample["prompt"],
|
||||
"response": sample["answer"],
|
||||
}
|
||||
|
||||
elif ds_name == "squad":
|
||||
# original squad
|
||||
def f(sample):
|
||||
q = sample["question"]
|
||||
prompt = closed_qa_prompting(q) if not is_eval else q
|
||||
return {
|
||||
"context": sample["context"],
|
||||
"prompt": prompt,
|
||||
"response": sample["answers"]["text"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "squad_assistant_ctx":
|
||||
|
||||
def f(sample):
|
||||
return {
|
||||
"context": "You are a useful AI assistant.",
|
||||
"prompt": sample["context"] + "\n\n" + sample["question"],
|
||||
"response": sample["answers"]["text"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "squad_negative":
|
||||
with open("data/gutenburg_sample.txt") as f:
|
||||
gutenburg_sample = f.read()
|
||||
|
||||
def f(sample):
|
||||
return {
|
||||
"context": gutenburg_sample,
|
||||
"prompt": sample["context"] + "\n\n" + sample["question"],
|
||||
"response": sample["answers"]["text"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "squad_negative_no_passage":
|
||||
with open("data/gutenburg_sample.txt") as f:
|
||||
gutenburg_sample = f.read()
|
||||
|
||||
def f(sample):
|
||||
return {
|
||||
"context": gutenburg_sample,
|
||||
"prompt": sample["question"],
|
||||
"response": sample["answers"]["text"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "squad_assistant_ctx_no_passage":
|
||||
|
||||
def f(sample):
|
||||
return {
|
||||
"context": "You are a useful AI assistant.",
|
||||
"prompt": sample["question"],
|
||||
"response": sample["answers"]["text"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "drop":
|
||||
|
||||
def f(sample):
|
||||
q = sample["question"]
|
||||
prompt = closed_qa_prompting(q) if not is_eval else q
|
||||
return {
|
||||
"context": sample["passage"],
|
||||
"prompt": prompt,
|
||||
"response": sample["answers_spans"]["spans"][0],
|
||||
}
|
||||
|
||||
elif ds_name == "ropes":
|
||||
ctx_template = "{background}\n{situation}"
|
||||
|
||||
def f(sample):
|
||||
response = sample["answers"]["text"][0]
|
||||
bg_txt = sample["background"]
|
||||
situation_txt = sample["situation"]
|
||||
ctx = ctx_template.format(background=bg_txt, situation=situation_txt)
|
||||
q = sample["question"]
|
||||
q = closed_qa_prompting(q) if not is_eval else q
|
||||
return {"context": ctx, "prompt": q, "response": response}
|
||||
|
||||
if is_eval and (ds_name in EVAL_INTX_TEMPLATES):
|
||||
prompt_template = EVAL_INTX_TEMPLATES[ds_name]
|
||||
|
||||
def eval_intx_decorator(f):
|
||||
def g(sample):
|
||||
sample = f(sample)
|
||||
assert "prompt" in sample, (
|
||||
f"Expected 'prompt' in sample, got {sample.keys()}"
|
||||
)
|
||||
sample["prompt"] = prompt_template.format(input=sample["prompt"])
|
||||
return sample
|
||||
|
||||
return g
|
||||
|
||||
f = eval_intx_decorator(f)
|
||||
|
||||
def maybe_convert_to_list(f):
|
||||
def g(sample):
|
||||
sample = f(sample)
|
||||
if "prompt" in sample:
|
||||
sample["prompts"] = [sample.pop("prompt")]
|
||||
if "response" in sample:
|
||||
sample["responses"] = [sample.pop("response")]
|
||||
return sample
|
||||
|
||||
return g
|
||||
|
||||
f = maybe_convert_to_list(f)
|
||||
|
||||
if "self_gen" not in ds_name:
|
||||
|
||||
def strip_response(f):
|
||||
def g(sample):
|
||||
sample = f(sample)
|
||||
if "responses" in sample and bool(sample["responses"]):
|
||||
sample["responses"] = [
|
||||
r.strip() if isinstance(r, str) else r
|
||||
for r in sample["responses"]
|
||||
]
|
||||
return sample
|
||||
|
||||
return g
|
||||
|
||||
f = strip_response(f)
|
||||
|
||||
return f
|
||||
1070
src/ctx_to_lora/data/processing.py
Normal file
1070
src/ctx_to_lora/data/processing.py
Normal file
File diff suppressed because it is too large
Load diff
85
src/ctx_to_lora/data/q_generation_template.py
Normal file
85
src/ctx_to_lora/data/q_generation_template.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
STOP_STRINGS = {
|
||||
"google/gemma-3-12b-it": ["<eos>", "<end_of_turn>"],
|
||||
}
|
||||
|
||||
|
||||
Q_GEN_SYSTEM_TEMPLATE = (
|
||||
"You are a creative and helpful assistant.\n"
|
||||
"You are tasked with generating questions for reading comprehension tests.\n"
|
||||
"You will be given a context and you need to generate questions and corresponding answers from the given context.\n"
|
||||
"The questions should be highly specific to the information provided in the context, not general questions that suit any context.\n"
|
||||
"**DO NOT** hallucinate or make up information."
|
||||
)
|
||||
|
||||
Q_GEN_PROMPT_TEMPLATE = (
|
||||
"### Instructions ###\n"
|
||||
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
|
||||
"information provided in the context, not general questions that suit any context.\n\n"
|
||||
"### Context ###\n"
|
||||
"{context}\n\n\n"
|
||||
"### Rules ###\n"
|
||||
"Rules to follow when generating the questions:\n"
|
||||
"1. The questions must be specific to the given context and fully answerable from information present in the given context.\n"
|
||||
"2. Ask questions that are fact-seeking based on the information provided.\n"
|
||||
"3. Make sure the questions are clear and unambiguous.\n"
|
||||
"4. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the questions.\n"
|
||||
"5. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
|
||||
"6. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
|
||||
"7. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
|
||||
"8. Ignore typos, spacing, and grammatical errors in the context.\n\n"
|
||||
"Rules to follow when generating the answers:\n"
|
||||
"1. The answers must use the (implied) information provided in the context.\n"
|
||||
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the answers.\n"
|
||||
"3. Do not just copy words from the context. Answer the question in your own words.\n"
|
||||
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
|
||||
"Respond with {n_qa_pairs} question-answer pairs.\n"
|
||||
"Always use proper grammar and punctuation.\n"
|
||||
"Try to use different question forms and styles.\n"
|
||||
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
|
||||
"The question-answer pairs should be in the following format:\n"
|
||||
"Question 1: {{question_1}}\n"
|
||||
"Answer 1: {{answer_1}}\n"
|
||||
"Question 2: {{question_2}}\n"
|
||||
"Answer 2: {{answer_2}}\n"
|
||||
"..."
|
||||
)
|
||||
|
||||
Q_GEN_PROMPT_TEMPLATE_REPEAT = (
|
||||
"### Instructions ###\n"
|
||||
"Generate questions and corresponding answers from the given context. The questions should be highly specific to the "
|
||||
"information provided in the context, not general questions that suit any context.\n\n"
|
||||
"### Context ###\n"
|
||||
"{context}\n\n\n"
|
||||
"### Example Question-Answer Pairs ###\n"
|
||||
"{qa_pairs}\n\n\n"
|
||||
"### Rules ###\n"
|
||||
"Rules to follow when generating the questions:\n"
|
||||
"1. The questions must be specific to the given context and fully answerable from information present in *or* implied from the given context.\n"
|
||||
"2. The questions must *not* be redundant with the example questions-answer pairs provided.\n"
|
||||
"3. You should prioritize fact-seeking questions. Consider reversal questions, e.g., asking 'What causes X to happen?' is valid when 'Y causes X' is presented in the context.\n"
|
||||
"4. If all the facts in the context are already covered by the provided examples, you must generate *more complicated* questions that require reasoning beyond simple information retrieval.\nThis includes asking about information that can be inferred, requiring synthesizing information from multiple parts of the text, or understanding relationships between concepts, events, or individuals mentioned in the context. For example, if the context says 'The Eiffel Tower was completed in 1889 after 2 years of construction', you can ask 'When did the construction of the Eiffel Tower begin?'. Here's another example: if the context says 'Alice is Bob's mother. Bob is Charlie's Dad', you can ask 'Who is Charlie's grandmother?'.\n"
|
||||
"5. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the questions.\n"
|
||||
"6. The questions should not overlap. They should be diverse, covering many aspects of the context.\n"
|
||||
"7. Do not give away too much information in the questions. For example, ask 'Who is X?' instead of 'Who is X that did Y?' when Y is clear from the context.\n"
|
||||
"8. Ignore the text formatting of the context, e.g., bold, italic, underline, etc.\n"
|
||||
"9. Ignore typos, spacing, and grammatical errors in the context.\n\n"
|
||||
"Rules to follow when generating the answers:\n"
|
||||
"1. The answers must use the (implied) information provided in the context.\n"
|
||||
"2. Phrases like 'based on the provided context', 'according to the context', 'in the context', etc., are **NOT ALLOWED** to appear in "
|
||||
"the answers.\n"
|
||||
"3. Do not just copy words from the context. Answer the question in your own words.\n"
|
||||
"4. The answers should be detailed and comprehensive. Please include additional specific details from the context.\n\n"
|
||||
"Respond with {n_qa_pairs} question-answer pairs.\n"
|
||||
"Always use proper grammar and punctuation.\n"
|
||||
"Try to use different question forms and styles.\n"
|
||||
"Use simple words and make sure that the answers are clear and comprehensive.\n\n"
|
||||
"The question-answer pairs should be in the following format:\n"
|
||||
"Question 1: {{question_1}}\n"
|
||||
"Answer 1: {{answer_1}}\n"
|
||||
"Question 2: {{question_2}}\n"
|
||||
"Answer 2: {{answer_2}}\n"
|
||||
"..."
|
||||
)
|
||||
16
src/ctx_to_lora/data/self_gen_template.py
Normal file
16
src/ctx_to_lora/data/self_gen_template.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
SELF_GEN_SYSTEM_MSG = "You are an honest and helpful assistant."
|
||||
|
||||
SELF_QA_INTX = (
|
||||
"# System Instruction\n"
|
||||
"- The information provided is up-to-date information and/or the user instruction.\n"
|
||||
"- When the provided information is not relevant to the question, ***ignore*** it and answer the question based on your knowledge.\n"
|
||||
"- If the provided information is related to the question, incorporate it in your response.\n"
|
||||
"- If the provided information is an instruction, follow the instruction carefully.\n"
|
||||
"\n---\n\n"
|
||||
"# User Input\n"
|
||||
)
|
||||
|
||||
PRE_CTX = "# Provided Information\n"
|
||||
|
||||
QA_PROMPT_TEMPLATE = PRE_CTX + "{context}\n\n---\n\n" + SELF_QA_INTX + "{question}"
|
||||
PROMPT_TEMPLATE = "{context}\n\n{question}"
|
||||
1149
src/ctx_to_lora/eval_utils.py
Normal file
1149
src/ctx_to_lora/eval_utils.py
Normal file
File diff suppressed because it is too large
Load diff
163
src/ctx_to_lora/metrics.py
Normal file
163
src/ctx_to_lora/metrics.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from rouge_score import rouge_scorer
|
||||
from transformers import EvalPrediction
|
||||
|
||||
LENGTH_BINS = [
|
||||
# finegrain bins
|
||||
(0, 2**7 - 1),
|
||||
(2**7, 2**8 - 1),
|
||||
(2**8, 2**9 - 1),
|
||||
# coarse bins
|
||||
(0, 2**9 - 1),
|
||||
(2**9, 2**10 - 1),
|
||||
(2**10, 2**11 - 1),
|
||||
(2**11, 2**12 - 1),
|
||||
(2**12, 2**13 - 1),
|
||||
(0, 2**13 - 1),
|
||||
(2**13, 2**14 - 1),
|
||||
(2**14, 2**15 - 1),
|
||||
(2**15, float("inf")),
|
||||
]
|
||||
|
||||
|
||||
def get_length_bin(length: int):
|
||||
"""Get the length bin for a given length."""
|
||||
for i, (start, end) in enumerate(LENGTH_BINS):
|
||||
if start <= length < end:
|
||||
return (start, end)
|
||||
|
||||
|
||||
def compute_rouge(pred_texts, label_texts):
|
||||
out = defaultdict(list)
|
||||
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
|
||||
for pred_text, label_text in zip(pred_texts, label_texts):
|
||||
scores = scorer.score(pred_text, label_text)
|
||||
for k, v in scores.items():
|
||||
out[f"{k}.f1"].append(v.fmeasure)
|
||||
out_mean = dict()
|
||||
for k in out:
|
||||
out_mean[k] = np.mean(out[k])
|
||||
return out_mean, out
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_per_token_acc(shift_logits, shift_labels, valid_masks):
|
||||
indices = torch.where(valid_masks)
|
||||
acc = (shift_logits.argmax(-1) == shift_labels)[indices].float()
|
||||
return {
|
||||
"per_token_accs": acc.flatten().tolist(),
|
||||
"n_per_token_accs": valid_masks.sum().item(),
|
||||
}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_prefix_matching(shift_logits, shift_labels, valid_masks):
|
||||
lengths = valid_masks.sum(dim=1)
|
||||
|
||||
is_wrong = (shift_logits.argmax(-1) != shift_labels) * valid_masks
|
||||
is_correct = (shift_logits.argmax(-1) == shift_labels) * valid_masks
|
||||
# NOTE: not reliable for multi-turn conversations
|
||||
# ie, all tokens in the following user's turn will be correct
|
||||
# still monotonically correlate with perf though
|
||||
wrong_pos = torch.argmax(is_wrong, dim=1) - torch.argmax(valid_masks, dim=1)
|
||||
perf = wrong_pos / lengths
|
||||
|
||||
# if all tokens are correct, set to 1
|
||||
perf = torch.where(is_correct.sum(dim=1) == lengths, 1, perf)
|
||||
return {
|
||||
"prefix_matchings": perf.tolist(),
|
||||
"n_prefix_matchings": valid_masks.shape[0],
|
||||
}
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def compute_perplexity(shift_logits, shift_labels, valid_masks):
|
||||
return {"perplexities_ph": [1], "n_perplexities_ph": 1}
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, metric_fns: list[Callable]):
|
||||
self.metric_fns = metric_fns
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.accum_metrics = defaultdict(lambda: list((0,)))
|
||||
self.count = defaultdict(lambda: list((0,)))
|
||||
|
||||
def update(self, shift_logits, shift_labels, valid_masks, lengths=None):
|
||||
for metric_fn in self.metric_fns:
|
||||
# overall metric
|
||||
metric = metric_fn(shift_logits, shift_labels, valid_masks)
|
||||
for k, v in metric.items():
|
||||
key = k if not k.startswith("n_") else k[2:]
|
||||
if k.startswith("n_"):
|
||||
# prefix "n_" indicates the count of the metric
|
||||
self.count[key].append(v)
|
||||
else:
|
||||
self.accum_metrics[key] += v
|
||||
for start, end in LENGTH_BINS:
|
||||
key_w_len = f"{key}_len_{start}-{end}"
|
||||
if key_w_len not in self.accum_metrics:
|
||||
# add key here so that it shows up in the output
|
||||
self.accum_metrics[key_w_len] = [0]
|
||||
self.count[key_w_len] = [0]
|
||||
# split samples into length groups, calculate metric for each group
|
||||
if lengths is not None:
|
||||
for start, end in LENGTH_BINS:
|
||||
logits, labels, masks = [], [], []
|
||||
|
||||
for logit, label, m, len in zip(
|
||||
shift_logits, shift_labels, valid_masks, lengths
|
||||
):
|
||||
if isinstance(len, torch.Tensor):
|
||||
len = len.item()
|
||||
if start <= len < end:
|
||||
logits.append(logit)
|
||||
labels.append(label)
|
||||
masks.append(m)
|
||||
|
||||
if not logits:
|
||||
continue
|
||||
|
||||
metric = metric_fn(
|
||||
torch.stack(logits), torch.stack(labels), torch.stack(masks)
|
||||
)
|
||||
for k, v in metric.items():
|
||||
if k.startswith("n_"):
|
||||
key = f"{k[2:]}_len_{start}-{end}"
|
||||
self.count[key].append(v)
|
||||
else:
|
||||
key = f"{k}_len_{start}-{end}"
|
||||
self.accum_metrics[key] += v
|
||||
|
||||
def compute(self):
|
||||
# Get result across entire eval set
|
||||
result = {
|
||||
k: np.sum(v) / np.sum(self.count[k]) if len(v) > 1 else "None"
|
||||
for k, v in self.accum_metrics.items()
|
||||
}
|
||||
# Reset batch statistics
|
||||
self.reset()
|
||||
return result
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_metrics(
|
||||
eval_pred: EvalPrediction,
|
||||
compute_result: bool,
|
||||
evaluator: Evaluator,
|
||||
) -> dict | None:
|
||||
inputs = eval_pred.inputs
|
||||
len_key = "ctx_ids_len" if "ctx_ids_len" in inputs else "input_ids_len"
|
||||
lengths = inputs[len_key]
|
||||
logits, labels = eval_pred.predictions, eval_pred.label_ids
|
||||
shift_logits = logits[..., :-1, :]
|
||||
shift_labels = labels[..., 1:]
|
||||
valid_masks = torch.where(shift_labels != -100, 1, 0)
|
||||
evaluator.update(shift_logits, shift_labels, valid_masks, lengths)
|
||||
if compute_result:
|
||||
return evaluator.compute()
|
||||
183
src/ctx_to_lora/model_loading.py
Normal file
183
src/ctx_to_lora/model_loading.py
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from peft import get_peft_config as _get_peft_config
|
||||
from peft.utils import PeftType
|
||||
from transformers import (
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
Gemma3ForConditionalGeneration,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
GEMMA_VISION_MODELS = [
|
||||
"google/gemma-3-4b-it",
|
||||
"google/gemma-3-12b-it",
|
||||
"google/gemma-3-27b-it",
|
||||
]
|
||||
|
||||
|
||||
def check_is_vision_model(model_name):
|
||||
return model_name in GEMMA_VISION_MODELS
|
||||
|
||||
|
||||
def get_model_and_tokenizer(
|
||||
model_name_or_path,
|
||||
train,
|
||||
requires_grad,
|
||||
use_flash_attn=True,
|
||||
peft_config=None,
|
||||
model_kwargs=None,
|
||||
tokenizer_kwargs=None,
|
||||
use_q_lora=False,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
model = get_model(
|
||||
model_name_or_path,
|
||||
train,
|
||||
requires_grad,
|
||||
use_flash_attn,
|
||||
peft_config,
|
||||
model_kwargs,
|
||||
use_q_lora,
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
tokenizer = get_tokenizer(model_name_or_path, tokenizer_kwargs, peft_config, train)
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
if getattr(model, "generation_config", None):
|
||||
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_name_or_path, tokenizer_kwargs=None, peft_config=None, train=False
|
||||
):
|
||||
padding_side = "left" if not train else "right"
|
||||
truncation_side = "left"
|
||||
|
||||
if tokenizer_kwargs is None:
|
||||
tokenizer_kwargs = {}
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name_or_path,
|
||||
add_bos_tokens=False,
|
||||
add_eos_tokens=False,
|
||||
padding_side=padding_side,
|
||||
truncation_side=truncation_side,
|
||||
trust_remote_code=True,
|
||||
**tokenizer_kwargs,
|
||||
)
|
||||
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token_id = tokenizer.eos_token_id
|
||||
|
||||
template_path = f"chat_templates/{model_name_or_path}.jinja"
|
||||
if not os.path.exists(template_path):
|
||||
logger.warning(
|
||||
f"Chat template not found at {template_path}. Using default template."
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
logger.info(f"Using chat template from {template_path}")
|
||||
chat_template = open(template_path).read()
|
||||
chat_template = chat_template.replace(" ", "").replace("\n", "")
|
||||
tokenizer.chat_template = chat_template
|
||||
return tokenizer
|
||||
|
||||
|
||||
def get_model(
|
||||
model_name_or_path,
|
||||
train,
|
||||
requires_grad,
|
||||
use_flash_attn=True,
|
||||
peft_config=None,
|
||||
model_kwargs=None,
|
||||
use_q_lora=False,
|
||||
device="cuda",
|
||||
dtype=torch.bfloat16,
|
||||
):
|
||||
model_init_kwargs = dict(
|
||||
pretrained_model_name_or_path=model_name_or_path,
|
||||
device_map=device,
|
||||
torch_dtype=dtype,
|
||||
trust_remote_code=True,
|
||||
attn_implementation="eager",
|
||||
use_cache=None,
|
||||
)
|
||||
is_vision_model = check_is_vision_model(model_name_or_path)
|
||||
if model_kwargs is not None:
|
||||
model_init_kwargs.update(model_kwargs)
|
||||
|
||||
is_bidir_model = (
|
||||
"bert" in model_name_or_path.lower() or "gte" in model_name_or_path.lower()
|
||||
)
|
||||
|
||||
if use_flash_attn:
|
||||
if "gte" not in model_name_or_path:
|
||||
model_init_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
elif "gte" in model_name_or_path:
|
||||
model_init_kwargs["attn_implementation"] = "sdpa"
|
||||
|
||||
if is_vision_model:
|
||||
# always use sdpa for vision models
|
||||
# model_init_kwargs["attn_implementation"] = "sdpa"
|
||||
model_init_kwargs.pop("use_cache")
|
||||
elif is_bidir_model:
|
||||
model_init_kwargs["torch_dtype"] = torch.float32
|
||||
model_init_kwargs.pop("use_cache")
|
||||
|
||||
if use_q_lora:
|
||||
# https://huggingface.co/blog/4bit-transformers-bitsandbytes
|
||||
# https://colab.research.google.com/drive/1VoYNfYDKcKRQRor98Zbf2-9VQTtGJ24k?usp=sharing
|
||||
# see bitsandbytes for the quantization implementation https://github.com/bitsandbytes-foundation/bitsandbytes
|
||||
# see unsloth https://huggingface.co/docs/trl/v0.7.11/en/sft_trainer#accelerate-fine-tuning-2x-using-unsloth
|
||||
# does work currently bc it modifies the forward pass call of Linear
|
||||
bnb_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=torch.bfloat16,
|
||||
)
|
||||
model_init_kwargs["quantization_config"] = bnb_config
|
||||
|
||||
logger.debug(f"Model init kwargs: {model_init_kwargs}")
|
||||
if not is_vision_model:
|
||||
if is_bidir_model:
|
||||
model = AutoModel.from_pretrained(**model_init_kwargs)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(**model_init_kwargs)
|
||||
else:
|
||||
model = Gemma3ForConditionalGeneration.from_pretrained(**model_init_kwargs)
|
||||
model = model.language_model
|
||||
if peft_config is not None:
|
||||
model = PeftModel(model, peft_config)
|
||||
model.train(train)
|
||||
for name, param in model.named_parameters():
|
||||
param.requires_grad = requires_grad
|
||||
return model
|
||||
|
||||
|
||||
def get_lora_config(model_dir, **kwargs):
|
||||
if "target_modules" not in kwargs or kwargs["target_modules"] is None:
|
||||
logger.info("No target modules specified for LoRA.")
|
||||
return None
|
||||
r = kwargs.pop("lora_r", 8)
|
||||
peft_conf_kwargs = dict(
|
||||
r=r,
|
||||
peft_type=PeftType.LORA,
|
||||
base_model_name_or_path=model_dir,
|
||||
task_type="CAUSAL_LM",
|
||||
lora_dropout=kwargs.get("lora_dropout", 0.0),
|
||||
lora_alpha=r ** (3 / 2) * 2,
|
||||
)
|
||||
|
||||
peft_conf_kwargs.update(kwargs)
|
||||
peft_config = _get_peft_config(peft_conf_kwargs)
|
||||
return peft_config
|
||||
0
src/ctx_to_lora/modeling/__init__.py
Normal file
0
src/ctx_to_lora/modeling/__init__.py
Normal file
211
src/ctx_to_lora/modeling/aggregator.py
Normal file
211
src/ctx_to_lora/modeling/aggregator.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
import logging
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
from einops import rearrange, repeat, unpack
|
||||
from jaxtyping import Float, Integer
|
||||
from torch import Tensor, nn
|
||||
from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
|
||||
from ctx_to_lora.configs import (
|
||||
AggregatorArguments,
|
||||
)
|
||||
from ctx_to_lora.modeling.idefics2 import Idefics2Perceiver, Idefics2PerceiverConfig
|
||||
from ctx_to_lora.pooling import POOL_FN
|
||||
from ctx_to_lora.utils import (
|
||||
get_num_layers,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class AGGREGATOR_TYPE(str, Enum):
|
||||
POOLER = "pooler"
|
||||
PERCEIVER = "perceiver"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AggregatorConfig:
|
||||
aggregator_type: AGGREGATOR_TYPE
|
||||
num_layers: int
|
||||
num_modules: int
|
||||
num_extra_modules: int
|
||||
output_size: int
|
||||
feature_size: int
|
||||
|
||||
# pooler
|
||||
pooling_type: POOL_FN
|
||||
|
||||
# perceiver
|
||||
num_latent_factor: int
|
||||
lora_r: int
|
||||
per_rank_gen: bool
|
||||
|
||||
n_latent_queries: int
|
||||
num_blocks: int
|
||||
num_self_attn_per_block: int
|
||||
shared_weights: bool
|
||||
layer_to_layer_ctx_encoder: bool
|
||||
|
||||
|
||||
def get_aggregator_config(
|
||||
model: PreTrainedModel,
|
||||
ctx_encoder_model_config: PretrainedConfig,
|
||||
layer_to_layer_ctx_encoder: bool,
|
||||
output_size: int,
|
||||
num_modules: int,
|
||||
num_extra_modules: int,
|
||||
lora_r: int,
|
||||
per_rank_gen: bool,
|
||||
aggregator_args: AggregatorArguments,
|
||||
):
|
||||
return AggregatorConfig(
|
||||
feature_size=ctx_encoder_model_config.hidden_size,
|
||||
output_size=output_size,
|
||||
num_layers=get_num_layers(model),
|
||||
num_modules=num_modules,
|
||||
num_extra_modules=num_extra_modules,
|
||||
lora_r=lora_r,
|
||||
per_rank_gen=per_rank_gen,
|
||||
layer_to_layer_ctx_encoder=layer_to_layer_ctx_encoder,
|
||||
**vars(aggregator_args),
|
||||
)
|
||||
|
||||
|
||||
class Perceiver(nn.Module):
|
||||
"""perceiver w/ bottleneck size = n_modules * n_layers"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
feature_size,
|
||||
output_size,
|
||||
num_layers,
|
||||
num_modules,
|
||||
num_extra_modules,
|
||||
per_rank_gen,
|
||||
lora_r,
|
||||
num_latent_factor, # unused
|
||||
layer_to_layer_ctx_encoder,
|
||||
n_latent_queries,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
assert num_extra_modules == 0
|
||||
self.num_layers = num_layers
|
||||
self.num_modules = num_modules
|
||||
self.num_extra_modules = num_extra_modules
|
||||
self.per_rank_gen = per_rank_gen
|
||||
self.r = lora_r if self.per_rank_gen else 1
|
||||
n_output_queries = num_layers * (num_modules * self.r + num_extra_modules)
|
||||
self.layer_to_layer = layer_to_layer_ctx_encoder
|
||||
if self.layer_to_layer:
|
||||
n_output_queries = num_modules * self.r + num_extra_modules
|
||||
self.config = Idefics2PerceiverConfig(
|
||||
input_size=feature_size,
|
||||
num_blocks=kwargs["num_blocks"],
|
||||
num_self_attn_per_block=kwargs["num_self_attn_per_block"],
|
||||
shared_weights=kwargs["shared_weights"],
|
||||
n_latents=n_latent_queries,
|
||||
intermediate_size_factor=4,
|
||||
hidden_size=output_size,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
self.decoder_config = Idefics2PerceiverConfig(
|
||||
input_size=output_size,
|
||||
num_blocks=1,
|
||||
num_self_attn_per_block=0,
|
||||
shared_weights=False,
|
||||
n_latents=n_output_queries,
|
||||
intermediate_size_factor=4,
|
||||
hidden_size=output_size,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
self.perceiver = Idefics2Perceiver(self.config, self.decoder_config)
|
||||
self.iterative_mode = False
|
||||
|
||||
def enable_iterative_mode(self, x: bool):
|
||||
self.iterative_mode = x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_features: Float[Tensor, "bs seq_len feature_dim"]
|
||||
| Float[Tensor, "bs x seq_len feature_dim"],
|
||||
ctx_attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
ctx_position_ids: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
):
|
||||
if self.layer_to_layer and not self.iterative_mode:
|
||||
if ctx_attn_mask is not None:
|
||||
ctx_attn_mask = repeat(
|
||||
ctx_attn_mask,
|
||||
"bs seq_len -> (num_layers bs) seq_len",
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
ctx_features = rearrange(
|
||||
ctx_features,
|
||||
"bs num_layers seq_len feature_dim -> (num_layers bs) seq_len feature_dim",
|
||||
)
|
||||
if ctx_position_ids is not None:
|
||||
ctx_position_ids = repeat(
|
||||
ctx_position_ids,
|
||||
"1 seq_len -> 1 (num_layers seq_len)",
|
||||
num_layers=self.num_layers,
|
||||
)
|
||||
ctx_features = rearrange(
|
||||
ctx_features,
|
||||
"1 num_layers seq_len feature_dim -> 1 (num_layers seq_len) feature_dim",
|
||||
)
|
||||
|
||||
x = self.perceiver(ctx_features, ctx_attn_mask, ctx_position_ids)
|
||||
|
||||
if self.layer_to_layer and self.iterative_mode:
|
||||
lora_x = rearrange(
|
||||
x,
|
||||
"bs (n_modules r) d -> bs n_modules r d",
|
||||
n_modules=self.num_modules,
|
||||
r=self.r,
|
||||
)
|
||||
return lora_x, None
|
||||
|
||||
if self.layer_to_layer:
|
||||
per_layer_size = self.num_modules * self.r + self.num_extra_modules
|
||||
x = rearrange(
|
||||
x,
|
||||
"(num_layers bs) (per_layer_sz) d -> bs (num_layers per_layer_sz) d",
|
||||
num_layers=self.num_layers,
|
||||
per_layer_sz=per_layer_size,
|
||||
)
|
||||
lora_x, extra_x = unpack(
|
||||
x,
|
||||
[
|
||||
[self.num_layers * self.num_modules * self.r],
|
||||
[self.num_layers * self.num_extra_modules],
|
||||
],
|
||||
"bs * feature_dim",
|
||||
)
|
||||
lora_x = rearrange(
|
||||
lora_x,
|
||||
"bs (n_layers n_modules r) d -> bs n_layers n_modules r d",
|
||||
n_modules=self.num_modules,
|
||||
n_layers=self.num_layers,
|
||||
r=self.r,
|
||||
)
|
||||
if not self.per_rank_gen:
|
||||
lora_x = lora_x.squeeze(3)
|
||||
|
||||
extra_x = rearrange(
|
||||
extra_x,
|
||||
"bs (n_layers n_extra_modules) d -> bs n_layers n_extra_modules d",
|
||||
n_extra_modules=self.num_extra_modules,
|
||||
n_layers=self.num_layers,
|
||||
)
|
||||
|
||||
return lora_x, extra_x
|
||||
|
||||
|
||||
AGGREGATOR_CLS = {
|
||||
AGGREGATOR_TYPE.PERCEIVER: Perceiver,
|
||||
}
|
||||
632
src/ctx_to_lora/modeling/context_distillation.py
Normal file
632
src/ctx_to_lora/modeling/context_distillation.py
Normal file
|
|
@ -0,0 +1,632 @@
|
|||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from math import ceil
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from jaxtyping import Integer
|
||||
from peft import PeftModel
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
|
||||
from torch import Tensor, nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
|
||||
from ctx_to_lora.data.definitions import CTX_AFFIXES
|
||||
from ctx_to_lora.data.q_generation_template import (
|
||||
Q_GEN_PROMPT_TEMPLATE,
|
||||
Q_GEN_PROMPT_TEMPLATE_REPEAT,
|
||||
Q_GEN_SYSTEM_TEMPLATE,
|
||||
STOP_STRINGS,
|
||||
)
|
||||
from ctx_to_lora.data.self_gen_template import SELF_QA_INTX
|
||||
from ctx_to_lora.utils import log_num_train_params
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def get_q_gen_prompt(context, n_qa_pairs):
|
||||
prompt = Q_GEN_PROMPT_TEMPLATE.format(context=context, n_qa_pairs=n_qa_pairs)
|
||||
return prompt
|
||||
|
||||
|
||||
def get_q_gen_prompt_repeat(context, qa_pairs, n_qa_pairs):
|
||||
example_qa_pairs = ""
|
||||
for i, (q, a) in enumerate(qa_pairs, 1):
|
||||
example_qa_pairs += f"Question {i}: {q}\nAnswer {i}: {a}\n"
|
||||
prompt = Q_GEN_PROMPT_TEMPLATE_REPEAT.format(
|
||||
context=context,
|
||||
qa_pairs=example_qa_pairs,
|
||||
n_qa_pairs=n_qa_pairs,
|
||||
)
|
||||
return prompt
|
||||
|
||||
|
||||
def check_should_skip(txt: str, vllm_model: str) -> bool:
|
||||
"""Check if the response should be skipped based on stop strings."""
|
||||
for stop in STOP_STRINGS[vllm_model]:
|
||||
if stop in txt[-len(stop) :]:
|
||||
return (txt.split(stop)[0], False) # Found a valid stop string
|
||||
return (txt, True) # No valid stop string found, skip this response
|
||||
|
||||
|
||||
def postprocess_qa_pairs(res_txt: str):
|
||||
"""
|
||||
Postprocesses the QA pairs from the response text.
|
||||
|
||||
Args:
|
||||
res_txt: The response text.
|
||||
n_qa_pairs: The number of QA pairs.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists, the first containing the questions and the second containing the answers.
|
||||
"""
|
||||
# capture everything after each "Question {number}:" until "Answer"
|
||||
q_pattern = r"Question \d+:(.*?)(?=Answer|$)" # thanks chatgpt
|
||||
questions = re.findall(q_pattern, res_txt, flags=re.S)
|
||||
|
||||
a_pattern = r"Answer \d+:(.*?)(?=Question|$)" # thanks chatgpt
|
||||
answers = re.findall(a_pattern, res_txt, flags=re.S)
|
||||
|
||||
if len(questions) != len(answers):
|
||||
print(f"Warning---number of questions and answers do not match")
|
||||
print(f"Number of questions: {len(questions)}")
|
||||
print(f"Number of answers: {len(answers)}")
|
||||
|
||||
out_q = []
|
||||
out_a = []
|
||||
n_skips = 0
|
||||
if (len(questions) > 0) and (len(answers) > 0):
|
||||
n_gen_pairs = min(len(questions), len(answers))
|
||||
has_left_over = n_gen_pairs < len(questions) or n_gen_pairs < len(answers)
|
||||
for i in range(n_gen_pairs):
|
||||
response = answers[i].strip()
|
||||
question = questions[i].strip()
|
||||
if not response or not question:
|
||||
print(f"Skipping empty question or answer at index {i}")
|
||||
continue
|
||||
if (not has_left_over) and (i == n_gen_pairs - 1):
|
||||
response, skip = check_should_skip(response, "google/gemma-3-12b-it")
|
||||
if skip:
|
||||
print(f"Skipping due to missing stop string")
|
||||
n_skips += 1
|
||||
continue
|
||||
out_q.append(question.strip())
|
||||
out_a.append(response.strip())
|
||||
print(f"Skipped {n_skips} responses due to missing stop strings")
|
||||
|
||||
return out_q, out_a
|
||||
|
||||
|
||||
def build_messages(ctx_text: str, level: int, example_qa_pairs: list = None):
|
||||
messages = [
|
||||
{"role": "system", "content": Q_GEN_SYSTEM_TEMPLATE},
|
||||
{
|
||||
"role": "user",
|
||||
"content": get_q_gen_prompt(ctx_text, 5)
|
||||
if level == 0
|
||||
else get_q_gen_prompt_repeat(ctx_text, example_qa_pairs, 5),
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
def get_shifted_label_pos(labels):
|
||||
pos = torch.where(labels != -100)
|
||||
# (batch_idx, token_idx)
|
||||
return (pos[0], pos[1] - 1)
|
||||
|
||||
|
||||
def logits_at_positions(outputs: ModelOutput, pos) -> Tensor:
|
||||
logits = outputs.logits
|
||||
return logits[pos[0], pos[1]]
|
||||
|
||||
|
||||
def ctx_inp_split(
|
||||
ctx_inp_ids, ctx_inp_sep_seq, pad_token_id, prefix_tokens=None, padding_side="right"
|
||||
):
|
||||
# Split each row in ctx_inp_ids at the first occurrence of ctx_inp_sep_seq
|
||||
# Return the part after the separator for each row
|
||||
batch_size = ctx_inp_ids.size(0)
|
||||
sep_len = ctx_inp_sep_seq.size(0)
|
||||
out_inp = []
|
||||
out_ctx = []
|
||||
for i in range(batch_size):
|
||||
row = ctx_inp_ids[i]
|
||||
# Find where the separator starts
|
||||
for j in range(row.size(0) - sep_len + 1):
|
||||
if torch.equal(row[j : j + sep_len], ctx_inp_sep_seq):
|
||||
out_ctx.append(row[:j])
|
||||
if prefix_tokens is not None:
|
||||
out_inp.append(
|
||||
torch.cat([prefix_tokens, row[j + sep_len :]], axis=-1)
|
||||
)
|
||||
else:
|
||||
out_inp.append(row[j + sep_len :])
|
||||
break
|
||||
else:
|
||||
# If separator not found
|
||||
raise ValueError(f"Separator sequence not found in row {i}")
|
||||
out_inp = torch.nn.utils.rnn.pad_sequence(
|
||||
out_inp, batch_first=True, padding_value=pad_token_id, padding_side=padding_side
|
||||
)
|
||||
out_ctx = torch.nn.utils.rnn.pad_sequence(
|
||||
out_ctx, batch_first=True, padding_value=pad_token_id, padding_side=padding_side
|
||||
)
|
||||
return out_ctx, out_inp
|
||||
|
||||
|
||||
def get_peft_layers(model, peft_config):
|
||||
out = []
|
||||
for module_name, module in model.named_modules():
|
||||
if not check_target_module_exists(peft_config, module_name):
|
||||
continue
|
||||
if not isinstance(module, BaseTunerLayer):
|
||||
continue
|
||||
# support just Linear layer for now
|
||||
# all modules should be a leave module that is Linear layer
|
||||
assert isinstance(module.base_layer, nn.Linear), (
|
||||
"all modules should be a leave module that is Linear layer"
|
||||
)
|
||||
|
||||
# this should always pass
|
||||
name = module_name.split(".")[-1]
|
||||
assert name in peft_config.target_modules
|
||||
out.append(module)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CtxDistillModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_model: PeftModel,
|
||||
prefix_tokens: Integer[Tensor, "n"],
|
||||
ctx_inp_sep_seq: Integer[Tensor, "m"],
|
||||
pad_token_id: int,
|
||||
update_iterations: int,
|
||||
reset: bool = True,
|
||||
tokenizer=None,
|
||||
q_model: PreTrainedModel | None = None,
|
||||
q_tokenizer=None,
|
||||
reprompt_ctx: bool = False,
|
||||
lora_save_dir: str | None = None,
|
||||
save_after_distill: bool = True,
|
||||
q_gen_rounds: int = 4,
|
||||
batch_size: int = 16,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_module("base_model", base_model)
|
||||
self.register_module("q_model", q_model)
|
||||
self.register_buffer("prefix_tokens", prefix_tokens)
|
||||
self.register_buffer("ctx_inp_sep_seq", ctx_inp_sep_seq)
|
||||
self.tokenizer = tokenizer
|
||||
self.q_tokenizer = q_tokenizer
|
||||
self.pad_token_id = pad_token_id
|
||||
self.update_iterations = update_iterations
|
||||
self.reprompt_ctx = reprompt_ctx
|
||||
self.reset = reset
|
||||
self.device = base_model.device
|
||||
self.to(self.device)
|
||||
self.q_gen_rounds = q_gen_rounds
|
||||
# New save options
|
||||
self.lora_save_dir = lora_save_dir
|
||||
self.save_after_distill = save_after_distill
|
||||
|
||||
self.peft_config = base_model.peft_config["default"]
|
||||
self.adapter_name = "default"
|
||||
self.base_model.set_adapter("default")
|
||||
for layer in get_peft_layers(self.base_model, self.peft_config):
|
||||
for name, p in layer.named_parameters():
|
||||
if "lora_A" in name or "lora_B" in name:
|
||||
p.requires_grad = True
|
||||
log_num_train_params(self.base_model)
|
||||
self._init_optim()
|
||||
# Mini-batch size for distillation updates
|
||||
self.batch_size = batch_size
|
||||
|
||||
@property
|
||||
def generation_config(self):
|
||||
return self.base_model.generation_config
|
||||
|
||||
def _init_optim(self):
|
||||
self.optimizer = torch.optim.AdamW(
|
||||
[
|
||||
p
|
||||
for l in get_peft_layers(self.base_model, self.peft_config)
|
||||
for p in l.parameters()
|
||||
if p.requires_grad
|
||||
],
|
||||
lr=1e-4,
|
||||
)
|
||||
|
||||
def reset_lora(self):
|
||||
print("Resetting LoRA")
|
||||
for layer in get_peft_layers(self.base_model, self.peft_config):
|
||||
layer.reset_lora_parameters(self.adapter_name, init_lora_weights=True)
|
||||
self._init_optim()
|
||||
|
||||
def save_lora(self):
|
||||
"""
|
||||
Save current LoRA adapter in PEFT format plus a lightweight JSON summary
|
||||
for easy human inspection/manipulation.
|
||||
"""
|
||||
if self.lora_save_dir is None:
|
||||
return
|
||||
os.makedirs(self.lora_save_dir, exist_ok=True)
|
||||
# Standard PEFT save (produces adapter_config.json + adapter_model.bin / safetensors)
|
||||
self.base_model.save_pretrained(self.lora_save_dir)
|
||||
# Human-readable summary of LoRA parameter shapes
|
||||
summary = {
|
||||
name: list(p.shape)
|
||||
for name, p in self.base_model.named_parameters()
|
||||
if ("lora_A" in name or "lora_B" in name) and p.requires_grad
|
||||
}
|
||||
with open(os.path.join(self.lora_save_dir, "lora_summary.json"), "w") as f:
|
||||
json.dump(summary, f, indent=2)
|
||||
print(f"LoRA adapter saved to {self.lora_save_dir}")
|
||||
|
||||
@torch.enable_grad()
|
||||
def _distill_context(
|
||||
self,
|
||||
ctx_inp_res_ids: Integer[Tensor, "bs ctx_inp_length"],
|
||||
ctx_inp_res_attention_mask: Integer[Tensor, "bs ctx_inp_length"],
|
||||
teacher_labels: Integer[Tensor, "bs ctx_inp_length"],
|
||||
inp_res_ids: Integer[Tensor, "bs inp_length"],
|
||||
inp_res_attention_mask: Integer[Tensor, "bs inp_length"],
|
||||
student_labels: Integer[Tensor, "bs inp_length"],
|
||||
):
|
||||
# Implements KD-style loss by computing teacher (with context) and student (no context)
|
||||
# log-probs locally, using mini-batches for updates.
|
||||
|
||||
was_training = self.training
|
||||
self.train()
|
||||
|
||||
num_samples = ctx_inp_res_ids.size(0)
|
||||
mb = self.batch_size
|
||||
num_batches = ceil(num_samples / mb)
|
||||
|
||||
total_steps = max(self.update_iterations * num_batches, 1)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
self.optimizer, T_max=total_steps, eta_min=0.0
|
||||
)
|
||||
|
||||
print(
|
||||
f"Starting context distillation for {self.update_iterations} epochs, "
|
||||
f"mini-batch size {mb} ({num_batches} batches/epoch), "
|
||||
f"cosine LR schedule over {total_steps} steps"
|
||||
)
|
||||
|
||||
for epoch in range(self.update_iterations):
|
||||
# Shuffle order each epoch for SGD
|
||||
perm = torch.randperm(num_samples, device=self.device)
|
||||
|
||||
epoch_loss = 0.0
|
||||
for b in range(num_batches):
|
||||
start = b * mb
|
||||
end = min(start + mb, num_samples)
|
||||
indices = perm[start:end]
|
||||
|
||||
b_ctx_ids = ctx_inp_res_ids[indices]
|
||||
b_ctx_am = ctx_inp_res_attention_mask[indices]
|
||||
b_teacher_labels = teacher_labels[indices]
|
||||
|
||||
b_inp_ids = inp_res_ids[indices]
|
||||
b_inp_am = inp_res_attention_mask[indices]
|
||||
b_student_labels = student_labels[indices]
|
||||
|
||||
# Compute teacher distribution (top-k) for this mini-batch
|
||||
with torch.no_grad(), self.base_model.disable_adapter():
|
||||
t_pos = get_shifted_label_pos(b_teacher_labels)
|
||||
teacher_outputs = self.base_model(
|
||||
b_ctx_ids, attention_mask=b_ctx_am
|
||||
)
|
||||
teacher_logits = logits_at_positions(teacher_outputs, t_pos)
|
||||
K = 16
|
||||
topk_vals, topk_idx = teacher_logits.topk(K, dim=-1)
|
||||
teacher_denom = torch.logsumexp(
|
||||
teacher_logits.float(), dim=-1, keepdim=True
|
||||
)
|
||||
teacher_p = (topk_vals - teacher_denom).exp().detach() # [N, K]
|
||||
|
||||
# Student forward and update for this mini-batch
|
||||
self.optimizer.zero_grad()
|
||||
s_pos = get_shifted_label_pos(b_student_labels)
|
||||
student_outputs = self.base_model(b_inp_ids, attention_mask=b_inp_am)
|
||||
student_logits = logits_at_positions(student_outputs, s_pos)
|
||||
student_denom = torch.logsumexp(
|
||||
student_logits.float(), dim=-1, keepdim=True
|
||||
)
|
||||
selected_student_logits = student_logits.gather(-1, topk_idx)
|
||||
student_logq = selected_student_logits - student_denom # [N, K]
|
||||
token_losses = -(teacher_p * student_logq).sum(dim=-1) # [N]
|
||||
loss = token_losses.mean()
|
||||
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch_loss += loss.detach().item()
|
||||
|
||||
cur_lr = self.optimizer.param_groups[0]["lr"]
|
||||
print(
|
||||
f"Epoch {epoch + 1}/{self.update_iterations}, "
|
||||
f"batch {b + 1}/{num_batches}: loss={epoch_loss / num_batches:.4f}, lr={cur_lr:.6e}"
|
||||
)
|
||||
|
||||
if not was_training:
|
||||
self.eval()
|
||||
|
||||
def generate_questions(self, *args, **kwargs):
|
||||
questions = self.q_model.generate(*args, **kwargs)
|
||||
return questions
|
||||
|
||||
def teacher_generate(self, *args, **kwargs):
|
||||
# rename for separate timing
|
||||
return self.base_model.generate(*args, **kwargs)
|
||||
|
||||
def student_generate(self, *args, **kwargs):
|
||||
# rename for separate timing
|
||||
return self.base_model.generate(*args, **kwargs)
|
||||
|
||||
def get_lora_state(self, clone: bool = True):
|
||||
"""
|
||||
Return a dict of current LoRA parameter tensors.
|
||||
clone=True returns detached cloned tensors (safe to store).
|
||||
"""
|
||||
return {
|
||||
name: (p.detach().clone() if clone else p)
|
||||
for name, p in self.base_model.named_parameters()
|
||||
if ("lora_A" in name or "lora_B" in name)
|
||||
}
|
||||
|
||||
def generate(
|
||||
self,
|
||||
*model_inputs_args: Any,
|
||||
distill_only: bool = False,
|
||||
**model_inputs_kwargs: dict[str, Any],
|
||||
):
|
||||
if self.reset:
|
||||
self.reset_lora()
|
||||
|
||||
# teacher tokens
|
||||
orig_ctx_inp_ids = model_inputs_kwargs.pop("input_ids")
|
||||
ctx_inp_ids = orig_ctx_inp_ids.clone()
|
||||
_, orig_inp_ids = ctx_inp_split(
|
||||
ctx_inp_ids,
|
||||
self.ctx_inp_sep_seq,
|
||||
self.pad_token_id,
|
||||
self.prefix_tokens,
|
||||
padding_side="left",
|
||||
)
|
||||
ctx_inp_attention_mask = model_inputs_kwargs.pop("attention_mask")
|
||||
|
||||
if self.q_model is not None:
|
||||
self.q_model.to(self.base_model.device)
|
||||
# Extract context-only portion after separator (remove prefix tokens from first row)
|
||||
ctx_ids_full, _ = ctx_inp_split(
|
||||
ctx_inp_ids, self.ctx_inp_sep_seq, self.pad_token_id
|
||||
) # [bs, var_len]
|
||||
ctx_ids = ctx_ids_full[0, len(self.prefix_tokens) :]
|
||||
ctx_txt = self.tokenizer.decode(ctx_ids, skip_special_tokens=True)
|
||||
questions = []
|
||||
answers = []
|
||||
# Build multiple instruction variants
|
||||
for lvl in range(self.q_gen_rounds):
|
||||
messages = build_messages(
|
||||
ctx_txt, lvl, zip(questions, answers) if lvl > 0 else None
|
||||
)
|
||||
q_inputs = self.q_tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_special_tokens=False,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
q_inputs = {k: v.to(self.q_model.device) for k, v in q_inputs.items()}
|
||||
|
||||
with torch.no_grad():
|
||||
question_outputs = self.generate_questions(
|
||||
input_ids=q_inputs["input_ids"],
|
||||
attention_mask=q_inputs["attention_mask"],
|
||||
max_new_tokens=1024,
|
||||
do_sample=False,
|
||||
temperature=0.0,
|
||||
eos_token_id=106, # <end_of_turn> for gemma-3-12b-it
|
||||
)
|
||||
# Slice off the prompt portion
|
||||
gen_only = question_outputs[:, q_inputs["input_ids"].shape[-1] :]
|
||||
res = self.q_tokenizer.batch_decode(gen_only, skip_special_tokens=False)
|
||||
gen_q_list, gen_a_list = postprocess_qa_pairs(res[0])
|
||||
questions += gen_q_list
|
||||
answers += gen_a_list
|
||||
|
||||
if len(questions) == 0:
|
||||
# when q_model refuses to provide questions
|
||||
# only happens with sample 116 in longbench/multifieldqa_en_e
|
||||
# in this case cd just doesn't work, we fall back to zero-shot answer
|
||||
print(f"Warning---no questions generated, skipping distillation")
|
||||
attention_mask = torch.where(
|
||||
orig_inp_ids != self.pad_token_id, 1, 0
|
||||
).long()
|
||||
return self.student_generate(
|
||||
orig_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
|
||||
)
|
||||
|
||||
ctx_inp_messages = [
|
||||
[{"role": "user", "content": f"{ctx_txt}\n\n{SELF_QA_INTX}\n\n{q}"}]
|
||||
for q in questions
|
||||
]
|
||||
encoded_ctx_inp = self.tokenizer.apply_chat_template(
|
||||
ctx_inp_messages,
|
||||
tokenize=True,
|
||||
add_special_tokens=False,
|
||||
padding=True,
|
||||
truncation=False,
|
||||
add_generation_prompt=True,
|
||||
return_dict=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
encoded_ctx_inp = {k: v.to(self.device) for k, v in encoded_ctx_inp.items()}
|
||||
ctx_inp_ids = encoded_ctx_inp["input_ids"]
|
||||
ctx_inp_attention_mask = encoded_ctx_inp["attention_mask"]
|
||||
self.q_model.to("cpu")
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# sample responses first
|
||||
ctx_inp_res_ids = self.teacher_generate(
|
||||
ctx_inp_ids,
|
||||
attention_mask=ctx_inp_attention_mask,
|
||||
**model_inputs_kwargs,
|
||||
)
|
||||
ctx_inp_res_attention_mask = torch.where(
|
||||
ctx_inp_res_ids != self.pad_token_id, 1, 0
|
||||
).long()
|
||||
ctx_inp_res_txt = self.tokenizer.batch_decode(ctx_inp_res_ids)
|
||||
|
||||
bs = ctx_inp_ids.shape[0]
|
||||
res_len = ctx_inp_res_ids.shape[-1] - ctx_inp_ids.shape[-1]
|
||||
res_ids = ctx_inp_res_ids[:, -res_len:] # correct
|
||||
|
||||
pads = torch.full_like(ctx_inp_ids, self.pad_token_id)
|
||||
teacher_labels = torch.cat([pads, res_ids], dim=-1)
|
||||
teacher_labels = torch.where(
|
||||
teacher_labels != self.pad_token_id, teacher_labels, -100
|
||||
)
|
||||
|
||||
# student tokens
|
||||
_, inp_res_ids = ctx_inp_split(
|
||||
ctx_inp_res_ids,
|
||||
self.ctx_inp_sep_seq,
|
||||
self.pad_token_id,
|
||||
self.prefix_tokens,
|
||||
padding_side="left",
|
||||
)
|
||||
inp_res_attention_mask = torch.where(
|
||||
inp_res_ids != self.pad_token_id, 1, 0
|
||||
).long()
|
||||
|
||||
student_labels = inp_res_ids.clone()
|
||||
student_labels[:, :-res_len] = -100
|
||||
student_labels = torch.where(
|
||||
student_labels != self.pad_token_id, student_labels, -100
|
||||
)
|
||||
|
||||
self._distill_context(
|
||||
ctx_inp_res_ids,
|
||||
ctx_inp_res_attention_mask,
|
||||
teacher_labels,
|
||||
inp_res_ids,
|
||||
inp_res_attention_mask,
|
||||
student_labels,
|
||||
)
|
||||
# Save LoRA after distillation if requested
|
||||
if distill_only:
|
||||
return self.get_lora_state()
|
||||
|
||||
model_inputs_kwargs.pop("attention_mask", None)
|
||||
model_inputs_kwargs.pop("input_ids", None)
|
||||
if self.reprompt_ctx:
|
||||
attention_mask = torch.where(orig_ctx_inp_ids != self.pad_token_id, 1, 0)
|
||||
model_outputs = self.student_generate(
|
||||
orig_ctx_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
|
||||
)
|
||||
else:
|
||||
attention_mask = torch.where(orig_inp_ids != self.pad_token_id, 1, 0).long()
|
||||
model_outputs = self.student_generate(
|
||||
orig_inp_ids, attention_mask=attention_mask, **model_inputs_kwargs
|
||||
)
|
||||
return model_outputs
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ctx_to_lora.data.processing import load_and_process_dataset
|
||||
from ctx_to_lora.model_loading import get_lora_config, get_model_and_tokenizer
|
||||
|
||||
model_name = "google/gemma-2-2b-it"
|
||||
q_model_name = "google/gemma-3-12b-it"
|
||||
peft_config = get_lora_config(
|
||||
model_name, r=8, target_modules=["down_proj"], lora_dropout=0.0
|
||||
)
|
||||
peft_config.lora_alpha = 16
|
||||
model, tokenizer = get_model_and_tokenizer(
|
||||
model_name,
|
||||
train=False,
|
||||
requires_grad=False,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
q_model, q_tokenizer = get_model_and_tokenizer(
|
||||
q_model_name,
|
||||
train=False,
|
||||
requires_grad=False,
|
||||
peft_config=peft_config,
|
||||
)
|
||||
|
||||
ds = load_and_process_dataset("pwc", split="train", num_proc=8)
|
||||
ctx = ds[0]["context"]
|
||||
inp = ds[1]["prompts"][0]
|
||||
# Build a simple context/input pair separated by a unique token sequence
|
||||
sep_text = SELF_QA_INTX
|
||||
# ctx = "# Provided Information\nMy name is Tan."
|
||||
# ctx = "# Provided Info"
|
||||
prompt = f"{ctx}\n\n{sep_text}\n\n{inp}"
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
encoded = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
# encoded = tokenizer(prompt, return_tensors="pt")
|
||||
encoded = {k: v.to(model.device) for k, v in encoded.items()}
|
||||
|
||||
prefix_tokens = CTX_AFFIXES[model_name]["prefix"]
|
||||
prefix_tokens = torch.tensor(prefix_tokens, dtype=torch.long)
|
||||
|
||||
sep_ids = (
|
||||
tokenizer(sep_text.strip("\n"), add_special_tokens=False, return_tensors="pt")
|
||||
.input_ids[0]
|
||||
.to(model.device)
|
||||
)
|
||||
|
||||
cd_model = CtxDistillModel(
|
||||
base_model=model,
|
||||
prefix_tokens=prefix_tokens,
|
||||
ctx_inp_sep_seq=sep_ids,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
update_iterations=100,
|
||||
q_model=q_model,
|
||||
q_tokenizer=q_tokenizer,
|
||||
tokenizer=tokenizer,
|
||||
reprompt_ctx=False,
|
||||
lora_save_dir="./saved_lora_adapter", # example save path
|
||||
save_after_distill=True,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(1):
|
||||
base_model_res = model.generate(
|
||||
input_ids=encoded["input_ids"],
|
||||
attention_mask=encoded["attention_mask"],
|
||||
max_new_tokens=256,
|
||||
do_sample=False,
|
||||
)
|
||||
print(
|
||||
f"Base model response:{tokenizer.batch_decode(base_model_res, skip_special_tokens=False)}"
|
||||
)
|
||||
|
||||
outputs = cd_model.generate(
|
||||
input_ids=encoded["input_ids"],
|
||||
attention_mask=encoded["attention_mask"],
|
||||
max_new_tokens=256,
|
||||
do_sample=False,
|
||||
)
|
||||
print(
|
||||
f"Student response: {tokenizer.batch_decode(outputs, skip_special_tokens=False)}"
|
||||
)
|
||||
158
src/ctx_to_lora/modeling/ctx_encoder.py
Normal file
158
src/ctx_to_lora/modeling/ctx_encoder.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
import logging
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ctx_to_lora.configs import CtxEncoderArguments
|
||||
from ctx_to_lora.utils import get_base_model
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def early_exit(base_model: PreTrainedModel, exit_layer: int):
|
||||
try:
|
||||
layers = base_model.layers
|
||||
base_model.layers = layers[:exit_layer]
|
||||
yield base_model
|
||||
finally:
|
||||
base_model.layers = layers
|
||||
|
||||
|
||||
@contextmanager
|
||||
def maybe_add_batch_dim(kwargs):
|
||||
try:
|
||||
batched_input = False
|
||||
batched_attn_mask = False
|
||||
if (
|
||||
"input_ids" in kwargs
|
||||
and kwargs["input_ids"] is not None
|
||||
and len(kwargs["input_ids"].shape) == 1
|
||||
):
|
||||
kwargs["input_ids"] = kwargs["input_ids"].unsqueeze(0)
|
||||
batched_input = True
|
||||
if (
|
||||
"attention_mask" in kwargs
|
||||
and kwargs["attention_mask"] is not None
|
||||
and isinstance(kwargs["attention_mask"], torch.Tensor)
|
||||
and len(kwargs["attention_mask"].shape) == 1
|
||||
):
|
||||
kwargs["attention_mask"] = kwargs["attention_mask"].unsqueeze(0)
|
||||
batched_attn_mask = True
|
||||
yield batched_input, batched_attn_mask
|
||||
finally:
|
||||
if batched_input:
|
||||
kwargs["input_ids"] = kwargs["input_ids"].squeeze(0)
|
||||
if batched_attn_mask:
|
||||
kwargs["attention_mask"] = kwargs["attention_mask"].squeeze(0)
|
||||
|
||||
|
||||
class EarlyExit(nn.Module):
|
||||
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
|
||||
super().__init__()
|
||||
base_model = get_base_model(base_model)
|
||||
if "gte" in base_model.config.name_or_path:
|
||||
base_model.encoder.layer = base_model.encoder.layer[: config.layer_idx]
|
||||
else:
|
||||
base_model.layers = base_model.layers[: config.layer_idx]
|
||||
|
||||
self.base_model = base_model
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.base_model.config
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, **kwargs):
|
||||
model_outputs = self.base_model(**kwargs)
|
||||
return model_outputs.last_hidden_state
|
||||
|
||||
|
||||
class EmbeddingOnly(nn.Module):
|
||||
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
|
||||
super().__init__()
|
||||
self.base_model = base_model
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.base_model.config
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, **kwargs):
|
||||
kwargs["output_hidden_states"] = True # Force output of hidden states
|
||||
outputs = self.base_model(**kwargs)
|
||||
# Return the embeddings only
|
||||
return outputs.hidden_states[0] # The first hidden state is the embeddings
|
||||
|
||||
|
||||
class PerLayerActivations(nn.Module):
|
||||
def __init__(self, base_model: PreTrainedModel, config: CtxEncoderArguments):
|
||||
super().__init__()
|
||||
self.keep_lm_head = getattr(config, "keep_lm_head", False)
|
||||
if not self.keep_lm_head:
|
||||
base_model = get_base_model(base_model) # remove lm head
|
||||
else:
|
||||
base_model.lm_head = nn.Identity()
|
||||
|
||||
# -1 to remove last attn block
|
||||
if config.ctx_encoder_last_layer is not None:
|
||||
last_layer = config.ctx_encoder_last_layer - 1
|
||||
else:
|
||||
last_layer = -1
|
||||
|
||||
if self.keep_lm_head:
|
||||
base_model.model.layers = base_model.model.layers[:last_layer]
|
||||
else:
|
||||
base_model.layers = base_model.layers[:last_layer]
|
||||
self.base_model = base_model
|
||||
|
||||
@property
|
||||
def config(self):
|
||||
return self.base_model.config
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.base_model.get_input_embeddings()
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.base_model.set_input_embeddings(value)
|
||||
|
||||
def get_output_embeddings(self):
|
||||
return self.base_model.get_output_embeddings()
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.base_model.set_output_embeddings(new_embeddings)
|
||||
|
||||
def set_decoder(self, decoder):
|
||||
self.base_model.set_decoder(decoder)
|
||||
|
||||
def get_decoder(self):
|
||||
return self.base_model.get_decoder()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, **kwargs):
|
||||
kwargs["output_hidden_states"] = True # Force output of hidden states
|
||||
outputs = self.base_model(**kwargs)
|
||||
# Return all layers' activations except the last one
|
||||
# from embeddings to the input of the last attn block
|
||||
# Shape: (batch_size, num_layers, seq_len, hidden_size)
|
||||
|
||||
if self.keep_lm_head:
|
||||
return outputs
|
||||
else:
|
||||
return torch.stack(outputs.hidden_states, dim=1)
|
||||
|
||||
|
||||
class CTX_ENCODER_TYPE(str, Enum):
|
||||
EARLY_EXIT = "early_exit"
|
||||
EMBED_ONLY = "embed_only"
|
||||
PER_LAYER_ACTIVATIONS = "per_layer_activations"
|
||||
|
||||
|
||||
CTX_ENCODER_CLS = {
|
||||
CTX_ENCODER_TYPE.EARLY_EXIT: EarlyExit,
|
||||
CTX_ENCODER_TYPE.EMBED_ONLY: EmbeddingOnly,
|
||||
CTX_ENCODER_TYPE.PER_LAYER_ACTIVATIONS: PerLayerActivations,
|
||||
}
|
||||
70
src/ctx_to_lora/modeling/generative_adapter.py
Normal file
70
src/ctx_to_lora/modeling/generative_adapter.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
import requests
|
||||
import torch
|
||||
|
||||
|
||||
def call_generate(
|
||||
input_txt: str,
|
||||
context_txt: str,
|
||||
window_size: int | None = None,
|
||||
max_new_tokens: int | None = None,
|
||||
host: str = "http://127.0.0.1:8989",
|
||||
timeout: int = 120,
|
||||
) -> torch.Tensor:
|
||||
"""Send the prompt to the API server and return the generated token tensor."""
|
||||
payload: dict[str, object] = {
|
||||
"input_txt": input_txt,
|
||||
"context_txt": context_txt,
|
||||
}
|
||||
if window_size is not None:
|
||||
payload["window_size"] = int(window_size)
|
||||
if max_new_tokens is not None:
|
||||
payload["max_new_tokens"] = int(max_new_tokens)
|
||||
|
||||
response = requests.post(f"{host}/generate", json=payload, timeout=timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if "output" not in data:
|
||||
raise ValueError(f"Unexpected response payload: {data}")
|
||||
return torch.tensor([data["output"]])
|
||||
|
||||
|
||||
def check_server_health(host: str = "http://127.0.0.1:8989", timeout: int = 60) -> None:
|
||||
"""Check if the API server is healthy and responding."""
|
||||
try:
|
||||
response = requests.get(f"{host}/health", timeout=timeout)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if data.get("status") != "ok":
|
||||
raise RuntimeError(f"Server is not healthy: {data}")
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise RuntimeError(
|
||||
f"Cannot connect to server at {host}. Is the server running?"
|
||||
) from e
|
||||
except requests.exceptions.Timeout as e:
|
||||
raise RuntimeError(f"Server health check timed out after {timeout}s") from e
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise RuntimeError(f"Server health check failed: {e}") from e
|
||||
print("Server is healthy.")
|
||||
|
||||
|
||||
class GenerativeAdapter(torch.nn.Module):
|
||||
def __init__(self, model, tokenizer):
|
||||
super().__init__()
|
||||
self.base_model = model # placeholder
|
||||
self.tokenizer = tokenizer
|
||||
check_server_health()
|
||||
|
||||
@property
|
||||
def generation_config(self):
|
||||
return self.base_model.generation_config
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
ctx_ids = kwargs["ctx_ids"]
|
||||
input_ids = kwargs["input_ids"]
|
||||
assert ctx_ids.shape[0] == 1
|
||||
assert input_ids.shape[0] == 1
|
||||
|
||||
context_txt = self.tokenizer.decode(ctx_ids[0])
|
||||
input_txt = self.tokenizer.decode(input_ids[0])
|
||||
outputs = call_generate(input_txt, context_txt)
|
||||
return outputs
|
||||
930
src/ctx_to_lora/modeling/hypernet.py
Normal file
930
src/ctx_to_lora/modeling/hypernet.py
Normal file
|
|
@ -0,0 +1,930 @@
|
|||
import logging
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from einops import unpack
|
||||
from einops.layers.torch import EinMix as Mix
|
||||
from jaxtyping import Float, Integer
|
||||
from peft import (
|
||||
LoraConfig,
|
||||
LoraRuntimeConfig,
|
||||
PeftConfig,
|
||||
PeftModel,
|
||||
)
|
||||
from peft.tuners._buffer_dict import BufferDict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
|
||||
from peft.utils import PeftType, TaskType
|
||||
from torch import Tensor, nn
|
||||
from transformers import (
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
)
|
||||
from transformers.modeling_outputs import ModelOutput
|
||||
from transformers.models.modernbert.modeling_modernbert import ModernBertModel
|
||||
|
||||
from ctx_to_lora.configs import (
|
||||
AggregatorArguments,
|
||||
CtxEncoderArguments,
|
||||
HypernetArguments,
|
||||
)
|
||||
from ctx_to_lora.data.processing import tokenize_ctx_text
|
||||
from ctx_to_lora.model_loading import (
|
||||
get_model,
|
||||
get_tokenizer,
|
||||
)
|
||||
from ctx_to_lora.modeling.aggregator import (
|
||||
AGGREGATOR_CLS,
|
||||
AggregatorConfig,
|
||||
get_aggregator_config,
|
||||
)
|
||||
from ctx_to_lora.modeling.ctx_encoder import CTX_ENCODER_CLS, CTX_ENCODER_TYPE
|
||||
from ctx_to_lora.modeling.lora_layer import (
|
||||
apply_lora_to_layers,
|
||||
lora_forward,
|
||||
lora_forward_packed,
|
||||
)
|
||||
from ctx_to_lora.modeling.lora_merger import combine_lora
|
||||
from ctx_to_lora.utils import (
|
||||
get_layers,
|
||||
get_num_layers,
|
||||
get_peft_in_out_features,
|
||||
get_peft_modules,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class HypernetConfig:
|
||||
latent_size: int
|
||||
use_light_weight_lora: bool
|
||||
light_weight_latent_size: int
|
||||
per_rank_gen: bool
|
||||
use_per_rank_bias: bool
|
||||
use_bias: bool
|
||||
per_layer_processing: bool
|
||||
use_token_mixing: bool
|
||||
num_pre_head_layers: int
|
||||
dropout_rate: float
|
||||
|
||||
lora_config: LoraConfig
|
||||
extra_modules: list[str] | None
|
||||
base_hidden_size: int
|
||||
|
||||
layer_indices: Iterable[int]
|
||||
feature_sizes: tuple[dict[str, int], dict[str, int]]
|
||||
aggregator_config: AggregatorConfig
|
||||
|
||||
|
||||
def get_hypernet_config(
|
||||
model: PreTrainedModel,
|
||||
ctx_encoder_model_config: PretrainedConfig,
|
||||
hypernet_args: HypernetArguments,
|
||||
aggregator_args: AggregatorArguments,
|
||||
ctx_encoder_args: CtxEncoderArguments,
|
||||
):
|
||||
num_modules = 0
|
||||
lora_config = getattr(model, "peft_config", None)
|
||||
if lora_config is not None:
|
||||
lora_config = lora_config["default"]
|
||||
num_modules += len(lora_config.target_modules)
|
||||
num_extra_modules = len(hypernet_args.extra_modules or [])
|
||||
indices = torch.arange(get_num_layers(model), device=model.device)
|
||||
return HypernetConfig(
|
||||
**vars(hypernet_args),
|
||||
base_hidden_size=model.config.hidden_size,
|
||||
lora_config=lora_config,
|
||||
layer_indices=indices,
|
||||
feature_sizes=get_peft_in_out_features(model, peft_config=lora_config),
|
||||
aggregator_config=get_aggregator_config(
|
||||
model,
|
||||
ctx_encoder_model_config,
|
||||
ctx_encoder_args.ctx_encoder_type == CTX_ENCODER_TYPE.PER_LAYER_ACTIVATIONS,
|
||||
hypernet_args.latent_size,
|
||||
num_modules,
|
||||
num_extra_modules,
|
||||
lora_config.r,
|
||||
hypernet_args.per_rank_gen,
|
||||
aggregator_args,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_init_peft_weights(model: PeftModel, peft_config: PeftConfig = None):
|
||||
if peft_config is None:
|
||||
peft_config = model.peft_config["default"]
|
||||
peft_weights = {module_name: dict() for module_name in peft_config.target_modules}
|
||||
adapter_name = "default"
|
||||
for module_name, module in model.named_modules():
|
||||
if not check_target_module_exists(peft_config, module_name):
|
||||
continue
|
||||
if not isinstance(module, BaseTunerLayer):
|
||||
continue
|
||||
# support just Linear layer for now
|
||||
# all modules should be a leave module that is Linear layer
|
||||
assert isinstance(module.base_layer, nn.Linear), (
|
||||
"all modules should be a leave module that is Linear layer"
|
||||
)
|
||||
|
||||
# this should always pass
|
||||
name = module_name.split(".")[-1]
|
||||
assert name in peft_config.target_modules
|
||||
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
if not isinstance(submodule, (nn.ModuleDict, nn.ParameterDict, BufferDict)):
|
||||
continue
|
||||
|
||||
if adapter_name not in submodule:
|
||||
continue
|
||||
|
||||
if submodule_name not in peft_weights[name]:
|
||||
peft_weights[name][submodule_name] = submodule[adapter_name]
|
||||
else:
|
||||
smod1 = peft_weights[name][submodule_name]
|
||||
smod2 = submodule[adapter_name]
|
||||
assert type(smod1) == type(smod2)
|
||||
|
||||
return peft_weights
|
||||
|
||||
|
||||
class ResMLPBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
output_size: int,
|
||||
dropout_rate: float = 0,
|
||||
):
|
||||
super().__init__()
|
||||
layers = []
|
||||
layers = [
|
||||
nn.LayerNorm(input_size),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(input_size, hidden_size),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(dropout_rate),
|
||||
nn.Linear(hidden_size, output_size),
|
||||
nn.LayerNorm(output_size),
|
||||
]
|
||||
self.mlp = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.mlp(x)
|
||||
|
||||
|
||||
class ResMLPBlockPerLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_layers: int,
|
||||
input_size: int,
|
||||
hidden_size: int,
|
||||
output_size: int,
|
||||
):
|
||||
super().__init__()
|
||||
layers = [
|
||||
nn.LayerNorm(input_size),
|
||||
Mix(
|
||||
"bs n_layers n_modules r d_in -> bs n_layers n_modules r d_hid",
|
||||
weight_shape="n_layers d_in d_hid",
|
||||
bias_shape="n_layers d_hid",
|
||||
n_layers=n_layers,
|
||||
d_in=input_size,
|
||||
d_hid=hidden_size,
|
||||
),
|
||||
nn.SiLU(),
|
||||
Mix(
|
||||
"bs n_layers n_modules r d_hid -> bs n_layers n_modules r d_out",
|
||||
weight_shape="n_layers d_hid d_out",
|
||||
bias_shape="n_layers d_out",
|
||||
n_layers=n_layers,
|
||||
d_hid=hidden_size,
|
||||
d_out=output_size,
|
||||
),
|
||||
nn.LayerNorm(output_size),
|
||||
]
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.layers(x)
|
||||
|
||||
|
||||
class HyperLoRA(nn.Module):
|
||||
def __init__(self, config: HypernetConfig):
|
||||
super().__init__()
|
||||
|
||||
# aggregator output [bs, n_layers, n_modules, feature_dim]
|
||||
# by mixing the pooled features with layer embs and module embs (for pooling)
|
||||
# or via a perceiver w/ bottleneck size = n_modules * n_layers
|
||||
self.config = config
|
||||
logger.debug(f"HyperLoRA config: {self.config}")
|
||||
self.iterative_mode = False
|
||||
self._init_model()
|
||||
|
||||
def _init_model(self):
|
||||
self.agg_config = self.config.aggregator_config
|
||||
self.aggregator = AGGREGATOR_CLS[self.agg_config.aggregator_type](
|
||||
**vars(self.agg_config)
|
||||
)
|
||||
|
||||
self.lora_config = self.config.lora_config
|
||||
self.r = self.lora_config.r
|
||||
|
||||
self.target_modules = (
|
||||
tuple(sorted(self.lora_config.target_modules)) if self.lora_config else None
|
||||
)
|
||||
self.num_modules = len(self.target_modules) if self.target_modules else 0
|
||||
self.extra_modules = (
|
||||
self.config.extra_modules if self.config.extra_modules else None
|
||||
)
|
||||
self.num_extra_modules = len(self.extra_modules) if self.extra_modules else 0
|
||||
self.layer_indices = self.config.layer_indices
|
||||
self.n_layers = len(self.layer_indices)
|
||||
|
||||
self.d_in, self.d_out = self.config.feature_sizes
|
||||
self.d_latent = self.config.latent_size
|
||||
|
||||
if self.target_modules:
|
||||
if self.config.per_layer_processing:
|
||||
layers = [
|
||||
ResMLPBlockPerLayer(
|
||||
self.n_layers,
|
||||
self.d_latent,
|
||||
self.d_latent * 4,
|
||||
self.d_latent,
|
||||
)
|
||||
for _ in range(self.config.num_pre_head_layers)
|
||||
]
|
||||
else:
|
||||
layers = [
|
||||
ResMLPBlock(
|
||||
input_size=self.config.latent_size,
|
||||
hidden_size=self.config.latent_size * 4,
|
||||
output_size=self.config.latent_size,
|
||||
dropout_rate=getattr(self.config, "dropout_rate", 0),
|
||||
)
|
||||
for _ in range(self.config.num_pre_head_layers)
|
||||
]
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.d_lora = max(self.d_in[m] + self.d_out[m] for m in self.target_modules)
|
||||
|
||||
self.bias_A = nn.ParameterDict(
|
||||
{
|
||||
m: nn.Parameter(
|
||||
torch.normal(
|
||||
0,
|
||||
0.2 / (self.d_in[m] * self.r) ** 0.5,
|
||||
(self.n_layers, self.r, self.d_in[m]),
|
||||
)
|
||||
)
|
||||
for m in self.target_modules
|
||||
}
|
||||
)
|
||||
self.bias_B = nn.ParameterDict(
|
||||
{
|
||||
m: nn.Parameter(torch.zeros((self.n_layers, self.r, self.d_out[m])))
|
||||
for m in self.target_modules
|
||||
}
|
||||
)
|
||||
|
||||
self.scaler_A = nn.ParameterDict(
|
||||
{
|
||||
m: nn.Parameter(torch.ones((1, self.n_layers, self.r, 1)))
|
||||
for m in self.target_modules
|
||||
}
|
||||
)
|
||||
self.scaler_B = nn.ParameterDict(
|
||||
{
|
||||
m: nn.Parameter(torch.zeros((1, self.n_layers, self.r, 1)))
|
||||
for m in self.target_modules
|
||||
}
|
||||
)
|
||||
|
||||
n_modules = len(self.target_modules)
|
||||
# have to do this otherwise doesnt work with adamw_torch_fused
|
||||
# has something to do with the bias shape (n_modules r d_lora)
|
||||
# when n_modules == 1, adamw_torch_fused complains about device/layout
|
||||
# but when n_modules > 1, it works fine
|
||||
if n_modules == 1:
|
||||
self.head = Mix(
|
||||
"bs n_layers n_modules r d_latent -> bs n_layers n_modules r d_lora",
|
||||
weight_shape="n_layers d_latent d_lora",
|
||||
bias_shape=None, # no bias
|
||||
n_layers=len(self.layer_indices),
|
||||
d_latent=self.config.latent_size,
|
||||
r=self.config.lora_config.r,
|
||||
d_lora=self.d_lora,
|
||||
)
|
||||
else:
|
||||
self.head = Mix(
|
||||
"bs n_layers n_modules r d_latent -> bs n_layers n_modules r d_lora",
|
||||
weight_shape="n_layers n_modules d_latent d_lora",
|
||||
bias_shape=None, # no bias
|
||||
n_layers=len(self.layer_indices),
|
||||
n_modules=n_modules,
|
||||
d_latent=self.config.latent_size,
|
||||
r=self.config.lora_config.r,
|
||||
d_lora=self.d_lora,
|
||||
)
|
||||
|
||||
def get_head_bias(self):
|
||||
bias_dict = dict()
|
||||
for module in self.target_modules:
|
||||
bias_A = self.bias_A[module]
|
||||
bias_B = self.bias_B[module]
|
||||
|
||||
bias_dict[module] = dict(A=bias_A, B=bias_B)
|
||||
return bias_dict
|
||||
|
||||
def _to_lora_dict(
|
||||
self, flat_loras: Float[Tensor, "bs n_layers n_modules r max_io_dim"]
|
||||
) -> dict[str, dict[str, Float[Tensor, "bs n_layers r _"]]]:
|
||||
if self.target_modules is None:
|
||||
return None
|
||||
# list of [bs, n_layers, r, in_d_outim]
|
||||
# and in_d_outim might vary across modules
|
||||
loras = unpack(
|
||||
flat_loras,
|
||||
[[] for _ in range(len(self.target_modules))],
|
||||
"bs n_layers * r max_io_dim",
|
||||
)
|
||||
|
||||
# dict of {module:
|
||||
# {A: [bs, n_layers, r, d_inim],
|
||||
# B: [bs, n_layers, r, d_outim]}}
|
||||
lora_dict = dict()
|
||||
for module, lora in zip(self.target_modules, loras):
|
||||
A, B = unpack(
|
||||
lora[..., : self.d_in[module] + self.d_out[module]],
|
||||
[[self.d_in[module]], [self.d_out[module]]],
|
||||
"bs n_layers r *",
|
||||
)
|
||||
|
||||
# apparently doing A * self.scaler_A is slow due to broadcasting
|
||||
A = torch.einsum("ijkl,ijkl->ijkl", A, self.scaler_A[module])
|
||||
B = torch.einsum("ijkl,ijkl->ijkl", B, self.scaler_B[module])
|
||||
|
||||
lora_dict[module] = dict(A=A, B=B)
|
||||
|
||||
return lora_dict
|
||||
|
||||
def _to_layernorm_dict(
|
||||
self, flat_layernorms: Float[Tensor, "bs n_layers n_modules hidden_size"]
|
||||
) -> dict[str, Float[Tensor, "bs n_layers hidden_size"]]:
|
||||
if self.extra_modules is None:
|
||||
return None
|
||||
layernorms = unpack(
|
||||
flat_layernorms,
|
||||
[[] for _ in range(len(self.extra_modules))],
|
||||
"bs n_layers * hidden_size",
|
||||
)
|
||||
return {k: v for k, v in zip(self.extra_modules, layernorms)}
|
||||
|
||||
def enable_iterative_mode(self, x: bool):
|
||||
self.iterative_mode = x
|
||||
self.aggregator.enable_iterative_mode(x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
features: Float[Tensor, "bs seq_len feature_dim"],
|
||||
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
position_ids: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
|
||||
):
|
||||
# [bs, n_layers, n_total_modules, r, feature_dim]
|
||||
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
||||
if self.aggregator.layer_to_layer and self.iterative_mode:
|
||||
# iterative inference
|
||||
# features: [bs num_layers seq_len feature_dim]
|
||||
bs, n_layers = features.shape[0:2]
|
||||
lora_emb = torch.empty(
|
||||
(bs, n_layers, self.num_modules, self.r, self.config.latent_size),
|
||||
device=features.device,
|
||||
)
|
||||
for i in range(n_layers):
|
||||
lora_emb[:, i], _ = self.aggregator(
|
||||
features[:, i], attn_mask, position_ids
|
||||
)
|
||||
|
||||
else:
|
||||
# batched inference
|
||||
lora_emb, _ = self.aggregator(features, attn_mask, position_ids)
|
||||
|
||||
# [bs, n_layers, n_modules, r, max_in_d_outim]
|
||||
flat_loras = None
|
||||
if self.target_modules:
|
||||
lora_emb = self.layers(lora_emb)
|
||||
norm = torch.norm(lora_emb, dim=-1, keepdim=True)
|
||||
norm_lora_emb = lora_emb / norm
|
||||
flat_loras = self.head(norm_lora_emb)
|
||||
|
||||
flat_layernorms = None
|
||||
|
||||
return flat_loras, flat_layernorms
|
||||
|
||||
def generate_weights(
|
||||
self,
|
||||
features: Float[Tensor, "bs seq_len feature_dim"],
|
||||
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
position_ids: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
):
|
||||
flat_loras, flat_layernorms = self.forward(features, attn_mask, position_ids)
|
||||
return self._to_lora_dict(flat_loras), self._to_layernorm_dict(flat_layernorms)
|
||||
|
||||
|
||||
class ModulatedPretrainedModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_model: PeftModel,
|
||||
hypernet_config: HypernetConfig,
|
||||
ctx_encoder_args: CtxEncoderArguments,
|
||||
use_base_input_as_ctx: bool = False,
|
||||
# need non-packed inputs for generation
|
||||
use_sequence_packing: bool = True,
|
||||
user_defined_scaling: float = 1,
|
||||
inp_compressor=None,
|
||||
):
|
||||
assert not use_base_input_as_ctx
|
||||
super().__init__()
|
||||
self.device = base_model.device
|
||||
self.peft_config = base_model.peft_config["default"]
|
||||
self.hypernet_config = hypernet_config
|
||||
self.ctx_encoder_args = ctx_encoder_args
|
||||
self.use_base_input_as_ctx = use_base_input_as_ctx
|
||||
self.use_sequence_packing = use_sequence_packing
|
||||
self.user_defined_scaling = user_defined_scaling
|
||||
self.inp_compressor = inp_compressor
|
||||
self.model_accepts_loss_kwargs = True
|
||||
self.generated_loras = None
|
||||
|
||||
self.register_module("base_model", base_model)
|
||||
self._init_model()
|
||||
self._bias_hyper_init()
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(
|
||||
cls,
|
||||
state_dict: dict,
|
||||
train: bool = True,
|
||||
base_model_kwargs: dict = None,
|
||||
use_flash_attn: bool = True,
|
||||
**kwargs: Any,
|
||||
):
|
||||
lora_config = state_dict["hypernet_config"].lora_config
|
||||
print(f"lora_config: {lora_config}")
|
||||
model_name_or_path = state_dict["base_model_name_or_path"]
|
||||
base_model = get_model(
|
||||
model_name_or_path,
|
||||
train=train,
|
||||
requires_grad=False,
|
||||
peft_config=lora_config,
|
||||
model_kwargs=base_model_kwargs,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
hypernet_config = state_dict["hypernet_config"]
|
||||
if getattr(hypernet_config, "num_pre_head_layers", None) is None:
|
||||
hypernet_config.num_pre_head_layers = 4
|
||||
if getattr(hypernet_config, "use_per_rank_bias", None) is None:
|
||||
hypernet_config.use_per_rank_bias = False
|
||||
if getattr(hypernet_config, "use_bias", None) is None:
|
||||
hypernet_config.use_bias = True
|
||||
ctx_encoder_args = state_dict["ctx_encoder_args"]
|
||||
model = cls(base_model, hypernet_config, ctx_encoder_args, **kwargs)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def patch_lora_forward(self):
|
||||
layers = get_layers(self.base_model)
|
||||
|
||||
lora_forward_fn = (
|
||||
lora_forward_packed if self.use_sequence_packing else lora_forward
|
||||
)
|
||||
for layer_idx in self.hypernet.layer_indices:
|
||||
for module_info in get_peft_modules(layers[layer_idx], self.peft_config):
|
||||
name = module_info["name"]
|
||||
module = module_info["module"]
|
||||
if getattr(module, "patched_forward", False):
|
||||
continue
|
||||
logger.debug(f"Applying LoRA forward to {name}")
|
||||
module.forward_orig = module.forward
|
||||
module.patched_forward = True
|
||||
module.forward = partial(
|
||||
lora_forward_fn,
|
||||
self=module,
|
||||
lora_dropout_p=self.peft_config.lora_dropout,
|
||||
scaling=self.peft_config.lora_alpha,
|
||||
)
|
||||
|
||||
def _init_model(self):
|
||||
# disable adapter of the base model
|
||||
# this only works with LoRA(?)
|
||||
# we disable to avoid peft lora computation
|
||||
self.base_model.disable_adapter_layers()
|
||||
|
||||
self.hypernet = (
|
||||
HyperLoRA(self.hypernet_config).to(self.device).to(torch.float32)
|
||||
)
|
||||
|
||||
self.patch_lora_forward()
|
||||
|
||||
ctx_model_name = self.ctx_encoder_args.ctx_encoder_model_name_or_path
|
||||
if ctx_model_name is None:
|
||||
ctx_model_name = self.base_model.config.name_or_path
|
||||
# use an explicit copy of the base model
|
||||
# for using with "modules_to_save"
|
||||
base_model_attn_impl = self.base_model.config._attn_implementation
|
||||
logger.debug(f"ctx_model_name: {ctx_model_name}")
|
||||
logger.debug(f"base_model.config._attn_implementation: {base_model_attn_impl}")
|
||||
encoder_model = get_model(
|
||||
ctx_model_name,
|
||||
train=self.base_model.training,
|
||||
requires_grad=False,
|
||||
use_flash_attn=base_model_attn_impl == "flash_attention_2",
|
||||
use_q_lora=self.ctx_encoder_args.quantize_ctx_encoder,
|
||||
)
|
||||
self.ctx_encoder = CTX_ENCODER_CLS[self.ctx_encoder_args.ctx_encoder_type](
|
||||
encoder_model, self.ctx_encoder_args
|
||||
)
|
||||
|
||||
# delegate to base_model
|
||||
@property
|
||||
def config(self):
|
||||
return self.base_model.config
|
||||
|
||||
@property
|
||||
def generation_config(self):
|
||||
return self.base_model.generation_config
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.base_model.vocab_size
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.base_model.get_input_embeddings()
|
||||
|
||||
@torch.no_grad()
|
||||
def _bias_hyper_init(self):
|
||||
if self.hypernet.extra_modules:
|
||||
self.hypernet.extra_head.weight.data[:] = 0
|
||||
self.hypernet.extra_head.bias.data[:] = 0
|
||||
if self.hypernet.target_modules:
|
||||
peft_weights = get_init_peft_weights(
|
||||
self.base_model, self.hypernet.lora_config
|
||||
)
|
||||
logger.debug(f"peft_weights: {peft_weights}")
|
||||
r = self.hypernet_config.lora_config.r
|
||||
nn.init.normal_(
|
||||
self.hypernet.head.weight,
|
||||
mean=0,
|
||||
std=0.5
|
||||
/ sqrt(self.hypernet.config.latent_size + self.hypernet.d_lora * r),
|
||||
# the head outputs per rank lora --> divide by r to scale down grad
|
||||
)
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
# we assume ctx_encoder and base model is frozen here
|
||||
if len([p for p in self.ctx_encoder.parameters() if p.requires_grad]):
|
||||
raise ValueError("ctx_encoder contains trainable parameters")
|
||||
if len([p for p in self.base_model.parameters() if p.requires_grad]):
|
||||
raise ValueError("base model contains trainable parameters")
|
||||
|
||||
state_dict = self.hypernet.state_dict(*args, **kwargs)
|
||||
state_dict["base_model_name_or_path"] = self.base_model.name_or_path
|
||||
state_dict["hypernet_config"] = self.hypernet_config
|
||||
state_dict["ctx_encoder_args"] = self.ctx_encoder_args
|
||||
return state_dict
|
||||
|
||||
def load_state_dict(self, state_dict: dict, *args, **kwargs):
|
||||
self.base_model_name_or_path = state_dict.pop("base_model_name_or_path")
|
||||
self.hypernet_config = state_dict.pop("hypernet_config")
|
||||
self.ctx_encoder_args = state_dict.pop("ctx_encoder_args")
|
||||
if self.base_model_name_or_path != self.base_model.name_or_path:
|
||||
raise ValueError(
|
||||
f"Base model name or path mismatch. "
|
||||
f"The base model given is: {self.base_model.name_or_path}, "
|
||||
f"but the loaded name is: {self.base_model_name_or_path}"
|
||||
)
|
||||
self._init_model()
|
||||
|
||||
def remove_compile_prefix(sd: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
COMPILED_PREFIX = "_orig_mod."
|
||||
for k in list(sd.keys()):
|
||||
if k.startswith(COMPILED_PREFIX):
|
||||
sd[k[len(COMPILED_PREFIX) :]] = sd.pop(k)
|
||||
return sd
|
||||
|
||||
load_result = self.hypernet.load_state_dict(
|
||||
remove_compile_prefix(state_dict),
|
||||
strict=True, # , *args, **kwargs
|
||||
)
|
||||
logger.info(f"load result: {load_result}")
|
||||
return load_result
|
||||
|
||||
def generate_weights(
|
||||
self,
|
||||
ctx_ids: Integer[Tensor, "bs ctx_len"],
|
||||
ctx_attn_mask: Integer[Tensor, "bs ctx_len"] | None = None,
|
||||
ctx_position_ids: Integer[Tensor, "bs ctx_len"] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
with torch.no_grad():
|
||||
ctx_encoder_kwargs = dict(
|
||||
input_ids=ctx_ids,
|
||||
attention_mask=ctx_attn_mask,
|
||||
position_ids=ctx_position_ids,
|
||||
)
|
||||
if isinstance(self.ctx_encoder.base_model, ModernBertModel):
|
||||
position_ids = ctx_position_ids.flatten()
|
||||
indices = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
# [bsz + 1]
|
||||
cu_seqlens = torch.cat(
|
||||
(
|
||||
indices[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(),
|
||||
device=position_ids.device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
)
|
||||
ctx_encoder_kwargs = dict(
|
||||
input_ids=ctx_ids.squeeze(0),
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=position_ids.max() + 1,
|
||||
attention_mask=-1,
|
||||
seq_len=-1,
|
||||
batch_size=-1,
|
||||
)
|
||||
|
||||
ctx_features = self.ctx_encoder(**ctx_encoder_kwargs, **kwargs)
|
||||
|
||||
if isinstance(self.ctx_encoder.base_model, ModernBertModel):
|
||||
ctx_features = ctx_features.unsqueeze(0)
|
||||
if self.user_defined_scaling == 1:
|
||||
return self.hypernet.generate_weights(
|
||||
ctx_features, ctx_attn_mask, ctx_position_ids
|
||||
)
|
||||
|
||||
lora_dict, _ = self.hypernet.generate_weights(
|
||||
ctx_features, ctx_attn_mask, ctx_position_ids
|
||||
)
|
||||
for module in lora_dict:
|
||||
lora_dict[module]["A"] = lora_dict[module]["A"] * self.user_defined_scaling
|
||||
lora_dict[module]["B"] = lora_dict[module]["B"] * self.user_defined_scaling
|
||||
return lora_dict, None
|
||||
|
||||
def enable_iterative_mode(self, x: bool):
|
||||
self.hypernet.enable_iterative_mode(x)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
ctx_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
ctx_attn_mask: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
ctx_position_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
|
||||
n_queries: Integer[Tensor, "n_ctx"] | None = None,
|
||||
return_generated_lora: bool | None = False,
|
||||
*model_inputs_args: Any,
|
||||
**model_inputs_kwargs: dict[str, Any],
|
||||
) -> tuple | ModelOutput:
|
||||
"""Forward pass of the modulated model."""
|
||||
generated_loras = None
|
||||
generated_layernorms = None
|
||||
if ctx_ids is None and not self.use_base_input_as_ctx:
|
||||
logger.warning(
|
||||
(
|
||||
"*" * 100,
|
||||
"\n\nNo ctx_features provided, using the base model for forward pass\n\n",
|
||||
"*" * 100,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
if self.use_base_input_as_ctx:
|
||||
ctx_ids = (
|
||||
model_inputs_kwargs["input_ids"]
|
||||
if "input_ids" in model_inputs_kwargs
|
||||
else model_inputs_args[0]
|
||||
)
|
||||
ctx_attn_mask = (
|
||||
model_inputs_kwargs["attention_mask"]
|
||||
if "attention_mask" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
ctx_position_ids = (
|
||||
model_inputs_kwargs["position_ids"]
|
||||
if "position_ids" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
generated_loras, generated_layernorms = self.generate_weights(
|
||||
ctx_ids, ctx_attn_mask, ctx_position_ids
|
||||
)
|
||||
|
||||
if generated_loras is not None:
|
||||
generated_loras = combine_lora(
|
||||
generated_loras,
|
||||
n_ctx_chunks,
|
||||
lora_bias=self.hypernet.get_head_bias()
|
||||
if self.hypernet.config.use_bias
|
||||
else None,
|
||||
)
|
||||
|
||||
# input_ids in model_inputs_kwargs contains only
|
||||
# prompt + response (for hypernet training)
|
||||
position_ids = (
|
||||
model_inputs_kwargs["position_ids"]
|
||||
if "position_ids" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
|
||||
if n_queries is None:
|
||||
if ctx_position_ids is None:
|
||||
n_queries = torch.ones(
|
||||
ctx_ids.shape[0], dtype=torch.int32, device=self.device
|
||||
)
|
||||
else:
|
||||
# quite redundant (we do cu_seqlens many places)
|
||||
# TODO: compute cu_seqlens here and propagate that
|
||||
n_queries = torch.ones(
|
||||
(ctx_position_ids == 0).sum(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
apply_lora_to_layers(
|
||||
self.base_model,
|
||||
self.hypernet.layer_indices,
|
||||
generated_loras,
|
||||
n_queries,
|
||||
position_ids,
|
||||
)
|
||||
model_outputs = self.base_model(*model_inputs_args, **model_inputs_kwargs)
|
||||
|
||||
if return_generated_lora:
|
||||
return model_outputs, (generated_loras, generated_layernorms)
|
||||
else:
|
||||
return model_outputs
|
||||
|
||||
def combine_lora(self, *args, **kwargs):
|
||||
# for timing
|
||||
return combine_lora(*args, **kwargs)
|
||||
|
||||
def apply_lora_to_layers(self, *args, **kwargs):
|
||||
# for timing
|
||||
return apply_lora_to_layers(*args, **kwargs)
|
||||
|
||||
# for simple api usage
|
||||
def internalize(self, ctx_str: str):
|
||||
ctx_tokenizer = get_tokenizer(self.ctx_encoder.base_model.name_or_path)
|
||||
ctx_ids = tokenize_ctx_text(dict(context=[ctx_str]), ctx_tokenizer)["ctx_ids"]
|
||||
return self._internalize_from_ids(torch.tensor(ctx_ids, device=self.device))
|
||||
|
||||
def _internalize_from_ids(
|
||||
self,
|
||||
ctx_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
ctx_attn_mask: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
ctx_position_ids: Integer[Tensor, "n_ctx ctx_len"] | None = None,
|
||||
):
|
||||
self.patch_lora_forward()
|
||||
if ctx_attn_mask is None and ctx_position_ids is None:
|
||||
assert ctx_ids.shape[0] == 1
|
||||
ctx_attn_mask = torch.ones_like(ctx_ids)
|
||||
generated_loras, generated_layernorms = self.generate_weights(
|
||||
ctx_ids, ctx_attn_mask, ctx_position_ids
|
||||
)
|
||||
self.generated_loras = generated_loras
|
||||
|
||||
def reset(self):
|
||||
self.generated_loras = None
|
||||
layers = get_layers(self.base_model)
|
||||
for layer_idx in self.hypernet.layer_indices:
|
||||
for module_info in get_peft_modules(layers[layer_idx], self.peft_config):
|
||||
name = module_info["name"]
|
||||
module = module_info["module"]
|
||||
logger.debug(f"Resetting forward for {name}")
|
||||
module.forward = module.forward_orig
|
||||
module.patched_forward = False
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
ctx_ids: Integer[Tensor, "n_chunks ctx_length"] | None = None,
|
||||
ctx_attn_mask: Integer[Tensor, "n_chunks ctx_length"] | None = None,
|
||||
ctx_position_ids: Integer[Tensor, "n_chunks ctx_length"] | None = None,
|
||||
n_ctx_chunks: Integer[Tensor, "n_ctx"] | None = None,
|
||||
n_queries: Integer[Tensor, "n_ctx"] | None = None,
|
||||
scalers: Float[Tensor, "n_ctx"] | None = None,
|
||||
bias_scaler: float | None = None,
|
||||
*model_inputs_args: Any,
|
||||
**model_inputs_kwargs: dict[str, Any],
|
||||
):
|
||||
generated_loras = None
|
||||
generated_layernorms = None
|
||||
if (
|
||||
ctx_ids is None
|
||||
and not self.generated_loras
|
||||
and not self.use_base_input_as_ctx
|
||||
):
|
||||
print(
|
||||
"*" * 100
|
||||
+ "\n\nNo ctx_ids provided, using the base model for generation\n\n"
|
||||
+ "*" * 100
|
||||
)
|
||||
elif ctx_ids is None and self.generated_loras:
|
||||
generated_loras = self.generated_loras
|
||||
if n_ctx_chunks is None:
|
||||
n_ctx_chunks = torch.tensor((1,), device=self.device)
|
||||
print(
|
||||
"*" * 100
|
||||
+ "\n\nUsing internalized LoRAs for generation\n\n"
|
||||
+ "*" * 100
|
||||
)
|
||||
else:
|
||||
if self.use_base_input_as_ctx:
|
||||
ctx_ids = (
|
||||
model_inputs_kwargs["input_ids"]
|
||||
if "input_ids" in model_inputs_kwargs
|
||||
else model_inputs_args[0]
|
||||
)
|
||||
ctx_attn_mask = (
|
||||
model_inputs_kwargs["attention_mask"]
|
||||
if "attention_mask" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
ctx_position_ids = (
|
||||
model_inputs_kwargs["position_ids"]
|
||||
if "position_ids" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
generated_loras, generated_layernorms = self.generate_weights(
|
||||
ctx_ids, ctx_attn_mask, ctx_position_ids
|
||||
)
|
||||
|
||||
if generated_loras is not None:
|
||||
generated_loras = self.combine_lora(
|
||||
generated_loras,
|
||||
n_ctx_chunks,
|
||||
lora_bias=self.hypernet.get_head_bias()
|
||||
if self.hypernet.config.use_bias
|
||||
else None,
|
||||
scalers=scalers,
|
||||
bias_scaler=bias_scaler,
|
||||
)
|
||||
|
||||
# apply lora hook to the base model
|
||||
# TODO: we dont this position_ids for generation?
|
||||
position_ids = (
|
||||
model_inputs_kwargs["position_ids"]
|
||||
if "position_ids" in model_inputs_kwargs
|
||||
else None
|
||||
)
|
||||
if n_queries is None:
|
||||
if ctx_position_ids is None:
|
||||
n_queries = torch.ones(
|
||||
model_inputs_kwargs["input_ids"].shape[0],
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
# quite redundant (we do cu_seqlens many places)
|
||||
# TODO: compute cu_seqlens here and propagate that
|
||||
n_queries = torch.ones(
|
||||
(ctx_position_ids == 0).sum(),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
apply_lora_to_layers(
|
||||
self.base_model,
|
||||
self.hypernet.layer_indices,
|
||||
generated_loras,
|
||||
n_queries,
|
||||
position_ids,
|
||||
)
|
||||
|
||||
model_outputs = self.base_model.generate(
|
||||
*model_inputs_args, **model_inputs_kwargs
|
||||
)
|
||||
return model_outputs
|
||||
|
||||
|
||||
# needed for loading model from checkpoint
|
||||
# see https://github.com/huggingface/transformers/pull/34632
|
||||
torch.serialization.add_safe_globals(
|
||||
[
|
||||
AggregatorConfig,
|
||||
LoraConfig,
|
||||
HypernetConfig,
|
||||
PeftType,
|
||||
TaskType,
|
||||
LoraRuntimeConfig,
|
||||
set, # for real?
|
||||
]
|
||||
)
|
||||
765
src/ctx_to_lora/modeling/idefics2.py
Normal file
765
src/ctx_to_lora/modeling/idefics2.py
Normal file
|
|
@ -0,0 +1,765 @@
|
|||
# Copyright 2024 the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""PyTorch Idefics2 model."""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.cache_utils import Cache
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.models.idefics2.configuration_idefics2 import Idefics2Config
|
||||
from transformers.utils import (
|
||||
add_start_docstrings,
|
||||
is_flash_attn_2_available,
|
||||
is_flash_attn_greater_or_equal_2_10,
|
||||
logging,
|
||||
)
|
||||
|
||||
if is_flash_attn_2_available():
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Idefics2PerceiverConfig(PretrainedConfig):
|
||||
r"""
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the perceiver block.
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
n_latents (`int`, *optional*, defaults to 64):
|
||||
Number of latent embeddings to resample ("compress") the input sequence to (usually < 128).
|
||||
resampler_depth (`int`, *optional*, defaults to 3):
|
||||
Depth of the Perceiver Resampler (Transformer w/ cross attention). Should be shallow (<= 3).
|
||||
n_heads (`int`, *optional*, defaults to 16):
|
||||
Number of heads in each Transformer block (for multi-headed self-attention).
|
||||
head_dim (`int`, *optional*, defaults to 96):
|
||||
Dimensionality of each head projection in the Transformer block.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
Number of key-value heads in the perceiver attention block.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
"""
|
||||
|
||||
model_type = "idefics2_perceiver"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
num_blocks: int,
|
||||
num_self_attn_per_block: int,
|
||||
shared_weights: bool,
|
||||
intermediate_size_factor: int,
|
||||
hidden_act="silu",
|
||||
hidden_size=4096,
|
||||
rms_norm_eps=1e-06,
|
||||
n_latents=64,
|
||||
n_heads=16,
|
||||
head_dim=128,
|
||||
num_key_value_heads=4,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_blocks = num_blocks
|
||||
self.num_self_attn_per_block = num_self_attn_per_block
|
||||
self.shared_weights = shared_weights
|
||||
|
||||
self.input_size = input_size
|
||||
self.intermediate_size_factor = intermediate_size_factor
|
||||
# for perceiver
|
||||
self.hidden_act = hidden_act
|
||||
self.hidden_size = hidden_size
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.n_latents = n_latents
|
||||
self.n_heads = n_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.attention_dropout = attention_dropout
|
||||
if self.num_key_value_heads > self.n_heads:
|
||||
raise ValueError(
|
||||
f"num_key_value_heads={self.num_key_value_heads} must be less than or equal to"
|
||||
f" n_heads={self.n_heads}"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Idefics2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
output_size: int,
|
||||
hidden_act: str,
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
||||
self.down_proj = nn.Linear(intermediate_size, output_size, bias=False)
|
||||
self.act_fn = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x):
|
||||
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
|
||||
|
||||
|
||||
IDEFICS2_START_DOCSTRING = r"""
|
||||
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
|
||||
etc.)
|
||||
|
||||
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
|
||||
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
|
||||
and behavior.
|
||||
|
||||
Parameters:
|
||||
config ([`Idefics2Config`] or [`Idefics2VisionConfig`]):
|
||||
Model configuration class with all the parameters of the model. Initializing with a config file does not
|
||||
load the weights associated with the model, only the configuration. Check out the
|
||||
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare Idefics2 Model outputting raw hidden-states without any specific head on top.",
|
||||
IDEFICS2_START_DOCSTRING,
|
||||
)
|
||||
class Idefics2PreTrainedModel(PreTrainedModel):
|
||||
config_class = Idefics2Config
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"Idefics2VisionAttention",
|
||||
"Idefics2MLP",
|
||||
"Idefics2PerceiverLayer",
|
||||
"Idefics2DecoderLayer",
|
||||
]
|
||||
_skip_keys_device_placement = "past_key_values"
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
_supports_cache_class = True
|
||||
|
||||
def _init_weights(self, module):
|
||||
std = (
|
||||
self.config.initializer_range
|
||||
if hasattr(self.config, "initializer_range")
|
||||
else 0.02
|
||||
)
|
||||
|
||||
if hasattr(module, "class_embedding"):
|
||||
module.class_embedding.data.normal_(mean=0.0, std=std)
|
||||
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.Embedding):
|
||||
module.weight.data.normal_(mean=0.0, std=std)
|
||||
if module.padding_idx is not None:
|
||||
module.weight.data[module.padding_idx].zero_()
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.repeat_kv
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Idefics2
|
||||
class Idefics2RMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Idefics2RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
||||
|
||||
|
||||
class Idefics2PerceiverAttention(nn.Module):
|
||||
def __init__(self, config, layer_idx: int | None = None) -> None:
|
||||
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.layer_idx = None
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_heads = config.n_heads
|
||||
self.head_dim = config.head_dim
|
||||
self.num_key_value_heads = config.num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.attention_dropout = config.attention_dropout
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, self.num_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.num_heads * self.head_dim, self.hidden_size, bias=False
|
||||
)
|
||||
|
||||
self.is_causal = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
context: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: tuple[torch.Tensor] | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
"""
|
||||
Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension!
|
||||
|
||||
Args:
|
||||
latents (`torch.Tensor`): Tensor of shape [bsz, n_latents, embed_dim] representing fixed length latents to compress to.
|
||||
context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample.
|
||||
attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask.
|
||||
position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token.
|
||||
past_key_value (`Tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states.
|
||||
output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights.
|
||||
use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching.
|
||||
"""
|
||||
bsz, q_len, _ = latents.size()
|
||||
kv_seq_len = q_len + context.size()[1]
|
||||
|
||||
hidden_states = torch.concat([context, latents], dim=-2)
|
||||
|
||||
query_states = self.q_proj(latents)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, kv_seq_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
past_key_value = getattr(self, "past_key_value", past_key_value)
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx
|
||||
)
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# NO LONGER EXIST Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
|
||||
# TODO cyril: modular
|
||||
class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
|
||||
"""
|
||||
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
|
||||
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
|
||||
flash attention and deal with padding tokens in case the input contains any of them.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
|
||||
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
|
||||
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
|
||||
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
|
||||
|
||||
# Ignore copy
|
||||
def forward(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
is_cross_attn: bool,
|
||||
context: torch.Tensor | None = None,
|
||||
attention_mask: torch.LongTensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: Cache | None = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
|
||||
bsz, q_len, _ = latents.size()
|
||||
query_states = self.q_proj(latents)
|
||||
if is_cross_attn:
|
||||
kv_inp = context
|
||||
else:
|
||||
kv_inp = latents
|
||||
|
||||
key_states = self.k_proj(kv_inp)
|
||||
value_states = self.v_proj(kv_inp)
|
||||
|
||||
# query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
|
||||
query_states = query_states.view(
|
||||
*latents.shape[:2], self.num_heads, self.head_dim
|
||||
)
|
||||
|
||||
key_states = key_states.view(
|
||||
*kv_inp.shape[:2], self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
*kv_inp.shape[:2], self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
||||
|
||||
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
|
||||
# therefore the input hidden states gets silently casted in float32. Hence, we need
|
||||
# cast them back in float16 just to be sure everything works as expected.
|
||||
input_dtype = query_states.dtype
|
||||
if input_dtype == torch.float32:
|
||||
if torch.is_autocast_enabled():
|
||||
target_dtype = torch.get_autocast_gpu_dtype()
|
||||
# Handle the case where the model is quantized
|
||||
elif hasattr(self.config, "_pre_quantization_dtype"):
|
||||
target_dtype = self.config._pre_quantization_dtype
|
||||
else:
|
||||
target_dtype = self.q_proj.weight.dtype
|
||||
|
||||
logger.warning_once(
|
||||
f"The input hidden states seems to be silently casted in float32, this might be related to"
|
||||
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
|
||||
f" {target_dtype}."
|
||||
)
|
||||
|
||||
query_states = query_states.to(target_dtype)
|
||||
key_states = key_states.to(target_dtype)
|
||||
value_states = value_states.to(target_dtype)
|
||||
|
||||
# Reashape to the expected shape for Flash Attention
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
|
||||
attn_output = _flash_attention_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
q_len,
|
||||
dropout=dropout_rate,
|
||||
position_ids=position_ids,
|
||||
sliding_window=None,
|
||||
is_causal=self.is_causal,
|
||||
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attn_output = attn_output.reshape(
|
||||
bsz, q_len, self.num_heads * self.head_dim
|
||||
).contiguous()
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
|
||||
# "eager": Idefics2PerceiverAttention,
|
||||
"flash_attention_2": Idefics2PerceiverFlashAttention2,
|
||||
}
|
||||
|
||||
|
||||
class Idefics2PerceiverLayer(nn.Module):
|
||||
def __init__(self, config, is_cross_attn: bool):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
self.n_latents = config.n_latents
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
self.is_cross_attn = is_cross_attn
|
||||
|
||||
self.input_latents_layernorm = Idefics2RMSNorm(
|
||||
self.hidden_size, eps=self.rms_norm_eps
|
||||
)
|
||||
self.input_context_layernorm = (
|
||||
Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
|
||||
if self.is_cross_attn
|
||||
else torch.nn.Identity()
|
||||
)
|
||||
self.self_attn = IDEFICS2_PERCEIVER_ATTENTION_CLASSES[
|
||||
config._attn_implementation
|
||||
](config)
|
||||
self.post_attention_layernorm = Idefics2RMSNorm(
|
||||
self.hidden_size, eps=self.rms_norm_eps
|
||||
)
|
||||
self.pre_ff_layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
|
||||
self.post_ff_layernorm = Idefics2RMSNorm(
|
||||
self.hidden_size, eps=self.rms_norm_eps
|
||||
)
|
||||
self.mlp = Idefics2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.hidden_size * 4,
|
||||
output_size=config.hidden_size,
|
||||
hidden_act=config.hidden_act,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latents: torch.Tensor,
|
||||
context: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
past_key_value: tuple[torch.Tensor] | None = None,
|
||||
output_attentions: bool | None = False,
|
||||
use_cache: bool | None = False,
|
||||
**kwargs,
|
||||
) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
|
||||
"""
|
||||
Args:
|
||||
latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
context (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
"""
|
||||
residual = latents
|
||||
|
||||
latents = self.input_latents_layernorm(latents)
|
||||
context = self.input_context_layernorm(context)
|
||||
|
||||
latents, self_attn_weights, present_key_value = self.self_attn(
|
||||
latents=latents,
|
||||
is_cross_attn=self.is_cross_attn,
|
||||
context=context,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
**kwargs,
|
||||
)
|
||||
latents = self.post_attention_layernorm(latents)
|
||||
latents = residual + latents
|
||||
residual = latents
|
||||
|
||||
# latents = self.post_attention_layernorm(latents)
|
||||
latents = self.pre_ff_layernorm(latents)
|
||||
latents = self.mlp(latents)
|
||||
latents = self.post_ff_layernorm(latents)
|
||||
latents = residual + latents
|
||||
|
||||
outputs = (latents,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
IDEFICS2_INPUTS_DOCSTRING = r"""
|
||||
Args:
|
||||
context (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_dim)`):
|
||||
The hidden states of the image after vision encoder and modality projection.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
[What are attention masks?](../glossary#attention-mask)
|
||||
"""
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"Idefics2 perceiver resampler model that performs `depth` blocks of cross-attention with a fixed ",
|
||||
"`n_latents` inputs to decrease embedding sequence length. The Resampler acts as a form of learned pooling and ",
|
||||
"is derived from [Perceiver: General Perception with Iterative Attention](https://arxiv.org/abs/2103.03206)",
|
||||
IDEFICS2_START_DOCSTRING,
|
||||
)
|
||||
class Idefics2PerceiverResampler(Idefics2PreTrainedModel):
|
||||
_supports_sdpa = False
|
||||
config_class = Idefics2PerceiverConfig
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
super().__init__(config)
|
||||
self.num_blocks = config.num_blocks
|
||||
self.num_self_attn_per_block = config.num_self_attn_per_block
|
||||
self.shared_weights = config.shared_weights
|
||||
self.hidden_size = config.hidden_size
|
||||
self.hidden_act = config.hidden_act
|
||||
self.n_latents = config.n_latents
|
||||
self.rms_norm_eps = config.rms_norm_eps
|
||||
|
||||
# Create Latents for Perceiver
|
||||
self.latents_q = nn.Parameter(torch.randn(self.n_latents, self.hidden_size))
|
||||
|
||||
# First block
|
||||
assert config.num_blocks > 0
|
||||
first_x_attn = [Idefics2PerceiverLayer(config, is_cross_attn=True)]
|
||||
first_self_attn_block = [
|
||||
Idefics2PerceiverLayer(config, is_cross_attn=False)
|
||||
for _ in range(config.num_self_attn_per_block)
|
||||
]
|
||||
|
||||
self.layers = nn.ModuleList(first_x_attn + first_self_attn_block)
|
||||
for layer_idx in range(1, config.num_blocks):
|
||||
# cross-attention at the beginning of each block
|
||||
if self.shared_weights:
|
||||
if layer_idx == 1:
|
||||
second_x_attn = Idefics2PerceiverLayer(config, is_cross_attn=True)
|
||||
x_attn = second_x_attn
|
||||
else:
|
||||
x_attn = Idefics2PerceiverLayer(config, is_cross_attn=True)
|
||||
self.layers.append(x_attn)
|
||||
|
||||
# self-attention
|
||||
for i in range(config.num_self_attn_per_block):
|
||||
if self.shared_weights:
|
||||
self_attn = first_self_attn_block[i]
|
||||
else:
|
||||
self_attn = Idefics2PerceiverLayer(config, is_cross_attn=False)
|
||||
self.layers.append(self_attn)
|
||||
|
||||
self.layernorm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
|
||||
|
||||
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
||||
assert self._use_flash_attention_2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
# seq embed -> bsz seq embed
|
||||
if position_ids is None:
|
||||
bsz = context.shape[0]
|
||||
else:
|
||||
# flattened packed sequence
|
||||
bsz = torch.where(position_ids == 0, 1, 0).sum()
|
||||
|
||||
latents = self.latents_q.unsqueeze(0).expand((bsz, *self.latents_q.size()))
|
||||
|
||||
attention_mask = (
|
||||
_prepare_4d_attention_mask(
|
||||
attention_mask, latents.dtype, tgt_len=self.n_latents
|
||||
)
|
||||
if not self._use_flash_attention_2
|
||||
else attention_mask
|
||||
)
|
||||
|
||||
compressed_context = latents
|
||||
|
||||
cu_seq_lens_q = torch.tensor(
|
||||
[self.n_latents] * (bsz + 1), device=context.device, dtype=torch.int32
|
||||
) * torch.arange(bsz + 1, device=context.device, dtype=torch.int32)
|
||||
max_length_q = self.n_latents
|
||||
# cu_seq_lens_k = None
|
||||
# max_length_k = None
|
||||
if attention_mask is not None:
|
||||
logger.warning_once("Using attention mask for resampler")
|
||||
context, _, cu_seq_lens_k, max_length_k, _ = unpad_input(
|
||||
context, attention_mask
|
||||
)
|
||||
context = context.unsqueeze(0)
|
||||
position_ids = True # goes down flash attn path that uses cu_seq_lens
|
||||
|
||||
elif position_ids is not None:
|
||||
logger.warning_once("Using position ids for resampler")
|
||||
|
||||
position_ids = position_ids.flatten()
|
||||
indices = torch.arange(
|
||||
position_ids.size(0), device=position_ids.device, dtype=torch.int32
|
||||
)
|
||||
# [bsz + 1]
|
||||
cu_seq_lens_k = torch.cat(
|
||||
(
|
||||
indices[position_ids == 0],
|
||||
torch.tensor(
|
||||
position_ids.size(),
|
||||
device=position_ids.device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
max_length_k = position_ids.max() + 1
|
||||
|
||||
else:
|
||||
raise ValueError("either position_ids or attention_mask is required")
|
||||
x_attn_kwargs = dict(
|
||||
position_ids=position_ids,
|
||||
cu_seq_lens_q=cu_seq_lens_q,
|
||||
cu_seq_lens_k=cu_seq_lens_k,
|
||||
max_length_q=max_length_q,
|
||||
max_length_k=max_length_k,
|
||||
)
|
||||
self_attn_position_ids = torch.arange(
|
||||
self.n_latents, device=context.device, dtype=torch.int32
|
||||
).repeat(1, bsz)
|
||||
self_attn_kwargs = dict(
|
||||
# attention_mask=self_attn_mask,
|
||||
position_ids=self_attn_position_ids,
|
||||
cu_seq_lens_q=cu_seq_lens_q,
|
||||
cu_seq_lens_k=cu_seq_lens_q,
|
||||
max_length_q=max_length_q,
|
||||
max_length_k=max_length_q,
|
||||
)
|
||||
for i, layer in enumerate(self.layers):
|
||||
inp_kwargs = dict(
|
||||
latents=compressed_context,
|
||||
context=context,
|
||||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
use_cache=False,
|
||||
)
|
||||
if layer.is_cross_attn:
|
||||
attn_kwargs = {**inp_kwargs, **x_attn_kwargs}
|
||||
else:
|
||||
attn_kwargs = {**inp_kwargs, **self_attn_kwargs}
|
||||
|
||||
layer_outputs = layer(**attn_kwargs)
|
||||
compressed_context = layer_outputs[0]
|
||||
|
||||
compressed_context = self.layernorm(compressed_context)
|
||||
|
||||
return compressed_context
|
||||
|
||||
|
||||
class Idefics2Perceiver(Idefics2PreTrainedModel):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_config: Idefics2PerceiverConfig,
|
||||
decoder_config: Idefics2PerceiverConfig,
|
||||
):
|
||||
super().__init__(encoder_config)
|
||||
self.modality_projection = Idefics2MLP(
|
||||
hidden_size=encoder_config.input_size,
|
||||
intermediate_size=encoder_config.intermediate_size_factor
|
||||
* encoder_config.input_size,
|
||||
output_size=encoder_config.hidden_size,
|
||||
hidden_act=encoder_config.hidden_act,
|
||||
)
|
||||
self.encoder = Idefics2PerceiverResampler._from_config(encoder_config)
|
||||
self.decoder = Idefics2PerceiverResampler._from_config(decoder_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
context: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
position_ids: torch.LongTensor | None = None,
|
||||
):
|
||||
if position_ids is None:
|
||||
bsz = context.shape[0]
|
||||
else:
|
||||
bsz = torch.where(position_ids == 0, 1, 0).sum()
|
||||
projected_inputs = self.modality_projection(context)
|
||||
|
||||
# [bsz, n_latents, dim]
|
||||
latents = self.encoder(
|
||||
context=projected_inputs,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
latent_position_ids = torch.arange(
|
||||
self.encoder.n_latents, device=context.device
|
||||
).unsqueeze(0)
|
||||
latent_position_ids = torch.tile(latent_position_ids, (1, bsz))
|
||||
outputs = self.decoder(latents, position_ids=latent_position_ids)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Idefics2Perceiver",
|
||||
]
|
||||
143
src/ctx_to_lora/modeling/llm_lingua.py
Normal file
143
src/ctx_to_lora/modeling/llm_lingua.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
import torch
|
||||
from llmlingua import PromptCompressor
|
||||
from torch import nn
|
||||
|
||||
from ctx_to_lora.data.definitions import CTX_AFFIXES
|
||||
|
||||
|
||||
class LLMLinguaModel(nn.Module):
|
||||
def __init__(self, model, tokenizer, compression_rate):
|
||||
super().__init__()
|
||||
self.base_model = model
|
||||
self.compressor = PromptCompressor(
|
||||
model_name="microsoft/llmlingua-2-xlm-roberta-large-meetingbank",
|
||||
use_llmlingua2=True, # Whether to use llmlingua-2
|
||||
)
|
||||
model_name = self.base_model.name_or_path
|
||||
self.register_buffer("prefix", torch.tensor(CTX_AFFIXES[model_name]["prefix"]))
|
||||
self.register_buffer("suffix", torch.tensor(CTX_AFFIXES[model_name]["suffix"]))
|
||||
self.len_prefix = len(self.prefix)
|
||||
self.len_suffix = len(self.suffix)
|
||||
self.tokenizer = tokenizer
|
||||
self.compression_rate = compression_rate
|
||||
|
||||
@property
|
||||
def generation_config(self):
|
||||
return self.base_model.generation_config
|
||||
|
||||
def compress(self, prompt_txt: str, rate: float):
|
||||
return self.compressor.compress_prompt(
|
||||
prompt_txt, rate=rate, force_tokens=["\n", "?"]
|
||||
)
|
||||
|
||||
def compress_tokens(self, input_ids, query_text):
|
||||
bs = input_ids.shape[0]
|
||||
txt = self.tokenizer.batch_decode(
|
||||
input_ids[:, self.len_prefix : -self.len_suffix]
|
||||
)
|
||||
q_start_idx = txt[0].rfind(query_text)
|
||||
ctx_txt = txt[0][:q_start_idx]
|
||||
q_txt = txt[0][q_start_idx:]
|
||||
compressed_txt = self.compress(ctx_txt, rate=self.compression_rate)
|
||||
compressed_ids = self.tokenizer(
|
||||
compressed_txt["compressed_prompt"] + "\n\n" + q_txt,
|
||||
return_attention_mask=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
)["input_ids"].to(self.base_model.device)
|
||||
|
||||
out = torch.cat(
|
||||
[self.prefix.expand(bs, -1), compressed_ids, self.suffix.expand(bs, -1)],
|
||||
dim=-1,
|
||||
)
|
||||
return out
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
# take ctx_ids
|
||||
# strip prefix and suffix
|
||||
# ctx_ids is left padded
|
||||
ctx_ids = kwargs["ctx_ids"][:, self.len_prefix : -self.len_suffix]
|
||||
# decode ctx_ids to ctx_txt
|
||||
ctx_txt = self.tokenizer.batch_decode(ctx_ids)
|
||||
compressed_ctx_txt = self.compress(ctx_txt, rate=self.compression_rate)
|
||||
compressed_ctx_ids = self.tokenizer(
|
||||
compressed_ctx_txt["compressed_prompt"] + "\n\n",
|
||||
return_attention_mask=False,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt",
|
||||
).to(self.base_model.device)
|
||||
|
||||
bs = ctx_ids.shape[0]
|
||||
ctx_inp_ids = torch.cat(
|
||||
[
|
||||
self.prefix.expand(bs, -1),
|
||||
compressed_ctx_ids["input_ids"],
|
||||
kwargs["input_ids"][:, self.len_prefix :],
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
attn_mask = torch.ones_like(ctx_inp_ids)
|
||||
for k in [
|
||||
"ctx_ids",
|
||||
"ctx_attn_mask",
|
||||
"n_ctx_chunks",
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
]:
|
||||
kwargs.pop(k, None)
|
||||
return self.base_model.generate(ctx_inp_ids, attention_mask=attn_mask, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ctx_to_lora.model_loading import get_model_and_tokenizer
|
||||
|
||||
model, tokenizer = get_model_and_tokenizer(
|
||||
"google/gemma-2-2b-it",
|
||||
train=False,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Demo: build wrapper, create a toy context + prompt, run compression + generation.
|
||||
device = "cuda"
|
||||
llm = LLMLinguaModel(model, tokenizer).to(device)
|
||||
|
||||
# Toy context and user prompt
|
||||
context_text = (
|
||||
"This is a short illustrative context about large language models and compression. "
|
||||
"They can reduce prompt length while preserving meaning."
|
||||
)
|
||||
user_prompt = (
|
||||
"Summarize the context in one concise sentence." # what we want model to do
|
||||
)
|
||||
|
||||
# Tokenize raw context (core) without special tokens
|
||||
core_ctx_ids = tokenizer.apply_chat_template(
|
||||
[[{"role": "user", "content": context_text}]],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_attention_mask=False,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
|
||||
# Input prompt tokens (what follows the contextual block)
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
[[{"role": "user", "content": user_prompt}]],
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_attention_mask=False,
|
||||
padding=False,
|
||||
truncation=False,
|
||||
return_tensors="pt",
|
||||
add_special_tokens=False,
|
||||
).to(device)
|
||||
|
||||
print("Original context length (chars):", len(context_text))
|
||||
|
||||
# Run generation (may vary depending on model capabilities)
|
||||
output_ids = llm.generate(ctx_ids=core_ctx_ids, input_ids=input_ids)
|
||||
# Decode only the tail beyond supplied input for readability
|
||||
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=False)
|
||||
print(f"\nFull generated text:\n{generated_text}")
|
||||
107
src/ctx_to_lora/modeling/lora_layer.py
Normal file
107
src/ctx_to_lora/modeling/lora_layer.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from operator import attrgetter
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import einsum
|
||||
from jaxtyping import Float, Integer
|
||||
from torch import Tensor
|
||||
|
||||
from ctx_to_lora.utils import get_layers
|
||||
|
||||
|
||||
def lora_forward(
|
||||
x: Float[Tensor, "tot_q seq_len d_in"],
|
||||
n_qs: Integer[Tensor, "n_ctx"],
|
||||
tot_q: int,
|
||||
A: Float[Tensor, "n_ctx r d_in"],
|
||||
B: Float[Tensor, "n_ctx r d_out"],
|
||||
lora_dropout_p: float,
|
||||
scaling: float,
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Float[Tensor, "tot_q seq_len d_out"]:
|
||||
# A: [n_ctx, r, d_in] -> [tot_q, r, d_in]
|
||||
A = A.repeat_interleave(n_qs, dim=0, output_size=tot_q)
|
||||
# B: [n_ctx, d_out, r] -> [tot_q, d_out, r]
|
||||
B = B.repeat_interleave(n_qs, dim=0, output_size=tot_q)
|
||||
|
||||
base_out = torch.nn.Linear.forward(self, x, *args, **kwargs)
|
||||
x = x.to(A.dtype)
|
||||
delta_x = F.dropout(x, p=lora_dropout_p, training=self.training)
|
||||
delta_x = einsum(A, delta_x, "tot_q r d_in, tot_q s_len d_in -> tot_q s_len r")
|
||||
delta_x = einsum(B, delta_x, "tot_q r d_out, tot_q s_len r -> tot_q s_len d_out")
|
||||
delta_x = delta_x * scaling
|
||||
return (base_out + delta_x).to(base_out.dtype)
|
||||
|
||||
|
||||
def lora_forward_packed(
|
||||
x: Float[Tensor, "1 tot_len d_in"],
|
||||
n_qs: Integer[Tensor, "n_ctx"],
|
||||
tot_q: int,
|
||||
seq_lens: Integer[Tensor, "tot_q"],
|
||||
tot_len: int,
|
||||
A: Float[Tensor, "n_ctx r d_in"],
|
||||
B: Float[Tensor, "n_ctx r d_out"],
|
||||
lora_dropout_p: float,
|
||||
scaling: float,
|
||||
self,
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Float[Tensor, "1 tot_len d_out"]:
|
||||
# bs of x should be 1 in this case
|
||||
base_out = torch.nn.Linear.forward(self, x, *args, **kwargs)
|
||||
x = x.to(A.dtype)
|
||||
delta_x = F.dropout(x, p=lora_dropout_p, training=self.training)
|
||||
repeated_A = A.repeat_interleave(n_qs, dim=0, output_size=tot_q)
|
||||
repeated_A = repeated_A.repeat_interleave(seq_lens, dim=0, output_size=tot_len)
|
||||
|
||||
repeated_B = B.repeat_interleave(n_qs, dim=0, output_size=tot_q)
|
||||
repeated_B = repeated_B.repeat_interleave(seq_lens, dim=0, output_size=tot_len)
|
||||
|
||||
delta_x = einsum(
|
||||
repeated_A, delta_x, "tot_len r d_in, bs tot_len d_in -> bs tot_len r"
|
||||
)
|
||||
delta_x = einsum(
|
||||
repeated_B, delta_x, "tot_len r d_out, bs tot_len r -> bs tot_len d_out"
|
||||
)
|
||||
delta_x = delta_x * scaling
|
||||
|
||||
return (base_out + delta_x).to(base_out.dtype)
|
||||
|
||||
|
||||
def apply_lora_to_layers(
|
||||
model: torch.nn.Module,
|
||||
layer_indices: Iterable[int],
|
||||
generated_loras: dict[str, dict[str, Float[Tensor, "n_ctx n_layers r _"]]],
|
||||
n_qs: Integer[Tensor, "n_ctx"],
|
||||
position_ids: Integer[Tensor, "bs seq_len"] = None,
|
||||
) -> None:
|
||||
layers = get_layers(model)
|
||||
if position_ids is not None:
|
||||
position_ids = position_ids.squeeze(0)
|
||||
seq_lens = position_ids[torch.where(position_ids == 0)[0][1:] - 1]
|
||||
seq_lens = torch.cat(
|
||||
[seq_lens, torch.tensor([position_ids[-1]], device=seq_lens.device)]
|
||||
)
|
||||
seq_lens += 1
|
||||
tot_len = seq_lens.sum().item()
|
||||
tot_q = n_qs.sum().item()
|
||||
for layer_idx in layer_indices:
|
||||
layer = layers[layer_idx]
|
||||
|
||||
for mname in generated_loras:
|
||||
if mname in ["q_proj", "k_proj", "v_proj", "o_proj", "qkv_proj"]:
|
||||
long_mname = f"self_attn.{mname}"
|
||||
elif mname in ["down_proj", "up_proj", "gate_proj"]:
|
||||
long_mname = f"mlp.{mname}"
|
||||
module = attrgetter(long_mname)(layer)
|
||||
A = generated_loras[mname]["A"][:, layer_idx]
|
||||
B = generated_loras[mname]["B"][:, layer_idx]
|
||||
module.forward = partial(module.forward, n_qs=n_qs, tot_q=tot_q, A=A, B=B)
|
||||
if position_ids is not None:
|
||||
module.forward = partial(
|
||||
module.forward, seq_lens=seq_lens, tot_len=tot_len
|
||||
)
|
||||
78
src/ctx_to_lora/modeling/lora_merger.py
Normal file
78
src/ctx_to_lora/modeling/lora_merger.py
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
"""
|
||||
Utilities for merging / aggregating LoRA adapters coming from multiple chunks.
|
||||
"""
|
||||
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from jaxtyping import Float, Integer
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def compute_rank(n_lora, rank):
|
||||
return (n_lora + 1) * rank
|
||||
|
||||
|
||||
def combine_lora(
|
||||
generated_loras: dict[str, dict[str, Tensor]],
|
||||
n_chunks: Integer[Tensor, "n_ctx"],
|
||||
lora_bias: dict[str, dict[str, Tensor]] | None = None,
|
||||
scalers: Float[Tensor, "n_ctx"] | None = None,
|
||||
bias_scaler: float | None = None,
|
||||
) -> dict[str, dict[str, Tensor]]:
|
||||
total_chunks = int(n_chunks.sum())
|
||||
if bias_scaler is None:
|
||||
bias_scaler = 1
|
||||
# Assume all modules share same base rank r
|
||||
first_module = next(iter(generated_loras))
|
||||
sampled_lora = generated_loras[first_module]["A"]
|
||||
base_rank = sampled_lora.shape[-2]
|
||||
device = sampled_lora.device
|
||||
dtype = sampled_lora.dtype
|
||||
max_rank_needed = int(compute_rank(n_chunks.max(), base_rank))
|
||||
|
||||
combined_loras: dict[str, dict[str, Tensor]] = {
|
||||
module: {"A": None, "B": None} for module in generated_loras.keys()
|
||||
}
|
||||
rank_dim = 2
|
||||
num_groups = len(n_chunks)
|
||||
rank_per_group = (n_chunks * base_rank).tolist()
|
||||
bias_tensor = None
|
||||
for module_name, module_loras in generated_loras.items():
|
||||
for matrix_key in ("A", "B"):
|
||||
if lora_bias is not None:
|
||||
bias_tensor = lora_bias[module_name][matrix_key]
|
||||
loras = module_loras[matrix_key]
|
||||
if (scalers is not None) and (matrix_key == "A"):
|
||||
loras = loras * scalers[:, None, None, None]
|
||||
|
||||
flat_loras = rearrange(
|
||||
loras, "tot_chunks n_layers r dim -> 1 n_layers (tot_chunks r) dim"
|
||||
)
|
||||
per_group_deltas = flat_loras.split(rank_per_group, dim=rank_dim)
|
||||
|
||||
combined_shape = [num_groups, *per_group_deltas[0].shape[1:]]
|
||||
combined_shape[rank_dim] = max_rank_needed
|
||||
|
||||
combined = torch.zeros(*combined_shape, device=device, dtype=dtype)
|
||||
|
||||
for g, deltas in enumerate(per_group_deltas):
|
||||
combined_rank = deltas.shape[rank_dim]
|
||||
|
||||
# Build slice pattern, slice up to combined_rank.
|
||||
# slice_pattern = [g, slice(None), slice(None), slice(None)]
|
||||
# slice_pattern[rank_dim] = slice(combined_rank)
|
||||
|
||||
combined[g, :, :combined_rank, :] = deltas
|
||||
|
||||
if bias_tensor is not None:
|
||||
# bias_slice_pattern = [g, slice(None), slice(None), slice(None)]
|
||||
# bias_slice_pattern[rank_dim] = slice(
|
||||
# combined_rank, combined_rank + base_rank
|
||||
# )
|
||||
combined[g, :, combined_rank : combined_rank + base_rank, :] = (
|
||||
bias_tensor * bias_scaler
|
||||
)
|
||||
|
||||
combined_loras[module_name][matrix_key] = combined
|
||||
|
||||
return combined_loras
|
||||
160
src/ctx_to_lora/modeling/text_to_lora.py
Normal file
160
src/ctx_to_lora/modeling/text_to_lora.py
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
from functools import partial
|
||||
|
||||
import torch
|
||||
from peft import PeftConfig
|
||||
from torch import nn
|
||||
|
||||
from ctx_to_lora.modeling.lora_layer import apply_lora_to_layers, lora_forward
|
||||
from ctx_to_lora.modeling.text_to_lora_impl import (
|
||||
embed_texts,
|
||||
get_layers,
|
||||
get_peft_config,
|
||||
load_hypermod,
|
||||
)
|
||||
from ctx_to_lora.utils import get_peft_modules
|
||||
|
||||
|
||||
class TextToLoRA(nn.Module):
|
||||
def __init__(self, model_name_or_path, prefix_tokens, device):
|
||||
assert model_name_or_path == "google/gemma-2-2b-it"
|
||||
super().__init__()
|
||||
hypermod_dir = "trained_t2l/gemma_2b_t2l"
|
||||
peft_config = get_peft_config(
|
||||
PeftConfig.from_json_file(f"{hypermod_dir}/adapter_config.json")
|
||||
)
|
||||
|
||||
# ours lora forward pass uses alpha directly
|
||||
peft_config.lora_alpha = peft_config.lora_alpha / peft_config.r
|
||||
self.prefix_tokens = prefix_tokens
|
||||
self.device = device
|
||||
(
|
||||
_,
|
||||
self.t2l_model,
|
||||
self.base_model,
|
||||
self.tokenizer,
|
||||
self.emb_model,
|
||||
self.emb_tokenizer,
|
||||
self.task_desc_format_fn,
|
||||
self.pooling_fn,
|
||||
) = load_hypermod(hypermod_dir, device)
|
||||
layer_indices = range(len(get_layers(self.base_model)))
|
||||
|
||||
self.layer_indices = torch.tensor(
|
||||
layer_indices, dtype=torch.long, device=device
|
||||
)
|
||||
# patch base model forward pass to use lora
|
||||
layers = get_layers(self.base_model)
|
||||
lora_forward_fn = lora_forward
|
||||
|
||||
for layer_idx in self.layer_indices:
|
||||
for module_info in get_peft_modules(layers[layer_idx], peft_config):
|
||||
module = module_info["module"]
|
||||
module.forward = partial(
|
||||
lora_forward_fn,
|
||||
self=module,
|
||||
lora_dropout_p=peft_config.lora_dropout,
|
||||
scaling=peft_config.lora_alpha,
|
||||
)
|
||||
|
||||
@property
|
||||
def generation_config(self):
|
||||
return self.base_model.generation_config
|
||||
|
||||
def generate_weights(self, ctx_txt: str):
|
||||
# generate loras
|
||||
ctx_emb = embed_texts(
|
||||
[ctx_txt],
|
||||
self.emb_model,
|
||||
self.emb_tokenizer,
|
||||
self.task_desc_format_fn,
|
||||
self.pooling_fn,
|
||||
self.device,
|
||||
)
|
||||
encoder_out = self.t2l_model.task_encoder(ctx_emb)
|
||||
encoded_task_emb = encoder_out["encoded_task_emb"].detach()
|
||||
|
||||
lora_A, lora_B = dict(), dict()
|
||||
lora_dict = dict()
|
||||
for target_module in self.t2l_model.target_modules:
|
||||
factorized_delta_w = self.t2l_model.get_delta_weights(
|
||||
self.layer_indices,
|
||||
target_module,
|
||||
encoded_task_emb.expand(self.layer_indices.shape[0], -1),
|
||||
factorized=True,
|
||||
)
|
||||
# lora_A[target_module]: [n_layers, r, d_in]
|
||||
# lora_A[target_module]: [n_layers, d_out, r]
|
||||
lora_A[target_module], lora_B[target_module] = factorized_delta_w
|
||||
|
||||
# convert to lora format used by lora_forward
|
||||
# dict of {module:
|
||||
# {A: [bs, n_layers, r, d_inim],
|
||||
# B: [bs, n_layers, r, d_outim]}}
|
||||
lora_dict[target_module] = dict(
|
||||
A=lora_A[target_module].unsqueeze(0),
|
||||
B=lora_B[target_module].transpose(-1, -2).unsqueeze(0),
|
||||
)
|
||||
return lora_dict
|
||||
|
||||
def generate(self, *args, **kwargs):
|
||||
ctx_ids_full = kwargs["ctx_ids"]
|
||||
ctx_txt = self.tokenizer.decode(
|
||||
ctx_ids_full[0, len(self.prefix_tokens) :], skip_special_tokens=True
|
||||
)
|
||||
generated_loras = self.generate_weights(ctx_txt)
|
||||
apply_lora_to_layers(
|
||||
self.base_model,
|
||||
self.layer_indices,
|
||||
generated_loras,
|
||||
n_qs=torch.tensor([1], device=self.device),
|
||||
position_ids=None,
|
||||
)
|
||||
kwargs.pop("ctx_ids", None)
|
||||
kwargs.pop("ctx_attn_mask", None)
|
||||
kwargs.pop("n_ctx_chunks", None)
|
||||
return self.base_model.generate(*args, **kwargs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from ctx_to_lora.data.definitions import CTX_AFFIXES
|
||||
from ctx_to_lora.data.processing import load_and_process_dataset
|
||||
|
||||
model_name = "google/gemma-2-2b-it"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
ds = load_and_process_dataset("pwc", split="train", num_proc=8)
|
||||
ctx = ds[0]["context"]
|
||||
inp = ds[1]["prompts"][0]
|
||||
ctx_ids = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": ctx}], return_tensors="pt", return_dict=True
|
||||
)
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": ctx + "\n\n" + inp}],
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
ctx_ids = {k: v.to("cuda") for k, v in ctx_ids.items()}
|
||||
input_ids = {k: v.to("cuda") for k, v in input_ids.items()}
|
||||
|
||||
prefix_tokens = CTX_AFFIXES[model_name]["prefix"]
|
||||
prefix_tokens = torch.tensor(prefix_tokens, dtype=torch.long)
|
||||
|
||||
t2l_model = TextToLoRA(
|
||||
model_name,
|
||||
prefix_tokens,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
for _ in range(1):
|
||||
outputs = t2l_model.generate(
|
||||
**input_ids,
|
||||
ctx_ids=ctx_ids["input_ids"],
|
||||
max_new_tokens=256,
|
||||
do_sample=False,
|
||||
)
|
||||
print(
|
||||
f"Student response: {tokenizer.batch_decode(outputs, skip_special_tokens=False)}"
|
||||
)
|
||||
1256
src/ctx_to_lora/modeling/text_to_lora_impl.py
Normal file
1256
src/ctx_to_lora/modeling/text_to_lora_impl.py
Normal file
File diff suppressed because it is too large
Load diff
54
src/ctx_to_lora/pooling.py
Normal file
54
src/ctx_to_lora/pooling.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from enum import Enum
|
||||
|
||||
import torch
|
||||
from jaxtyping import Float, Integer
|
||||
from torch import Tensor
|
||||
|
||||
POOL_FN = Enum("POOL_FN", ["MEAN", "MAX", "LAST_TOKEN"])
|
||||
|
||||
|
||||
def inv_bool_mask(m: Integer[Tensor, "bs seq_len"]) -> Integer[Tensor, "bs seq_len 1"]:
|
||||
return (m - 1).bool().unsqueeze(-1)
|
||||
|
||||
|
||||
def get_pooling_fn(pooling_type: str):
|
||||
if pooling_type == POOL_FN.MEAN:
|
||||
return mean_pool
|
||||
elif pooling_type == POOL_FN.MAX:
|
||||
return max_pool
|
||||
elif pooling_type == POOL_FN.LAST_TOKEN:
|
||||
return last_token_pool
|
||||
|
||||
|
||||
def mean_pool(
|
||||
features: Float[Tensor, "bs seq_len feature_dim"],
|
||||
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
) -> Float[Tensor, "bs 1 feature_dim"]:
|
||||
if attn_mask is not None:
|
||||
features = features.masked_fill(inv_bool_mask(attn_mask), 0)
|
||||
return features.sum(dim=1) / attn_mask.sum(dim=1).unsqueeze(1)
|
||||
|
||||
|
||||
def max_pool(
|
||||
features: Float[Tensor, "bs seq_len feature_dim"],
|
||||
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
) -> Float[Tensor, "bs 1 feature_dim"]:
|
||||
if attn_mask is not None:
|
||||
features = features.masked_fill(inv_bool_mask(attn_mask), -float("inf"))
|
||||
return torch.max(features, dim=1)
|
||||
|
||||
|
||||
def last_token_pool(
|
||||
features: Float[Tensor, "bs seq_len feature_dim"],
|
||||
attn_mask: Integer[Tensor, "bs seq_len"] | None = None,
|
||||
) -> Float[Tensor, "bs feature_dim"]:
|
||||
left_padding = attn_mask[:, -1].sum() == attn_mask.shape[0]
|
||||
if left_padding:
|
||||
return features[:, -1]
|
||||
else:
|
||||
sequence_lengths = attn_mask.sum(dim=1) - 1
|
||||
batch_size = features.shape[0]
|
||||
return features[
|
||||
torch.arange(batch_size, device=features.device),
|
||||
sequence_lengths,
|
||||
]
|
||||
0
src/ctx_to_lora/tracker/__init__.py
Normal file
0
src/ctx_to_lora/tracker/__init__.py
Normal file
353
src/ctx_to_lora/tracker/cuda_memory_tracker.py
Normal file
353
src/ctx_to_lora/tracker/cuda_memory_tracker.py
Normal file
|
|
@ -0,0 +1,353 @@
|
|||
"""Lightweight CUDA memory tracking utilities for measuring per-method peak memory usage.
|
||||
|
||||
Usage:
|
||||
x = SomeClass()
|
||||
add_memory_tracker(x.some_method, "some_method") # wraps the bound method in-place
|
||||
x.some_method(...)
|
||||
print_aggregate_memory_stats("some_method")
|
||||
|
||||
Design notes:
|
||||
- Mirrors the API of timer.py but records CUDA memory (peak increase in bytes) per call.
|
||||
- add_memory_tracker mutates the instance method with a wrapper (idempotent: double wrap avoided).
|
||||
- Global registry: { name: [int, ...] } storing per-call peak memory increase (bytes).
|
||||
- print_aggregate_memory_stats prints summary stats (count, total, mean, median, min, max, p95, last).
|
||||
- If CUDA or torch is unavailable, wrappers degrade gracefully (no measurements recorded).
|
||||
|
||||
Metrics collected per call (if CUDA available):
|
||||
- peak_increase_bytes: (torch.cuda.max_memory_allocated() - start_allocated)
|
||||
This captures the maximum additional memory pressure during the call.
|
||||
|
||||
Caveats:
|
||||
- Rapid allocate/free patterns entirely inside the call still reflect peak transient usage.
|
||||
- Asynchronous CUDA ops: we synchronize before and after to improve accuracy. This may
|
||||
slightly affect performance timings but is necessary for memory correctness.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from statistics import mean, median, stdev
|
||||
from typing import Any
|
||||
|
||||
try: # Optional dependency handling
|
||||
import torch # type: ignore
|
||||
except Exception: # pragma: no cover - torch absence path
|
||||
torch = None # type: ignore
|
||||
|
||||
# Global memory registry: name -> list of peak memory increases (bytes)
|
||||
MEMORY_REGISTRY: dict[str, list[int]] = {}
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
return bool(torch is not None and torch.cuda.is_available())
|
||||
|
||||
|
||||
def add_memory_tracker(func: Callable, name: str) -> None:
|
||||
"""Attach a CUDA memory tracking wrapper to a bound method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
A *bound* instance method (instance.method). Raises ValueError if unbound.
|
||||
name : str
|
||||
Key under which memory stats are recorded in MEMORY_REGISTRY.
|
||||
"""
|
||||
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
|
||||
if getattr(func, "__is_memory_wrapper__", False): # already wrapped
|
||||
return
|
||||
raise ValueError(
|
||||
"add_memory_tracker expects a bound method: call with instance.method"
|
||||
)
|
||||
|
||||
instance = func.__self__
|
||||
method_name = getattr(func, "__name__", None)
|
||||
if method_name is None:
|
||||
raise ValueError("Cannot determine method name for provided callable")
|
||||
|
||||
existing = getattr(instance, method_name, None)
|
||||
if getattr(existing, "__is_memory_wrapper__", False): # idempotent
|
||||
return
|
||||
|
||||
orig_bound = func
|
||||
|
||||
def tracked(*args: Any, **kwargs: Any): # noqa: D401 - simple wrapper
|
||||
if not _cuda_available():
|
||||
return orig_bound(*args, **kwargs)
|
||||
# Synchronize to get a clean baseline
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
start_alloc = torch.cuda.memory_allocated()
|
||||
try:
|
||||
return orig_bound(*args, **kwargs)
|
||||
finally:
|
||||
torch.cuda.synchronize()
|
||||
peak_alloc = torch.cuda.max_memory_allocated()
|
||||
peak_increase = peak_alloc - start_alloc
|
||||
# Record only if positive (avoid negative due to potential race, though improbable)
|
||||
if peak_increase < 0:
|
||||
peak_increase = 0
|
||||
MEMORY_REGISTRY.setdefault(name, []).append(int(peak_increase))
|
||||
|
||||
tracked.__name__ = method_name
|
||||
tracked.__doc__ = getattr(orig_bound, "__doc__")
|
||||
tracked.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
|
||||
tracked.__is_memory_wrapper__ = True # type: ignore[attr-defined]
|
||||
tracked.__wrapped__ = orig_bound # type: ignore[attr-defined]
|
||||
tracked.__memory_name__ = name # type: ignore[attr-defined]
|
||||
|
||||
setattr(instance, method_name, tracked)
|
||||
|
||||
|
||||
def _format_bytes(num_bytes: float) -> str:
|
||||
"""Human-readable byte formatting (base-2)."""
|
||||
if num_bytes < 1024:
|
||||
return f"{int(num_bytes):5d}B"
|
||||
units = ["KiB", "MiB", "GiB", "TiB"]
|
||||
value = float(num_bytes)
|
||||
for u in units:
|
||||
value /= 1024.0
|
||||
if value < 1024.0:
|
||||
return f"{value:8.3f}{u}"
|
||||
return f"{value:8.3f}PiB" # Extremely unlikely
|
||||
|
||||
|
||||
def compute_aggregate_memory_stats(
|
||||
name: str | None = None,
|
||||
) -> dict[str, dict[str, float]] | None:
|
||||
"""Compute aggregate CUDA memory statistics for specific trackers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific tracker name to compute stats for. If None, all trackers are computed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Dict[str, Dict[str, float]]]
|
||||
None if no data, else dict mapping tracker names to their stats.
|
||||
Each stats dict contains: count, total, mean, median, min, max, p95, last, std.
|
||||
"""
|
||||
if not MEMORY_REGISTRY:
|
||||
return None
|
||||
|
||||
keys = [name] if name else sorted(MEMORY_REGISTRY.keys())
|
||||
valid_keys = [k for k in keys if k in MEMORY_REGISTRY and MEMORY_REGISTRY[k]]
|
||||
if not valid_keys:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k in valid_keys:
|
||||
data = MEMORY_REGISTRY[k]
|
||||
data_sorted = sorted(data)
|
||||
cnt = len(data)
|
||||
total = sum(data)
|
||||
avg = mean(data)
|
||||
med = median(data)
|
||||
std = stdev(data) if cnt > 1 else 0.0
|
||||
mn = data_sorted[0]
|
||||
mx = data_sorted[-1]
|
||||
p95_index = int(0.95 * (cnt - 1))
|
||||
p95 = data_sorted[p95_index]
|
||||
last = data[-1]
|
||||
result[k] = {
|
||||
"count": float(cnt),
|
||||
"total": float(total),
|
||||
"mean": float(avg),
|
||||
"median": float(med),
|
||||
"min": float(mn),
|
||||
"max": float(mx),
|
||||
"p95": float(p95),
|
||||
"last": float(last),
|
||||
"std": float(std),
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def save_memory_stats_csv(file_path: str, name: str | None = None) -> None:
|
||||
"""Save aggregate CUDA memory statistics to a CSV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : str
|
||||
Path where the CSV file will be saved.
|
||||
name : Optional[str]
|
||||
Specific tracker name to export. If None, all trackers are exported.
|
||||
"""
|
||||
import csv
|
||||
|
||||
stats = compute_aggregate_memory_stats(name)
|
||||
if stats is None:
|
||||
raise ValueError("No memory data available to export")
|
||||
|
||||
with open(file_path, "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
"name",
|
||||
"count",
|
||||
"total",
|
||||
"mean",
|
||||
"median",
|
||||
"min",
|
||||
"max",
|
||||
"p95",
|
||||
"last",
|
||||
"std",
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for tracker_name, data in stats.items():
|
||||
row = {"name": tracker_name}
|
||||
row.update(data)
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
def print_aggregate_memory_stats(name: str | None = None) -> None:
|
||||
"""Print aggregate CUDA memory stats for one or all tracked names.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific name to report; if None, report all.
|
||||
"""
|
||||
stats = compute_aggregate_memory_stats(name)
|
||||
if stats is None:
|
||||
print("[mem] No memory data collected.")
|
||||
return
|
||||
|
||||
keys = [name] if name else sorted(MEMORY_REGISTRY.keys())
|
||||
missing = [k for k in keys if k not in MEMORY_REGISTRY or not MEMORY_REGISTRY[k]]
|
||||
if missing:
|
||||
print(f"[mem] No data for: {', '.join(missing)}")
|
||||
|
||||
if not stats:
|
||||
return
|
||||
|
||||
header = (
|
||||
f"{'name':20} {'count':>6} {'total':>12} {'mean':>12} {'median':>12} "
|
||||
f"{'min':>12} {'max':>12} {'p95':>12} {'std':>12} {'last':>12}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for k, data in stats.items():
|
||||
print(
|
||||
f"{k:20} {int(data['count']):6d} {_format_bytes(data['total']):>12} "
|
||||
f"{_format_bytes(data['mean']):>12} {_format_bytes(data['median']):>12} "
|
||||
f"{_format_bytes(data['min']):>12} {_format_bytes(data['max']):>12} "
|
||||
f"{_format_bytes(data['p95']):>12} {_format_bytes(data['std']):>12} "
|
||||
f"{_format_bytes(data['last']):>12}"
|
||||
)
|
||||
|
||||
|
||||
def compute_global_memory_stats() -> dict[str, float] | None:
|
||||
"""Compute aggregate stats across all recorded memory entries.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Dict[str, float]]
|
||||
None if no data; else dict with count, total, mean, median, min, max, p95, std.
|
||||
"""
|
||||
if not MEMORY_REGISTRY:
|
||||
return None
|
||||
all_values: list[int] = []
|
||||
for lst in MEMORY_REGISTRY.values():
|
||||
all_values.extend(lst)
|
||||
if not all_values:
|
||||
return None
|
||||
data_sorted = sorted(all_values)
|
||||
cnt = len(all_values)
|
||||
total = float(sum(all_values))
|
||||
avg = mean(all_values)
|
||||
med = median(all_values)
|
||||
std = stdev(all_values) if cnt > 1 else 0.0
|
||||
mn = float(data_sorted[0])
|
||||
mx = float(data_sorted[-1])
|
||||
p95_index = int(0.95 * (cnt - 1))
|
||||
p95 = float(data_sorted[p95_index])
|
||||
return {
|
||||
"count": float(cnt),
|
||||
"total": total,
|
||||
"mean": float(avg),
|
||||
"median": float(med),
|
||||
"min": mn,
|
||||
"max": mx,
|
||||
"p95": p95,
|
||||
"std": float(std),
|
||||
}
|
||||
|
||||
|
||||
def print_global_memory_stats() -> None:
|
||||
"""Pretty-print global stats across all memory registry entries."""
|
||||
stats = compute_global_memory_stats()
|
||||
if stats is None:
|
||||
print("[mem] No memory data collected.")
|
||||
return
|
||||
header = (
|
||||
f"{'scope':20} {'count':>6} {'total':>12} {'mean':>12} {'median':>12} "
|
||||
f"{'min':>12} {'max':>12} {'p95':>12} {'std':>12}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
print(
|
||||
f"{'<ALL>':20} {int(stats['count']):6d} {_format_bytes(stats['total']):>12} "
|
||||
f"{_format_bytes(stats['mean']):>12} {_format_bytes(stats['median']):>12} "
|
||||
f"{_format_bytes(stats['min']):>12} {_format_bytes(stats['max']):>12} "
|
||||
f"{_format_bytes(stats['p95']):>12} {_format_bytes(stats['std']):>12}"
|
||||
)
|
||||
|
||||
|
||||
def reset_memory_trackers() -> None:
|
||||
"""Clear all recorded memory tracking data."""
|
||||
MEMORY_REGISTRY.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MEMORY_REGISTRY",
|
||||
"add_memory_tracker",
|
||||
"compute_aggregate_memory_stats",
|
||||
"save_memory_stats_csv",
|
||||
"print_aggregate_memory_stats",
|
||||
"compute_global_memory_stats",
|
||||
"print_global_memory_stats",
|
||||
"reset_memory_trackers",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__": # Simple demonstration
|
||||
|
||||
class Demo:
|
||||
def __init__(self, device: str | None = None):
|
||||
self.device = device or ("cuda" if _cuda_available() else "cpu")
|
||||
|
||||
def allocate(self, n: int = 1_000_000) -> int:
|
||||
if not _cuda_available():
|
||||
# Fallback: just create a CPU tensor
|
||||
_ = [0] * n # noqa: F841
|
||||
return n
|
||||
import torch # local import to avoid mypy confusion
|
||||
|
||||
t = torch.empty(n, dtype=torch.float32, device=self.device)
|
||||
# Perform an op to ensure allocation
|
||||
t.uniform_() # noqa: F841
|
||||
return t.numel()
|
||||
|
||||
def noalloc(self): # method with negligible allocation
|
||||
return 42
|
||||
|
||||
demo = Demo()
|
||||
add_memory_tracker(demo.allocate, "alloc")
|
||||
add_memory_tracker(demo.noalloc, "noalloc")
|
||||
add_memory_tracker(demo.allocate, "alloc") # idempotent
|
||||
|
||||
for _ in range(5):
|
||||
demo.allocate(200_000)
|
||||
demo.noalloc()
|
||||
|
||||
print("\nAll memory stats:\n")
|
||||
print_aggregate_memory_stats()
|
||||
|
||||
print("\nSingle (alloc):\n")
|
||||
print_aggregate_memory_stats("alloc")
|
||||
|
||||
print("\nGlobal memory stats:\n")
|
||||
print_global_memory_stats()
|
||||
332
src/ctx_to_lora/tracker/timer.py
Normal file
332
src/ctx_to_lora/tracker/timer.py
Normal file
|
|
@ -0,0 +1,332 @@
|
|||
"""Lightweight timing utilities for attaching runtime measurement to object methods.
|
||||
|
||||
Usage:
|
||||
x = SomeClass()
|
||||
add_timer(x.some_method, "some_method") # wraps the bound method in-place
|
||||
x.some_method(...)
|
||||
print_aggregate_timer_stats("some_method")
|
||||
|
||||
Design notes:
|
||||
- add_timer mutates the instance by replacing the bound method with a timing wrapper.
|
||||
- Multiple calls to add_timer on the same (already wrapped) method are ignored to avoid double timing.
|
||||
- Global registry: { name: [float, ...] } storing individual call durations.
|
||||
- print_aggregate_timer_stats prints summary stats (count, total, mean, median, min, max, p95, last).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from statistics import mean, median, stdev
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
# Global timer registry: name -> list of durations (seconds)
|
||||
TIMER_REGISTRY: dict[str, list[float]] = {}
|
||||
|
||||
|
||||
class _TimerWrapperMarker:
|
||||
"""Mixin marker to identify already wrapped callables."""
|
||||
|
||||
__slots__ = ("__wrapped_name__",)
|
||||
|
||||
def __init__(self, wrapped_name: str):
|
||||
self.__wrapped_name__ = wrapped_name
|
||||
|
||||
|
||||
def add_timer(func: Callable, name: str) -> None:
|
||||
"""Attach a timing wrapper to a bound method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
A *bound* method (e.g., instance.method). If an unbound function is provided
|
||||
it will raise a ValueError (explicitness keeps behavior predictable).
|
||||
name : str
|
||||
Key under which durations are recorded in TIMER_REGISTRY.
|
||||
"""
|
||||
# Basic validation (permit already wrapped functions for idempotency even though
|
||||
# they sit as plain functions on the instance dict and thus lack __self__).
|
||||
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
|
||||
if getattr(func, "__is_timer_wrapper__", False): # already wrapped, no-op
|
||||
return
|
||||
raise ValueError("add_timer expects a bound method: call with instance.method")
|
||||
|
||||
instance = func.__self__ # The object instance
|
||||
method_name = getattr(func, "__name__", None)
|
||||
if method_name is None:
|
||||
raise ValueError("Cannot determine method name for provided callable")
|
||||
|
||||
# Prevent double-wrapping (idempotent behavior)
|
||||
existing = getattr(instance, method_name, None)
|
||||
if getattr(existing, "__is_timer_wrapper__", False): # Already wrapped
|
||||
return
|
||||
|
||||
orig_bound = func # capture original bound method
|
||||
|
||||
def timed(*args: Any, **kwargs: Any): # noqa: D401 - simple wrapper
|
||||
start = perf_counter()
|
||||
try:
|
||||
return orig_bound(*args, **kwargs)
|
||||
finally:
|
||||
elapsed = perf_counter() - start
|
||||
TIMER_REGISTRY.setdefault(name, []).append(elapsed)
|
||||
|
||||
# Mark wrapper to avoid double wrapping; preserve introspection hints.
|
||||
timed.__name__ = method_name
|
||||
timed.__doc__ = getattr(orig_bound, "__doc__")
|
||||
timed.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
|
||||
timed.__is_timer_wrapper__ = True # type: ignore[attr-defined]
|
||||
timed.__wrapped__ = orig_bound # type: ignore[attr-defined]
|
||||
timed.__timer_name__ = name # type: ignore[attr-defined]
|
||||
|
||||
setattr(instance, method_name, timed)
|
||||
|
||||
|
||||
def _format_seconds(sec: float) -> str:
|
||||
if sec >= 1:
|
||||
return f"{sec:8.3f}s"
|
||||
if sec >= 1e-3:
|
||||
return f"{sec * 1e3:8.3f}ms"
|
||||
if sec >= 1e-6:
|
||||
return f"{sec * 1e6:8.3f}µs"
|
||||
return f"{sec * 1e9:8.3f}ns"
|
||||
|
||||
|
||||
def compute_aggregate_timer_stats(
|
||||
name: str | None = None,
|
||||
) -> dict[str, dict[str, float]] | None:
|
||||
"""Compute aggregate timing statistics for specific timers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific timer name to compute stats for. If None, all timers are computed.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Dict[str, Dict[str, float]]]
|
||||
None if no data, else dict mapping timer names to their stats.
|
||||
Each stats dict contains: count, total, mean, median, min, max, p95, last, std.
|
||||
"""
|
||||
if not TIMER_REGISTRY:
|
||||
return None
|
||||
|
||||
keys = [name] if name else sorted(TIMER_REGISTRY.keys())
|
||||
valid_keys = [k for k in keys if k in TIMER_REGISTRY and TIMER_REGISTRY[k]]
|
||||
if not valid_keys:
|
||||
return None
|
||||
|
||||
result = {}
|
||||
for k in valid_keys:
|
||||
data = TIMER_REGISTRY[k]
|
||||
data_sorted = sorted(data)
|
||||
cnt = len(data)
|
||||
total = sum(data)
|
||||
avg = mean(data)
|
||||
med = median(data)
|
||||
std = stdev(data) if cnt > 1 else 0.0
|
||||
mn = data_sorted[0]
|
||||
mx = data_sorted[-1]
|
||||
p95_index = int(0.95 * (cnt - 1))
|
||||
p95 = data_sorted[p95_index]
|
||||
last = data[-1]
|
||||
result[k] = {
|
||||
"count": float(cnt),
|
||||
"total": total,
|
||||
"mean": avg,
|
||||
"median": med,
|
||||
"min": mn,
|
||||
"max": mx,
|
||||
"p95": p95,
|
||||
"last": last,
|
||||
"std": std,
|
||||
}
|
||||
return result
|
||||
|
||||
|
||||
def save_timer_stats_csv(file_path: str, name: str | None = None) -> None:
|
||||
"""Save aggregate timing statistics to a CSV file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : str
|
||||
Path where the CSV file will be saved.
|
||||
name : Optional[str]
|
||||
Specific timer name to export. If None, all timers are exported.
|
||||
"""
|
||||
import csv
|
||||
|
||||
stats = compute_aggregate_timer_stats(name)
|
||||
if stats is None:
|
||||
raise ValueError("No timing data available to export")
|
||||
|
||||
with open(file_path, "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
"name",
|
||||
"count",
|
||||
"total",
|
||||
"mean",
|
||||
"median",
|
||||
"min",
|
||||
"max",
|
||||
"p95",
|
||||
"last",
|
||||
"std",
|
||||
]
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
|
||||
for timer_name, data in stats.items():
|
||||
row = {"name": timer_name}
|
||||
row.update(data)
|
||||
writer.writerow(row)
|
||||
|
||||
|
||||
def print_aggregate_timer_stats(name: str | None = None) -> None:
|
||||
"""Print aggregate timing statistics.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific timer name to report. If None, all timers are reported.
|
||||
"""
|
||||
stats = compute_aggregate_timer_stats(name)
|
||||
if stats is None:
|
||||
print("[timer] No timing data collected.")
|
||||
return
|
||||
|
||||
keys = [name] if name else sorted(TIMER_REGISTRY.keys())
|
||||
missing = [k for k in keys if k not in TIMER_REGISTRY or not TIMER_REGISTRY[k]]
|
||||
if missing:
|
||||
print(f"[timer] No data for: {', '.join(missing)}")
|
||||
|
||||
if not stats:
|
||||
return
|
||||
|
||||
header = (
|
||||
f"{'name':20} {'count':>6} {'total':>10} {'mean':>10} {'median':>10} "
|
||||
f"{'min':>10} {'max':>10} {'p95':>10} {'std':>10} {'last':>10}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
|
||||
for k, data in stats.items():
|
||||
print(
|
||||
f"{k:20} {int(data['count']):6d} {_format_seconds(data['total']):>10} "
|
||||
f"{_format_seconds(data['mean']):>10} {_format_seconds(data['median']):>10} "
|
||||
f"{_format_seconds(data['min']):>10} {_format_seconds(data['max']):>10} "
|
||||
f"{_format_seconds(data['p95']):>10} {_format_seconds(data['std']):>10} "
|
||||
f"{_format_seconds(data['last']):>10}"
|
||||
)
|
||||
|
||||
|
||||
def compute_global_timer_stats() -> dict[str, float] | None:
|
||||
"""Compute aggregate statistics across all recorded timer values.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Dict[str, float]]
|
||||
None if no data recorded, else a dict with keys:
|
||||
count, total, mean, median, min, max, p95, std.
|
||||
"""
|
||||
if not TIMER_REGISTRY:
|
||||
return None
|
||||
all_values: list[float] = []
|
||||
for lst in TIMER_REGISTRY.values():
|
||||
all_values.extend(lst)
|
||||
if not all_values:
|
||||
return None
|
||||
data_sorted = sorted(all_values)
|
||||
cnt = len(all_values)
|
||||
total = sum(all_values)
|
||||
avg = mean(all_values)
|
||||
med = median(all_values)
|
||||
std = stdev(all_values) if cnt > 1 else 0.0
|
||||
mn = data_sorted[0]
|
||||
mx = data_sorted[-1]
|
||||
p95_index = int(0.95 * (cnt - 1))
|
||||
p95 = data_sorted[p95_index]
|
||||
return {
|
||||
"count": float(cnt), # keep uniform numeric type
|
||||
"total": total,
|
||||
"mean": avg,
|
||||
"median": med,
|
||||
"min": mn,
|
||||
"max": mx,
|
||||
"p95": p95,
|
||||
"std": std,
|
||||
}
|
||||
|
||||
|
||||
def print_global_timer_stats() -> None:
|
||||
"""Pretty-print global aggregate stats across all timer entries."""
|
||||
stats = compute_global_timer_stats()
|
||||
if stats is None:
|
||||
print("[timer] No timing data collected.")
|
||||
return
|
||||
header = (
|
||||
f"{'scope':20} {'count':>6} {'total':>10} {'mean':>10} {'median':>10} "
|
||||
f"{'min':>10} {'max':>10} {'p95':>10} {'std':>10}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * len(header))
|
||||
print(
|
||||
f"{'<ALL>':20} {int(stats['count']):6d} {_format_seconds(stats['total']):>10} "
|
||||
f"{_format_seconds(stats['mean']):>10} {_format_seconds(stats['median']):>10} "
|
||||
f"{_format_seconds(stats['min']):>10} {_format_seconds(stats['max']):>10} "
|
||||
f"{_format_seconds(stats['p95']):>10} {_format_seconds(stats['std']):>10}"
|
||||
)
|
||||
|
||||
|
||||
def reset_timers() -> None:
|
||||
"""Reset (clear) all recorded timing data."""
|
||||
TIMER_REGISTRY.clear()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TIMER_REGISTRY",
|
||||
"add_timer",
|
||||
"compute_aggregate_timer_stats",
|
||||
"save_timer_stats_csv",
|
||||
"print_aggregate_timer_stats",
|
||||
"compute_global_timer_stats",
|
||||
"print_global_timer_stats",
|
||||
"reset_timers",
|
||||
]
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Example usage / simple self-test.
|
||||
import random
|
||||
import time
|
||||
|
||||
class Demo:
|
||||
def f1(self, n: int = 20_000) -> int:
|
||||
# CPU-bound work
|
||||
s = 0
|
||||
for i in range(n):
|
||||
s += i * i
|
||||
return s
|
||||
|
||||
def f2(self) -> None:
|
||||
# Simulate I/O or waiting
|
||||
time.sleep(random.uniform(0.001, 0.003))
|
||||
|
||||
demo = Demo()
|
||||
|
||||
# Attach timers (idempotent: calling again does nothing harmful)
|
||||
add_timer(demo.f1, "f1")
|
||||
add_timer(demo.f2, "f2")
|
||||
add_timer(demo.f1, "f1") # demonstrate double-wrap prevention
|
||||
|
||||
for _ in range(5):
|
||||
demo.f1(10_000)
|
||||
demo.f2()
|
||||
|
||||
print("\nAll timers:\n")
|
||||
print_aggregate_timer_stats()
|
||||
|
||||
print("\nSingle timer (f1):\n")
|
||||
print_aggregate_timer_stats("f1")
|
||||
|
||||
print("\nGlobal aggregated stats:\n")
|
||||
print_global_timer_stats()
|
||||
303
src/ctx_to_lora/tracker/tracker.py
Normal file
303
src/ctx_to_lora/tracker/tracker.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
"""Unified tracking interface combining timing and CUDA memory usage.
|
||||
|
||||
Primary API
|
||||
-----------
|
||||
add_tracker(bound_method, name)
|
||||
Wraps a bound instance method so that each invocation records:
|
||||
- wall-clock duration (seconds) in timer.TIMER_REGISTRY[name]
|
||||
- CUDA peak memory increase (bytes) in cuda_memory_tracker.MEMORY_REGISTRY[name]
|
||||
(only if CUDA + torch available; otherwise memory list may stay empty / absent)
|
||||
|
||||
print_tracker_stats(name=None)
|
||||
Convenience printer that delegates to time + memory aggregate printers.
|
||||
|
||||
Design
|
||||
------
|
||||
We implement a single wrapper (instead of nesting the individual timer & memory
|
||||
wrappers) to avoid multiple layers of indirection and to ensure the measured
|
||||
CUDA memory footprint reflects only the original method's body (excluding the
|
||||
separate timing wrapper's slight overhead). The wrapper is idempotent: repeated
|
||||
calls to add_tracker on the same method are ignored.
|
||||
|
||||
This file depends on sibling modules:
|
||||
- tracker.timer
|
||||
- tracker.cuda_memory_tracker
|
||||
|
||||
Both registries remain the single source of truth; no additional registry is introduced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from time import perf_counter
|
||||
from typing import Any
|
||||
|
||||
# Support both package (relative) and direct script execution.
|
||||
try: # Package / normal import path
|
||||
from .cuda_memory_tracker import ( # type: ignore
|
||||
MEMORY_REGISTRY,
|
||||
compute_aggregate_memory_stats,
|
||||
print_aggregate_memory_stats,
|
||||
print_global_memory_stats,
|
||||
reset_memory_trackers,
|
||||
save_memory_stats_csv,
|
||||
)
|
||||
from .timer import ( # type: ignore
|
||||
TIMER_REGISTRY,
|
||||
compute_aggregate_timer_stats,
|
||||
print_aggregate_timer_stats,
|
||||
print_global_timer_stats,
|
||||
reset_timers,
|
||||
save_timer_stats_csv,
|
||||
)
|
||||
except Exception: # pragma: no cover - fallback when executed directly
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
_this_file = pathlib.Path(__file__).resolve()
|
||||
# project root is two levels up from tracker/ (i.e., .../src)
|
||||
_src_root = _this_file.parents[2]
|
||||
if str(_src_root) not in sys.path:
|
||||
sys.path.insert(0, str(_src_root))
|
||||
try:
|
||||
from ctx_to_lora.tracker.cuda_memory_tracker import ( # type: ignore
|
||||
MEMORY_REGISTRY,
|
||||
compute_aggregate_memory_stats,
|
||||
print_aggregate_memory_stats,
|
||||
print_global_memory_stats,
|
||||
reset_memory_trackers,
|
||||
save_memory_stats_csv,
|
||||
)
|
||||
from ctx_to_lora.tracker.timer import ( # type: ignore
|
||||
TIMER_REGISTRY,
|
||||
compute_aggregate_timer_stats,
|
||||
print_aggregate_timer_stats,
|
||||
print_global_timer_stats,
|
||||
reset_timers,
|
||||
save_timer_stats_csv,
|
||||
)
|
||||
except Exception as e: # If still failing, raise a clearer error.
|
||||
raise ImportError(
|
||||
f"Failed to import tracking dependencies; ensure project root on PYTHONPATH. Original: {e}"
|
||||
)
|
||||
|
||||
try: # Optional torch import (lazy fallback if unavailable)
|
||||
import torch # type: ignore
|
||||
except Exception: # pragma: no cover - torch absence path
|
||||
torch = None # type: ignore
|
||||
|
||||
__all__ = [
|
||||
"add_tracker",
|
||||
"compute_tracker_stats",
|
||||
"save_tracker_stats_csv",
|
||||
"print_tracker_stats",
|
||||
"print_global_tracker_stats",
|
||||
"reset_trackers",
|
||||
]
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
return bool(torch is not None and torch.cuda.is_available())
|
||||
|
||||
|
||||
def add_tracker(func: Callable, name: str) -> None:
|
||||
"""Attach a combined time + CUDA memory tracking wrapper to a bound method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
A bound instance method (instance.method). Raises ValueError if unbound.
|
||||
name : str
|
||||
Registry key under which metrics are stored.
|
||||
"""
|
||||
if not hasattr(func, "__self__") or getattr(func, "__self__") is None:
|
||||
# Permit idempotent re-calls if already wrapped.
|
||||
if getattr(func, "__is_tracker_wrapper__", False):
|
||||
return
|
||||
raise ValueError(
|
||||
"add_tracker expects a bound method: call with instance.method"
|
||||
)
|
||||
|
||||
instance = func.__self__ # underlying object
|
||||
method_name = getattr(func, "__name__", None)
|
||||
if method_name is None:
|
||||
raise ValueError("Cannot determine method name for provided callable")
|
||||
|
||||
existing = getattr(instance, method_name, None)
|
||||
if getattr(
|
||||
existing, "__is_tracker_wrapper__", False
|
||||
): # Already wrapped via unified tracker
|
||||
return
|
||||
|
||||
# If already individually wrapped by timer or memory tracker, we still wrap only once more;
|
||||
# future calls to add_tracker will become no-ops.
|
||||
orig_bound = existing if existing is not None else func
|
||||
|
||||
def tracked(*args: Any, **kwargs: Any): # noqa: D401 - combined wrapper
|
||||
use_cuda = _cuda_available()
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
start_alloc = torch.cuda.memory_allocated()
|
||||
start_time = perf_counter()
|
||||
try:
|
||||
return orig_bound(*args, **kwargs)
|
||||
finally:
|
||||
elapsed = perf_counter() - start_time
|
||||
TIMER_REGISTRY.setdefault(name, []).append(elapsed)
|
||||
if use_cuda:
|
||||
torch.cuda.synchronize()
|
||||
peak_alloc = torch.cuda.max_memory_allocated()
|
||||
peak_increase = peak_alloc - start_alloc
|
||||
if peak_increase < 0: # Safety guard (should not happen)
|
||||
peak_increase = 0
|
||||
MEMORY_REGISTRY.setdefault(name, []).append(int(peak_increase))
|
||||
|
||||
# Introspection / idempotency markers
|
||||
tracked.__name__ = method_name
|
||||
tracked.__qualname__ = getattr(orig_bound, "__qualname__", method_name)
|
||||
tracked.__doc__ = getattr(orig_bound, "__doc__")
|
||||
tracked.__wrapped__ = orig_bound # type: ignore[attr-defined]
|
||||
tracked.__is_tracker_wrapper__ = True # type: ignore[attr-defined]
|
||||
tracked.__is_timer_wrapper__ = True # type: ignore[attr-defined]
|
||||
tracked.__is_memory_wrapper__ = True # type: ignore[attr-defined]
|
||||
tracked.__tracker_name__ = name # type: ignore[attr-defined]
|
||||
|
||||
setattr(instance, method_name, tracked)
|
||||
|
||||
|
||||
def compute_tracker_stats(
|
||||
name: str | None = None,
|
||||
) -> dict[str, dict[str, Any]] | None:
|
||||
"""Compute both timing and memory stats for a given name (or all if None).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific tracker name; if None, computes all.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Optional[Dict[str, Dict[str, Any]]]
|
||||
None if no data, else dict with 'timing' and 'memory' keys containing
|
||||
their respective aggregate statistics.
|
||||
"""
|
||||
timer_stats = compute_aggregate_timer_stats(name)
|
||||
memory_stats = compute_aggregate_memory_stats(name)
|
||||
|
||||
if timer_stats is None and memory_stats is None:
|
||||
return None
|
||||
|
||||
return {
|
||||
"timing": timer_stats or {},
|
||||
"memory": memory_stats or {},
|
||||
}
|
||||
|
||||
|
||||
def save_tracker_stats_csv(file_path: str, name: str | None = None) -> None:
|
||||
"""Save both timing and memory stats to separate CSV files.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_path : str
|
||||
Base path for CSV files. Will create file_path_timing.csv and file_path_memory.csv
|
||||
name : Optional[str]
|
||||
Specific tracker name to export. If None, all trackers are exported.
|
||||
"""
|
||||
import os
|
||||
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
base_path = os.path.splitext(file_path)[0]
|
||||
timer_path = f"{base_path}_timing.csv"
|
||||
memory_path = f"{base_path}_memory.csv"
|
||||
|
||||
# Save timing stats if available
|
||||
timer_stats = compute_aggregate_timer_stats(name)
|
||||
if timer_stats is not None:
|
||||
save_timer_stats_csv(timer_path, name)
|
||||
|
||||
# Save memory stats if available
|
||||
memory_stats = compute_aggregate_memory_stats(name)
|
||||
if memory_stats is not None:
|
||||
save_memory_stats_csv(memory_path, name)
|
||||
|
||||
# If no data at all, raise an error
|
||||
if timer_stats is None and memory_stats is None:
|
||||
print("No tracking data available to export")
|
||||
|
||||
|
||||
def print_tracker_stats(name: str | None = None) -> None:
|
||||
"""Print both timing and memory stats for a given name (or all if None).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : Optional[str]
|
||||
Specific tracker name; if None, prints all.
|
||||
"""
|
||||
print("[tracker] Timing stats:")
|
||||
print_aggregate_timer_stats(name)
|
||||
print("\n[tracker] CUDA memory stats:")
|
||||
print_aggregate_memory_stats(name)
|
||||
|
||||
|
||||
def print_global_tracker_stats() -> None:
|
||||
"""Print global aggregate timing and memory stats."""
|
||||
print("[tracker] Global timing stats:")
|
||||
print_global_timer_stats()
|
||||
print("\n[tracker] Global CUDA memory stats:")
|
||||
print_global_memory_stats()
|
||||
|
||||
|
||||
def reset_trackers() -> None:
|
||||
"""Reset all timer and memory tracking data."""
|
||||
reset_timers()
|
||||
reset_memory_trackers()
|
||||
|
||||
|
||||
if __name__ == "__main__": # Demonstration
|
||||
import random
|
||||
import time
|
||||
|
||||
class Demo:
|
||||
def compute(self, n: int = 25_000) -> int:
|
||||
# CPU-bound work
|
||||
s = 0
|
||||
for i in range(n):
|
||||
s += i * i
|
||||
return s
|
||||
|
||||
def gpu_alloc(self, n: int = 500_000):
|
||||
if not _cuda_available():
|
||||
# Simulate light wait to differentiate timing
|
||||
time.sleep(random.uniform(0.01, 0.05))
|
||||
return None
|
||||
t = torch.empty(n, dtype=torch.float32, device="cuda")
|
||||
t.uniform_() # ensure usage
|
||||
return t.sum().item()
|
||||
|
||||
demo = Demo()
|
||||
|
||||
add_tracker(demo.compute, "compute")
|
||||
add_tracker(demo.gpu_alloc, "gpu_alloc")
|
||||
# Idempotent re-call
|
||||
add_tracker(demo.compute, "compute")
|
||||
|
||||
for _ in range(5):
|
||||
demo.compute(15_000)
|
||||
demo.gpu_alloc(300_000)
|
||||
|
||||
print_tracker_stats()
|
||||
print("\n--- Global Combined Stats ---\n")
|
||||
print_global_tracker_stats()
|
||||
|
||||
# Demonstrate CSV export
|
||||
print("\n[tracker] Saving stats to CSV files...\n")
|
||||
csv_path = "/tmp/tracker_demo_stats.csv"
|
||||
save_tracker_stats_csv(csv_path)
|
||||
print(f"Exported timing stats to: {csv_path.replace('.csv', '_timing.csv')}")
|
||||
print(f"Exported memory stats to: {csv_path.replace('.csv', '_memory.csv')}")
|
||||
|
||||
print("\n[tracker] Resetting registries...\n")
|
||||
reset_trackers()
|
||||
print_tracker_stats()
|
||||
458
src/ctx_to_lora/trainer.py
Normal file
458
src/ctx_to_lora/trainer.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
import logging
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from transformers.trainer_utils import IntervalStrategy
|
||||
|
||||
from ctx_to_lora.modeling.hypernet import ModulatedPretrainedModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def per_ctx_loss_ce(inputs, labels, loss):
|
||||
# loss still has masked out elem (0 at labels=-100)
|
||||
n_queries_per_ctx = inputs["n_queries"].tolist()
|
||||
|
||||
position_ids = inputs["position_ids"].squeeze(0)
|
||||
# account only label positions
|
||||
label_mask = labels.squeeze(0) != -100
|
||||
label_pos_ids = label_mask * position_ids
|
||||
label_pos_ids_diff = label_pos_ids.diff(
|
||||
append=torch.tensor([0], device=position_ids.device)
|
||||
)
|
||||
|
||||
# assumes the input starts with non-assistant tokens
|
||||
start_label_pos = torch.where((label_pos_ids_diff > 0) * ~label_mask)[0]
|
||||
end_label_pos = torch.where((label_pos_ids_diff < 0) * label_mask)[0]
|
||||
|
||||
label_seq_lens = end_label_pos - start_label_pos
|
||||
|
||||
# these stack and split can be optimized but let's keep it simple
|
||||
# mean across tokens of each q
|
||||
qa_losses = torch.stack(
|
||||
[
|
||||
loss[start : start + llen].mean()
|
||||
for start, llen in zip(start_label_pos, label_seq_lens)
|
||||
]
|
||||
)
|
||||
|
||||
# mean across queries of each ctx
|
||||
per_ctx_losses = [ql.mean() for ql in torch.split(qa_losses, n_queries_per_ctx)]
|
||||
|
||||
# per-ctx loss
|
||||
loss = torch.stack(per_ctx_losses)
|
||||
return loss
|
||||
|
||||
|
||||
def per_ctx_loss_kl(inputs, labels, loss):
|
||||
# loss is compact (label indices selected)
|
||||
n_queries_per_ctx = inputs["n_queries"].tolist()
|
||||
|
||||
position_ids = inputs["position_ids"].squeeze(0)
|
||||
# account only label positions
|
||||
label_mask = labels.squeeze(0) != -100
|
||||
label_pos_ids = label_mask * position_ids
|
||||
label_pos_ids_diff = label_pos_ids.diff(
|
||||
append=torch.tensor([0], device=position_ids.device)
|
||||
)
|
||||
# assumes the input starts with non-assistant tokens
|
||||
start_label_pos = torch.where((label_pos_ids_diff > 0) * ~label_mask)[0]
|
||||
end_label_pos = torch.where((label_pos_ids_diff < 0) * label_mask)[0]
|
||||
|
||||
label_seq_lens = end_label_pos - start_label_pos
|
||||
|
||||
# find equiv start indices in the already sliced loss vector
|
||||
cu_label_seq_lens = torch.cumsum(label_seq_lens, dim=0)
|
||||
start_indices = torch.cat(
|
||||
(
|
||||
torch.tensor([0], device=cu_label_seq_lens.device),
|
||||
cu_label_seq_lens[:-1],
|
||||
)
|
||||
)
|
||||
|
||||
# these stack and split can be optimized but let's keep it simple
|
||||
# mean across tokens of each q
|
||||
qa_losses = torch.stack(
|
||||
[loss[start:end].mean() for start, end in zip(start_indices, cu_label_seq_lens)]
|
||||
)
|
||||
|
||||
# mean across queries of each ctx
|
||||
per_ctx_losses = [ql.mean() for ql in torch.split(qa_losses, n_queries_per_ctx)]
|
||||
|
||||
# per-ctx loss
|
||||
loss = torch.stack(per_ctx_losses)
|
||||
return loss
|
||||
|
||||
|
||||
class ModulatedModelTrainer(Trainer):
|
||||
# modified from the base Trainer to support per-context average loss
|
||||
def get_batch_samples(self, epoch_iterator, num_batches, device):
|
||||
# only used with `use_per_ctx_average_loss=True`
|
||||
batch_samples = []
|
||||
num_items_in_batch = None
|
||||
|
||||
for _ in range(num_batches):
|
||||
try:
|
||||
batch_samples.append(next(epoch_iterator))
|
||||
except StopIteration:
|
||||
break
|
||||
|
||||
count_num_items_in_batch = (
|
||||
len(batch_samples) > 0
|
||||
and "labels" in batch_samples[0]
|
||||
and "n_ctx_chunks" in batch_samples[0]
|
||||
)
|
||||
|
||||
if count_num_items_in_batch:
|
||||
num_items_in_batch = dict()
|
||||
num_items_in_batch["ctx"] = torch.tensor(
|
||||
sum([batch["n_ctx_chunks"].numel() for batch in batch_samples])
|
||||
).to(device)
|
||||
# should we avg over num chunks?
|
||||
# num_items_in_batch["ctx"] = sum(
|
||||
# [(batch["ctx_position_ids"] == 0).sum() for batch in batch_samples]
|
||||
# )
|
||||
num_items_in_batch["labels"] = sum(
|
||||
[(batch["labels"].ne(-100)).sum() for batch in batch_samples]
|
||||
).to(device)
|
||||
|
||||
if num_items_in_batch is not None:
|
||||
if self.args.average_tokens_across_devices:
|
||||
for k in num_items_in_batch:
|
||||
num_items_in_batch[k] = self.accelerator.gather(
|
||||
num_items_in_batch[k]
|
||||
).sum()
|
||||
|
||||
if torch.is_tensor(num_items_in_batch):
|
||||
num_items_in_batch = num_items_in_batch.to(device)
|
||||
|
||||
if self.args.n_gpu > 1 and num_items_in_batch.dim() == 0:
|
||||
# In the DataParallel case, convert the scalar tensor into a 1-dim tensor
|
||||
num_items_in_batch = num_items_in_batch.unsqueeze(0)
|
||||
|
||||
return batch_samples, num_items_in_batch
|
||||
|
||||
|
||||
class DistillationTrainer(ModulatedModelTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.gen_lora_l1_reg_coef = kwargs.pop("gen_lora_l1_reg_coef", 0.0)
|
||||
self.use_per_ctx_average_loss = kwargs.pop("use_per_ctx_average_loss", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
# NOTE: the loss output from this fn will be ***added***
|
||||
# meaning that we should always scale the loss wrt `num_items_in_batch`
|
||||
# (average over the number of items in the accumulated batch)
|
||||
|
||||
is_train = num_items_in_batch is not None
|
||||
labels = inputs.pop("labels", None)
|
||||
label_pos = torch.where(labels != -100)
|
||||
outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
|
||||
|
||||
if "logprobs_vals" not in inputs:
|
||||
return (torch.tensor(0.0), outputs) if return_outputs else torch.tensor(0.0)
|
||||
|
||||
target_logp = inputs.pop("logprobs_vals").squeeze(0)
|
||||
indices = inputs.pop("logprobs_indices").squeeze(0)
|
||||
|
||||
assert label_pos[0].shape[0] == target_logp.shape[0], (
|
||||
"Label positions and target log probabilities should have the same # tokens."
|
||||
f"Got : {label_pos[0].shape[0]=} and {target_logp.shape[0]=}"
|
||||
)
|
||||
|
||||
##### KL loss
|
||||
outputs_logits = outputs.logits[label_pos[0], label_pos[1] - 1] # shift back 1
|
||||
|
||||
logq_full_denom = torch.logsumexp(outputs_logits, dim=-1, keepdim=True) # (N,1)
|
||||
selected_logits = outputs_logits.gather(1, indices) # (N,K)
|
||||
# log softmax at selected indices
|
||||
logq_selected = selected_logits - logq_full_denom
|
||||
p = target_logp.exp()
|
||||
loss = -(p * logq_selected).sum(dim=-1)
|
||||
|
||||
# teacher_logp = torch.full_like(outputs_logits, -torch.inf)
|
||||
# teacher_logp.scatter_(1, indices, target_logp)
|
||||
# # reduction = "batchmean" if num_items_in_batch is None else "sum"
|
||||
# p = teacher_logp.exp()
|
||||
# logq = nn.functional.log_softmax(outputs_logits, dim=-1)
|
||||
# loss = -torch.sum(p * logq, dim=-1)
|
||||
|
||||
if self.use_per_ctx_average_loss:
|
||||
loss = per_ctx_loss_kl(inputs, labels, loss)
|
||||
|
||||
if is_train:
|
||||
if self.use_per_ctx_average_loss:
|
||||
loss = loss.sum() / num_items_in_batch["ctx"]
|
||||
else:
|
||||
loss = loss.sum() / num_items_in_batch["labels"]
|
||||
else:
|
||||
# eval
|
||||
loss = loss.mean()
|
||||
|
||||
# if reduction == "batchmean":
|
||||
# loss = loss.mean()
|
||||
# elif reduction == "sum":
|
||||
# # loss does not scale with grad acc
|
||||
# # num_items_in_batch does
|
||||
# # this works for both token-avg and ctx-avg
|
||||
# # loss = loss.sum() / num_items_in_batch
|
||||
|
||||
# `num_items_in_batch` is # tokens if `args.use_ctx_average_loss=False``
|
||||
# loss = loss.sum() / num_items_in_batch
|
||||
#####
|
||||
|
||||
##### unpack gen lora dict and compute regularization loss
|
||||
l1_norm = 0
|
||||
n_modules = len(gen_loras)
|
||||
for module, lora in gen_loras.items():
|
||||
l1_norm += lora["A"].abs().sum(0).mean() + lora["B"].abs().sum(0).mean()
|
||||
l1_norm /= n_modules
|
||||
if is_train:
|
||||
# during eval `num_items_in_batch` will be None
|
||||
l1_norm /= num_items_in_batch["ctx"]
|
||||
|
||||
total_loss = loss + self.gen_lora_l1_reg_coef * l1_norm
|
||||
#####
|
||||
|
||||
scaler = self.args.gradient_accumulation_steps if is_train else 1
|
||||
if self.args.average_tokens_across_devices and is_train:
|
||||
total_loss *= self.accelerator.num_processes
|
||||
scaler *= self.accelerator.num_processes
|
||||
|
||||
# rough estimate of the losses (we only log the values from one step)
|
||||
if (self.state.global_step == 1 and self.args.logging_first_step) or (
|
||||
self.args.logging_strategy == IntervalStrategy.STEPS
|
||||
and self.state.global_step % self.state.logging_steps == 0
|
||||
):
|
||||
# compensate `num_items_in_batch` division
|
||||
self.log(
|
||||
{
|
||||
"kl_loss": loss.item() * scaler,
|
||||
"gen_lora_l1_norm": l1_norm.item() * scaler,
|
||||
}
|
||||
)
|
||||
|
||||
return (total_loss, outputs) if return_outputs else total_loss
|
||||
|
||||
|
||||
def causal_lm_ce_loss(
|
||||
logits,
|
||||
labels,
|
||||
vocab_size: int,
|
||||
num_items_in_batch: torch.Tensor | None = None,
|
||||
ignore_index: int = -100,
|
||||
shift_labels: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Upcast to float if we need to compute the loss to avoid potential precision issues
|
||||
logits = logits.float()
|
||||
|
||||
if shift_labels is None:
|
||||
# Shift so that tokens < n predict n
|
||||
labels = nn.functional.pad(labels, (0, 1), value=ignore_index)
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
|
||||
# Flatten the tokens
|
||||
logits = logits.view(-1, vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(logits.device)
|
||||
# loss = fixed_cross_entropy(
|
||||
# logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
|
||||
# )
|
||||
loss = nn.functional.cross_entropy(logits, shift_labels, reduction="none")
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class CrossEntropyTrainer(ModulatedModelTrainer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.gen_lora_l1_reg_coef = kwargs.pop("gen_lora_l1_reg_coef", 0.0)
|
||||
self.use_per_ctx_average_loss = kwargs.pop("use_per_ctx_average_loss", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def compute_loss(
|
||||
self, model, inputs, return_outputs=False, num_items_in_batch=None
|
||||
):
|
||||
"""
|
||||
How the loss is computed by Trainer.
|
||||
By default, all models return the loss in the first element.
|
||||
Subclass and override for custom behavior.
|
||||
"""
|
||||
|
||||
is_train = num_items_in_batch is not None
|
||||
labels = inputs.pop("labels", None)
|
||||
outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
|
||||
# [1, tot_seq_len]
|
||||
logits = outputs.logits
|
||||
|
||||
# [tot_seq_len]
|
||||
loss = causal_lm_ce_loss(logits, labels, self.model.vocab_size)
|
||||
|
||||
if self.use_per_ctx_average_loss:
|
||||
loss = per_ctx_loss_ce(inputs, labels, loss)
|
||||
|
||||
if is_train:
|
||||
if self.use_per_ctx_average_loss:
|
||||
loss = loss.sum() / num_items_in_batch["ctx"]
|
||||
else:
|
||||
loss = loss.sum() / num_items_in_batch["labels"]
|
||||
else:
|
||||
# eval
|
||||
loss = loss.mean()
|
||||
|
||||
#####
|
||||
# if is_train:
|
||||
# if self.use_per_ctx_average_loss:
|
||||
# loss_kwargs["num_items_in_batch"] = num_items_in_batch["ctx"]
|
||||
# else:
|
||||
# loss_kwargs["num_items_in_batch"] = num_items_in_batch["labels"]
|
||||
# inputs = {**inputs, **loss_kwargs}
|
||||
# outputs, (gen_loras, _) = model(**inputs, return_generated_lora=True)
|
||||
|
||||
# # Save past state if it exists
|
||||
# if self.args.past_index >= 0:
|
||||
# self._past = outputs[self.args.past_index]
|
||||
|
||||
# if labels is not None:
|
||||
# unwrapped_model = self.accelerator.unwrap_model(model)
|
||||
# if _is_peft_model(unwrapped_model):
|
||||
# model_name = unwrapped_model.base_model.model._get_name()
|
||||
# else:
|
||||
# model_name = unwrapped_model._get_name()
|
||||
# # User-defined compute_loss function
|
||||
# if self.compute_loss_func is not None:
|
||||
# loss = self.compute_loss_func(
|
||||
# outputs, labels, num_items_in_batch=num_items_in_batch["labels"]
|
||||
# )
|
||||
# elif model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
|
||||
# loss = self.label_smoother(outputs, labels, shift_labels=True)
|
||||
# else:
|
||||
# loss = self.label_smoother(outputs, labels)
|
||||
# else:
|
||||
# if isinstance(outputs, dict) and "loss" not in outputs:
|
||||
# raise ValueError(
|
||||
# "The model did not return a loss from the inputs, "
|
||||
# "only the following keys: "
|
||||
# f"{','.join(outputs.keys())}. "
|
||||
# "For reference, the inputs it received are "
|
||||
# f"{','.join(inputs.keys())}."
|
||||
# )
|
||||
# # We don't use .loss here since the model may return tuples instead of ModelOutput.
|
||||
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
|
||||
#####
|
||||
|
||||
##### unpack gen lora dict and compute regularization loss
|
||||
l1_norm = 0
|
||||
n_modules = len(gen_loras)
|
||||
for module, lora in gen_loras.items():
|
||||
l1_norm += lora["A"].abs().sum(0).mean() + lora["B"].abs().sum(0).mean()
|
||||
l1_norm /= n_modules
|
||||
if is_train:
|
||||
# during eval `num_items_in_batch` will be None
|
||||
l1_norm /= num_items_in_batch["ctx"]
|
||||
|
||||
total_loss = loss + self.gen_lora_l1_reg_coef * l1_norm
|
||||
#####
|
||||
|
||||
scaler = self.args.gradient_accumulation_steps if is_train else 1
|
||||
if self.args.average_tokens_across_devices and is_train:
|
||||
total_loss *= self.accelerator.num_processes
|
||||
scaler *= self.accelerator.num_processes
|
||||
|
||||
# rough estimate of the losses (we only log the values from one step)
|
||||
if (self.state.global_step == 1 and self.args.logging_first_step) or (
|
||||
self.args.logging_strategy == IntervalStrategy.STEPS
|
||||
and self.state.global_step % self.state.logging_steps == 0
|
||||
):
|
||||
# compensate `num_items_in_batch` division
|
||||
self.log(
|
||||
{
|
||||
"ce_loss": loss.item() * scaler,
|
||||
"gen_lora_l1_norm": l1_norm.item() * scaler,
|
||||
}
|
||||
)
|
||||
|
||||
return (total_loss, outputs) if return_outputs else total_loss
|
||||
|
||||
|
||||
def get_decay_parameter_names(model) -> list[str]:
|
||||
"""
|
||||
Get all parameter names that weight decay will be applied to.
|
||||
|
||||
This function filters out parameters in two ways:
|
||||
1. By layer type (nn.Embedding)
|
||||
2. By parameter name patterns (containing 'bias', 'layernorm', 'rmsnorm'
|
||||
or 'latents_q' [perceiver's latent queries]).
|
||||
"""
|
||||
decay_parameters = get_parameter_names(
|
||||
model,
|
||||
[nn.Embedding, nn.LayerNorm],
|
||||
["scaler", "bias", "layernorm", "rmsnorm", "latents_q"],
|
||||
)
|
||||
return decay_parameters
|
||||
|
||||
|
||||
def train_model(
|
||||
model,
|
||||
training_args,
|
||||
train_dataset=None,
|
||||
val_dataset=None,
|
||||
train_collator=None,
|
||||
compute_metrics=None,
|
||||
):
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
logger.info(f"Resuming from the checkpoint: {checkpoint}")
|
||||
|
||||
trainer_kwargs = dict(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
data_collator=train_collator,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
is_modulated_model = isinstance(model, ModulatedPretrainedModel)
|
||||
trainer_cls = Trainer
|
||||
if is_modulated_model:
|
||||
logger.info("Training with modulated model.")
|
||||
trainer_cls = CrossEntropyTrainer
|
||||
trainer_kwargs["gen_lora_l1_reg_coef"] = training_args.gen_lora_l1_reg_coef
|
||||
trainer_kwargs["use_per_ctx_average_loss"] = (
|
||||
training_args.use_per_ctx_average_loss
|
||||
)
|
||||
del training_args.gen_lora_l1_reg_coef
|
||||
del training_args.use_per_ctx_average_loss
|
||||
|
||||
if training_args.use_kl_loss:
|
||||
logger.info("Training with distillation loss. Using DistillationTrainer.")
|
||||
trainer_cls = DistillationTrainer
|
||||
del training_args.use_kl_loss
|
||||
|
||||
if training_args.auto_find_batch_size:
|
||||
# set the batch size to some high number
|
||||
# which will be lowered by the Trainer
|
||||
training_args.per_device_train_batch_size = 128
|
||||
|
||||
trainer = trainer_cls(**trainer_kwargs)
|
||||
# if getattr(trainer, "use_per_ctx_average_loss", False):
|
||||
# trainer.get_batch_samples = trainer.get_batch_samples_ctx
|
||||
|
||||
# MONKEY PATCH: remove embedding layers from weight decay
|
||||
trainer.get_decay_parameter_names = get_decay_parameter_names
|
||||
|
||||
# Trainer loads the best model after training
|
||||
# is done when load_best_model_at_end=True
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_model()
|
||||
|
||||
# TODO: add benchmark eval?
|
||||
# clear_gpu()
|
||||
275
src/ctx_to_lora/utils.py
Normal file
275
src/ctx_to_lora/utils.py
Normal file
|
|
@ -0,0 +1,275 @@
|
|||
import ast
|
||||
import gc
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from peft import PeftConfig, PeftModel
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer, check_target_module_exists
|
||||
from peft.utils import get_peft_model_state_dict
|
||||
|
||||
TRAINING_TASK = Enum("TRAINING_TASK", ["CAUSAL_LM", "COMPLETION"])
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
# taken from https://discuss.pytorch.org/t/opinion-eval-should-be-a-context-manager/18998/3
|
||||
@contextmanager
|
||||
def evaluating(*models):
|
||||
"""Temporarily switch to evaluation mode."""
|
||||
is_training = [model.training if model is not None else False for model in models]
|
||||
try:
|
||||
for model in models:
|
||||
if model is not None:
|
||||
model.eval()
|
||||
yield models
|
||||
finally:
|
||||
for model, training in zip(models, is_training):
|
||||
if model is not None:
|
||||
model.train(training)
|
||||
|
||||
|
||||
def get_layers(model):
|
||||
if hasattr(model, "model"):
|
||||
return get_layers(model.model)
|
||||
return model.layers
|
||||
|
||||
|
||||
def get_num_layers(model):
|
||||
return len(get_layers(model))
|
||||
|
||||
|
||||
def get_base_model(model):
|
||||
if hasattr(model, "model"):
|
||||
return get_base_model(model.model)
|
||||
return model
|
||||
|
||||
|
||||
def get_num_params(model):
|
||||
total_params = 0
|
||||
trainable_params = 0
|
||||
for p in model.parameters():
|
||||
total_params += p.numel()
|
||||
if p.requires_grad:
|
||||
trainable_params += p.numel()
|
||||
|
||||
return total_params, trainable_params
|
||||
|
||||
|
||||
def log_num_train_params(model):
|
||||
logger.debug("Trainable model parameters:")
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
logger.debug(f"{name}, dtype:{p.dtype}")
|
||||
|
||||
num_total_params, num_trainable_params = get_num_params(model)
|
||||
logger.info(
|
||||
f"trainable params: {num_trainable_params:,d} "
|
||||
f"|| all params: {num_total_params:,d} "
|
||||
f"|| trainable%: {100 * num_trainable_params / num_total_params:.4f}"
|
||||
)
|
||||
|
||||
|
||||
def get_run_name(seed_str: str | None = None):
|
||||
if not seed_str:
|
||||
uuid = "".join(
|
||||
[random.choice(string.ascii_letters + string.digits) for _ in range(8)]
|
||||
)
|
||||
run_name = time.strftime("%Y%m%d-%H%M%S") + f"_{uuid}"
|
||||
else:
|
||||
# Generate a UUID from the seed string
|
||||
hash_object = hashlib.sha256(seed_str.encode())
|
||||
uuid = hash_object.hexdigest()[:8] # Take the first 8 characters of the hash
|
||||
run_name = seed_str + f"_{uuid}"
|
||||
return run_name
|
||||
|
||||
|
||||
def try_convert(s):
|
||||
try:
|
||||
return ast.literal_eval(s)
|
||||
except:
|
||||
return s
|
||||
|
||||
|
||||
def extract_cli_args(argv: list[str]):
|
||||
out = dict()
|
||||
for elem in argv:
|
||||
if elem.endswith(".yaml"):
|
||||
out["config"] = elem
|
||||
|
||||
elif elem.startswith("--"):
|
||||
k, v = elem.split("=")
|
||||
k = k.split("--")[1]
|
||||
v = try_convert(v)
|
||||
# if k.startswith('env_'):
|
||||
# k = k.split('_')[1]
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def setup_logging(output_dir, debug=False):
|
||||
global logger
|
||||
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
log_formatter = logging.Formatter(
|
||||
fmt="%(asctime)s %(levelname)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
stream_level = logging.DEBUG if debug else logging.INFO
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(log_formatter)
|
||||
stream_handler.setLevel(stream_level)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
log_path = f"{output_dir}/debug.log"
|
||||
debug_handler = logging.FileHandler(log_path, delay=True)
|
||||
debug_handler.setLevel(logging.DEBUG)
|
||||
debug_handler.setFormatter(log_formatter)
|
||||
logger.addHandler(debug_handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.info(f"Logging to: {log_path}")
|
||||
|
||||
|
||||
def validate_args(args_list):
|
||||
# there shouldn't be overlap between args
|
||||
keys = set()
|
||||
for args in args_list:
|
||||
logger.debug(args)
|
||||
args_keys = set(vars(args).keys())
|
||||
assert len(keys & args_keys) == 0, "Overlap between args"
|
||||
keys |= args_keys
|
||||
|
||||
|
||||
def save_yaml(data, path):
|
||||
# Filter out non-primitive fields
|
||||
data = {
|
||||
k: v
|
||||
for k, v in data.items()
|
||||
if isinstance(v, (int, float, str, bool, list, dict, type(None)))
|
||||
}
|
||||
|
||||
with open(path, "w") as file:
|
||||
yaml.dump(data, file)
|
||||
|
||||
|
||||
def get_peft_modules(model: PeftModel, peft_config: PeftConfig) -> list[dict[str, str]]:
|
||||
return [
|
||||
{"name": name, "module": module}
|
||||
for name, module in model.named_modules()
|
||||
if name.split(".")[-1] in peft_config.target_modules
|
||||
and isinstance(module, BaseTunerLayer)
|
||||
and check_target_module_exists(peft_config, name)
|
||||
]
|
||||
|
||||
|
||||
def get_peft_in_out_features(
|
||||
model: PeftModel,
|
||||
peft_config: PeftConfig | None = None,
|
||||
) -> tuple[dict[str, int], dict[str, int]]:
|
||||
if peft_config is None:
|
||||
return None, None
|
||||
in_features = dict()
|
||||
out_features = dict()
|
||||
for module_info in get_peft_modules(model, peft_config):
|
||||
module_name = module_info["name"]
|
||||
module = module_info["module"]
|
||||
# support just Linear layer for now
|
||||
# all modules should be a leave module that is Linear layer
|
||||
assert isinstance(module.base_layer, torch.nn.Linear), (
|
||||
"all modules should be a leave module that is Linear layer"
|
||||
)
|
||||
|
||||
# this should always pass
|
||||
name = module_name.split(".")[-1]
|
||||
assert name in peft_config.target_modules
|
||||
|
||||
if name not in in_features:
|
||||
in_features[name] = module.in_features
|
||||
out_features[name] = module.out_features
|
||||
else:
|
||||
# assumes each module has the same input and output features
|
||||
assert in_features[name] == module.in_features
|
||||
assert out_features[name] == module.out_features
|
||||
|
||||
return in_features, out_features
|
||||
|
||||
|
||||
def generated_lora_to_state_dict(
|
||||
lora_dict: dict,
|
||||
module_names: dict,
|
||||
target_modules: list[str],
|
||||
layer_indices: Iterable[int],
|
||||
) -> dict:
|
||||
lora_state_dict = dict()
|
||||
for target_module in target_modules:
|
||||
for layer_idx in layer_indices:
|
||||
for module_name in module_names[target_module][layer_idx]:
|
||||
if "lora_A" in module_name:
|
||||
lora_state_dict[module_name] = (
|
||||
lora_dict[target_module]["A"][layer_idx].cpu().contiguous()
|
||||
)
|
||||
elif "lora_B" in module_name:
|
||||
lora_state_dict[module_name] = (
|
||||
lora_dict[target_module]["B"][layer_idx].cpu().contiguous()
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected module name: {module_name}")
|
||||
return lora_state_dict
|
||||
|
||||
|
||||
def get_lora_module_names(
|
||||
model: PeftModel,
|
||||
target_modules: list[str],
|
||||
layer_indices: Iterable[int],
|
||||
) -> dict[str, list[str]]:
|
||||
module_names = {
|
||||
target_module: [[] for _ in range(len(layer_indices))]
|
||||
for target_module in target_modules
|
||||
}
|
||||
for k in get_peft_model_state_dict(model):
|
||||
if "lora" not in k:
|
||||
continue
|
||||
layer_idx = int(k.split("layers.")[-1].split(".")[0])
|
||||
if layer_idx in layer_indices:
|
||||
for target_module in target_modules:
|
||||
if target_module in k:
|
||||
module_names[target_module][layer_idx].append(k)
|
||||
break
|
||||
return module_names
|
||||
|
||||
|
||||
def compile_linear(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
module.compile()
|
||||
|
||||
|
||||
def clear_gpu():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
|
||||
|
||||
def concat_list(l):
|
||||
out = []
|
||||
for x in l:
|
||||
out += x
|
||||
return out
|
||||
|
||||
|
||||
def check_is_iterable(x):
|
||||
try:
|
||||
iter(x)
|
||||
except TypeError:
|
||||
return False
|
||||
return True
|
||||
419
train.py
Executable file
419
train.py
Executable file
|
|
@ -0,0 +1,419 @@
|
|||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import wandb
|
||||
from datasets import disable_caching
|
||||
from peft import PeftModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
set_seed,
|
||||
)
|
||||
|
||||
from ctx_to_lora.configs import (
|
||||
AggregatorArguments,
|
||||
ArgumentParser,
|
||||
CtxEncoderArguments,
|
||||
CtxTrainingArguments,
|
||||
DataArguments,
|
||||
ExperimentSetup,
|
||||
HypernetArguments,
|
||||
LoRAArguments,
|
||||
ModelArguments,
|
||||
TrainingArguments,
|
||||
)
|
||||
from ctx_to_lora.data.collator import ( # train_packed_collator,; DefaultDataCollator,
|
||||
flatten_if_not_packed,
|
||||
)
|
||||
from ctx_to_lora.data.processing import get_tokenized_dataset, pack
|
||||
from ctx_to_lora.metrics import (
|
||||
Evaluator,
|
||||
compute_metrics,
|
||||
compute_per_token_acc,
|
||||
compute_perplexity,
|
||||
compute_prefix_matching,
|
||||
)
|
||||
from ctx_to_lora.model_loading import (
|
||||
check_is_vision_model,
|
||||
get_lora_config,
|
||||
get_model_and_tokenizer,
|
||||
get_tokenizer,
|
||||
)
|
||||
from ctx_to_lora.modeling.hypernet import (
|
||||
ModulatedPretrainedModel,
|
||||
get_hypernet_config,
|
||||
)
|
||||
from ctx_to_lora.trainer import train_model
|
||||
from ctx_to_lora.utils import (
|
||||
compile_linear,
|
||||
extract_cli_args,
|
||||
get_run_name,
|
||||
log_num_train_params,
|
||||
save_yaml,
|
||||
setup_logging,
|
||||
validate_args,
|
||||
)
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
|
||||
def main():
|
||||
############ Argument parsing
|
||||
parser = ArgumentParser(
|
||||
(
|
||||
DataArguments,
|
||||
CtxTrainingArguments,
|
||||
ModelArguments,
|
||||
LoRAArguments,
|
||||
TrainingArguments,
|
||||
HypernetArguments,
|
||||
AggregatorArguments,
|
||||
CtxEncoderArguments,
|
||||
)
|
||||
)
|
||||
(
|
||||
data_args,
|
||||
ctx_args,
|
||||
model_args,
|
||||
lora_args,
|
||||
training_args,
|
||||
hypernet_args,
|
||||
aggregator_args,
|
||||
ctx_encoder_args,
|
||||
) = parser.parse()
|
||||
|
||||
# there shouldn't be overlap between args
|
||||
validate_args(
|
||||
[
|
||||
data_args,
|
||||
ctx_args,
|
||||
model_args,
|
||||
lora_args,
|
||||
training_args,
|
||||
hypernet_args,
|
||||
aggregator_args,
|
||||
ctx_encoder_args,
|
||||
]
|
||||
)
|
||||
|
||||
assert ctx_args.use_sequence_packing, (
|
||||
f"Please set use_sequence_packing=True in {ctx_args}. It's faster!"
|
||||
)
|
||||
|
||||
set_seed(training_args.seed)
|
||||
checkpoint_dir = training_args.resume_from_checkpoint
|
||||
|
||||
# should be the same across processes
|
||||
# still possible to have a name crash though
|
||||
# logging_dir is just "runs/DATE_TIME_HOSTNAME"
|
||||
slurm_job_id = f"_{os.getenv('SLURM_JOB_ID')}" if os.getenv("SLURM_JOB_ID") else ""
|
||||
run_name = (
|
||||
get_run_name(seed_str=training_args.logging_dir.strip("runs/") + slurm_job_id)
|
||||
if not checkpoint_dir
|
||||
else checkpoint_dir.strip("/").split("/")[-2]
|
||||
)
|
||||
|
||||
output_dir = f"train_outputs/runs/{run_name}"
|
||||
setup_logging(output_dir, debug=os.getenv("DEBUG", False))
|
||||
logger.debug(f"CMD: {' '.join(os.sys.argv)}")
|
||||
cli_args = extract_cli_args(os.sys.argv)
|
||||
save_yaml(cli_args, f"{output_dir}/cli_args.yaml")
|
||||
if "config" in cli_args:
|
||||
config_name = os.path.basename(cli_args["config"]).split(".yaml")[0]
|
||||
os.environ["WANDB_TAGS"] = config_name
|
||||
|
||||
run_name = os.path.basename(output_dir)
|
||||
training_args.run_name = run_name
|
||||
training_args.output_dir = output_dir
|
||||
training_args.logging_dir = output_dir
|
||||
|
||||
if (
|
||||
training_args.lr_scheduler_type == "cosine_with_min_lr"
|
||||
and training_args.lr_scheduler_kwargs is None
|
||||
):
|
||||
training_args.lr_scheduler_kwargs = {"min_lr": 1e-7}
|
||||
args = {
|
||||
**vars(deepcopy(data_args)),
|
||||
**vars(deepcopy(ctx_args)),
|
||||
**vars(deepcopy(model_args)),
|
||||
**vars(deepcopy(lora_args)),
|
||||
**vars(deepcopy(training_args)),
|
||||
**vars(deepcopy(hypernet_args)),
|
||||
**vars(deepcopy(aggregator_args)),
|
||||
**vars(deepcopy(ctx_encoder_args)),
|
||||
}
|
||||
args["deepspeed_plugin"] = None
|
||||
logger.debug(f"args: {args}")
|
||||
save_yaml(args, f"{output_dir}/args.yaml")
|
||||
|
||||
############ Model setup
|
||||
if not ctx_args.from_pretrained_checkpoint:
|
||||
model_name = model_args.model_name_or_path
|
||||
base_model, tokenizer = get_model_and_tokenizer(
|
||||
**vars(model_args),
|
||||
train=True,
|
||||
requires_grad=False,
|
||||
peft_config=get_lora_config(model_name, **vars(lora_args)),
|
||||
)
|
||||
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
|
||||
if ctx_name is not None:
|
||||
ctx_encoder_model_config = AutoConfig.from_pretrained(
|
||||
ctx_name, trust_remote_code=True
|
||||
)
|
||||
if ("Llama" in ctx_name and "Vision" in ctx_name) or check_is_vision_model(
|
||||
ctx_name
|
||||
):
|
||||
ctx_encoder_model_config = ctx_encoder_model_config.text_config
|
||||
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
|
||||
else:
|
||||
ctx_name = base_model.base_model.config.name_or_path
|
||||
ctx_encoder_model_config = base_model.config
|
||||
ctx_tokenizer = tokenizer
|
||||
|
||||
if ctx_args.exp_setup == ExperimentSetup.HYPERLORA:
|
||||
logger.info("Using HyperLoRA")
|
||||
if not ctx_args.from_pretrained_checkpoint:
|
||||
hypernet_config = get_hypernet_config(
|
||||
base_model,
|
||||
ctx_encoder_model_config,
|
||||
hypernet_args,
|
||||
aggregator_args,
|
||||
ctx_encoder_args,
|
||||
)
|
||||
if ctx_encoder_args.layer_idx is None:
|
||||
ctx_encoder_args.layer_idx = (
|
||||
ctx_encoder_model_config.num_hidden_layers // 4
|
||||
)
|
||||
logger.info(
|
||||
f"Using the first {ctx_encoder_args.layer_idx} layers"
|
||||
" as the context encoder"
|
||||
)
|
||||
ctx_name = ctx_encoder_args.ctx_encoder_model_name_or_path
|
||||
if ctx_encoder_args.ctx_encoder_last_layer is None and (
|
||||
ctx_name is not None and ctx_name != base_model.name_or_path
|
||||
):
|
||||
logger.info(
|
||||
f"Setting ctx_encoder_last_layer to {base_model.name_or_path} max layers"
|
||||
f":{base_model.config.num_hidden_layers}"
|
||||
)
|
||||
ctx_encoder_args.ctx_encoder_last_layer = (
|
||||
base_model.config.num_hidden_layers
|
||||
)
|
||||
|
||||
model = ModulatedPretrainedModel(
|
||||
base_model, hypernet_config, ctx_encoder_args
|
||||
)
|
||||
|
||||
else:
|
||||
if checkpoint_dir:
|
||||
ctx_args.from_pretrained_checkpoint = (
|
||||
f"{checkpoint_dir}/pytorch_model.bin"
|
||||
)
|
||||
logger.info(
|
||||
f"Loading from checkpoint: {ctx_args.from_pretrained_checkpoint}"
|
||||
)
|
||||
|
||||
model = ModulatedPretrainedModel.from_state_dict(
|
||||
torch.load(ctx_args.from_pretrained_checkpoint, weights_only=False),
|
||||
train=True,
|
||||
use_flash_attn=model_args.use_flash_attn,
|
||||
)
|
||||
tokenizer = get_tokenizer(model.base_model.config.name_or_path, train=True)
|
||||
ctx_name = model.ctx_encoder_args.ctx_encoder_model_name_or_path
|
||||
if ctx_name is None:
|
||||
ctx_name = model.base_model.config.name_or_path
|
||||
ctx_tokenizer = get_tokenizer(ctx_name, train=True)
|
||||
|
||||
training_args.gen_lora_l1_reg_coef = ctx_args.gen_lora_l1_reg_coef
|
||||
training_args.use_kl_loss = ctx_args.use_kl_loss
|
||||
training_args.use_per_ctx_average_loss = ctx_args.use_per_ctx_average_loss
|
||||
|
||||
if len([p for p in model.ctx_encoder.parameters() if p.requires_grad]):
|
||||
raise ValueError("ctx_encoder contains trainable parameters")
|
||||
if len([p for p in model.base_model.parameters() if p.requires_grad]):
|
||||
raise ValueError("base model contains trainable parameters")
|
||||
|
||||
model.hypernet.compile(fullgraph=True, mode="max-autotune")
|
||||
|
||||
else:
|
||||
# activate LoRA
|
||||
base_model_config = AutoConfig.from_pretrained(
|
||||
model_args.model_name_or_path, trust_remote_code=True
|
||||
)
|
||||
base_model_config.save_pretrained(output_dir)
|
||||
logger.info("Using LoRA")
|
||||
model.set_adapter("default")
|
||||
model = torch.compile(model)
|
||||
|
||||
model.train()
|
||||
logger.debug(model)
|
||||
log_num_train_params(model)
|
||||
|
||||
############ Dataset setup
|
||||
logger.info("Loading dataset...")
|
||||
|
||||
add_ctx_to_chat = not isinstance(model, ModulatedPretrainedModel)
|
||||
ctx_model_max_len = model.ctx_encoder.config.max_position_embeddings
|
||||
if ctx_args.max_ctx_len > 0:
|
||||
ctx_model_max_len = ctx_args.max_ctx_len
|
||||
if ctx_args.max_ctx_chunk_len <= 0:
|
||||
# set default chunk size to max length of the ctx encoder
|
||||
ctx_args.max_ctx_chunk_len = ctx_model_max_len
|
||||
|
||||
if ctx_args.num_chunk_probs is not None:
|
||||
ctx_args.num_chunk_probs = {
|
||||
int(k): float(v) for k, v in ctx_args.num_chunk_probs.items()
|
||||
}
|
||||
|
||||
_get_tokenized_dataset = partial(
|
||||
get_tokenized_dataset,
|
||||
max_qas_len=ctx_args.max_qas_len,
|
||||
max_qas_per_sample=ctx_args.max_qas_per_sample,
|
||||
base_model_max_len=model.base_model.config.max_position_embeddings,
|
||||
tokenizer=tokenizer,
|
||||
ctx_model_max_len=ctx_model_max_len,
|
||||
ctx_tokenizer=ctx_tokenizer,
|
||||
add_ctx_to_chat=add_ctx_to_chat,
|
||||
add_negative_prompt=ctx_args.add_negative_prompt,
|
||||
max_ctx_chunk_len=ctx_args.max_ctx_chunk_len,
|
||||
min_ctx_chunk_len=ctx_args.min_ctx_chunk_len,
|
||||
num_chunk_probs=ctx_args.num_chunk_probs,
|
||||
max_ctx_chunk_num=ctx_args.max_ctx_chunk_num,
|
||||
use_kl_loss=ctx_args.use_kl_loss,
|
||||
)
|
||||
splits = ["train"]
|
||||
if training_args.eval_strategy != "no":
|
||||
splits.append("validation")
|
||||
tokenized_ds = {split: {} for split in splits}
|
||||
for split, ds_names in zip(
|
||||
splits,
|
||||
[data_args.train_ds_names, data_args.val_ds_names],
|
||||
):
|
||||
if not ds_names:
|
||||
continue
|
||||
ctx_mgr = (
|
||||
training_args.main_process_first()
|
||||
if split == "train"
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
with ctx_mgr:
|
||||
# process and tokenize on the main process
|
||||
# then other replicas can just load the cached dataset
|
||||
# we dont save cache for validation ds
|
||||
for ds_name in ds_names:
|
||||
ds = _get_tokenized_dataset(ds_name, split)
|
||||
|
||||
base_name = os.path.basename(ds_name)
|
||||
if ds_name.startswith("self_gen/"):
|
||||
ds_name = "self_gen/" + base_name
|
||||
else:
|
||||
ds_name = base_name
|
||||
|
||||
tokenized_ds[split][ds_name] = ds
|
||||
|
||||
train_ds = tokenized_ds["train"]
|
||||
if data_args.max_train_samples_per_ds is not None:
|
||||
for ds_name, ds in train_ds.items():
|
||||
if data_args.max_train_samples_per_ds >= len(ds):
|
||||
continue
|
||||
train_ds[ds_name] = ds.take(data_args.max_train_samples_per_ds)
|
||||
logging.info(f"train_ds: {train_ds}")
|
||||
|
||||
val_ds = dict()
|
||||
if "validation" in tokenized_ds:
|
||||
n_val_samples = data_args.max_val_samples_per_ds
|
||||
for ds_name, ds in tokenized_ds["validation"].items():
|
||||
if ds is None:
|
||||
# take some samples from train_ds
|
||||
ds = train_ds[ds_name].take(n_val_samples)
|
||||
train_ds[ds_name] = train_ds[ds_name].skip(n_val_samples)
|
||||
|
||||
val_ds[ds_name] = ds
|
||||
val_indices = np.random.permutation(len(ds))[:n_val_samples]
|
||||
val_ds[ds_name] = val_ds[ds_name].select(val_indices)
|
||||
|
||||
with training_args.main_process_first():
|
||||
train_ds = pack(
|
||||
train_ds,
|
||||
ctx_args.max_packed_inp_len,
|
||||
ctx_args.max_packed_ctx_len,
|
||||
max_packed_size=-1,
|
||||
seed=training_args.seed,
|
||||
num_proc=30,
|
||||
)
|
||||
logger.info("Setting per_device_train_batch_size to 1")
|
||||
training_args.per_device_train_batch_size = 1
|
||||
|
||||
logger.info(f"train_ds: {train_ds}")
|
||||
logger.info(f"val_ds: {val_ds}")
|
||||
|
||||
collator = flatten_if_not_packed
|
||||
|
||||
if isinstance(model, ModulatedPretrainedModel):
|
||||
if isinstance(model.base_model, PeftModel):
|
||||
base_model = model.base_model.base_model
|
||||
else:
|
||||
base_model = model.base_model
|
||||
|
||||
if ctx_name is not None:
|
||||
logger.info("Compiling ctx_encoder_model")
|
||||
ctx_base_model = model.ctx_encoder.base_model
|
||||
compile_linear(ctx_base_model)
|
||||
|
||||
elif isinstance(model, PeftModel):
|
||||
base_model = model.base_model
|
||||
|
||||
logger.info("Compiling base_model")
|
||||
base_model.compile(fullgraph=True, mode="max-autotune")
|
||||
|
||||
if LOCAL_RANK == 0:
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT"),
|
||||
name=run_name,
|
||||
group=run_name,
|
||||
config=args,
|
||||
tags=os.getenv("WANDB_TAGS").split(","),
|
||||
notes=ctx_args.notes,
|
||||
resume="allow",
|
||||
)
|
||||
else:
|
||||
wandb.init(mode="disabled")
|
||||
|
||||
train_model(
|
||||
model,
|
||||
training_args,
|
||||
train_ds,
|
||||
val_ds,
|
||||
collator,
|
||||
compute_metrics=partial(
|
||||
compute_metrics,
|
||||
evaluator=Evaluator(
|
||||
[compute_per_token_acc, compute_prefix_matching, compute_perplexity]
|
||||
),
|
||||
),
|
||||
)
|
||||
logger.info(f"Training run finished and saved to {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
os.environ["WANDB_DIR"] = ".wandb/"
|
||||
os.environ["WANDB_PROJECT"] = os.getenv("WANDB_PROJECT") or "ctx_to_lora"
|
||||
os.environ["WANDB_WATCH"] = ""
|
||||
os.environ["WANDB_CONSOLE"] = "off"
|
||||
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||
os.environ["OMP_NUM_THREADS"] = "23"
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
if os.getenv("DEBUG", False):
|
||||
disable_caching()
|
||||
main()
|
||||
122
watcher.py
Normal file
122
watcher.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
import itertools
|
||||
import os
|
||||
import time
|
||||
from argparse import Namespace
|
||||
from glob import glob
|
||||
|
||||
import wandb
|
||||
import yaml
|
||||
|
||||
from ctx_to_lora.eval_utils import run_eval
|
||||
from ctx_to_lora.utils import clear_gpu
|
||||
|
||||
CP_PATTERN = "train_outputs/runs/*/checkpoint*/pytorch_model.bin"
|
||||
|
||||
|
||||
def flatten(l):
|
||||
return itertools.chain.from_iterable(l)
|
||||
|
||||
|
||||
# handmade file watcher using glob
|
||||
# not using watchdog because there are too many saved files
|
||||
# but we want to just watch CP_PATTERN files
|
||||
class Watcher:
|
||||
def __init__(self, patterns):
|
||||
self.patterns = patterns
|
||||
self.files = self.get_files()
|
||||
self.last_files = self.files
|
||||
|
||||
def get_files(self):
|
||||
return set(flatten(glob(pattern) for pattern in self.patterns))
|
||||
|
||||
def watch(self):
|
||||
self.files = self.get_files()
|
||||
new_files = self.files - self.last_files
|
||||
return sorted(list(new_files))
|
||||
|
||||
def update(self, file):
|
||||
if file in self.last_files:
|
||||
return
|
||||
self.last_files.add(file)
|
||||
print(f"Added {file} to evaluated files.")
|
||||
|
||||
def save_state(self):
|
||||
with open("watcher_state.yaml", "w") as f:
|
||||
yaml.dump({"last_files": self.last_files}, f)
|
||||
|
||||
def load_state(self):
|
||||
if not os.path.exists("watcher_state.yaml"):
|
||||
return
|
||||
with open("watcher_state.yaml") as f:
|
||||
state = yaml.safe_load(f)
|
||||
self.last_files = state["last_files"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true"
|
||||
os.environ["FLASH_ATTENTION_DETERMINISTIC"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["WANDB_PROJECT"] = "ctx_to_lora"
|
||||
|
||||
watcher = Watcher([CP_PATTERN])
|
||||
watcher.load_state()
|
||||
print("Watching for new files...")
|
||||
while True:
|
||||
time.sleep(10)
|
||||
new_files = watcher.watch()
|
||||
for file in new_files:
|
||||
# workaround to prevent loading incomplete checkpoints
|
||||
time.sleep(20)
|
||||
if not os.path.exists(file):
|
||||
# cp is delete before we can read it
|
||||
continue
|
||||
run_dir = file.split("/checkpoint")[0]
|
||||
run_name = run_dir.split("/")[-1]
|
||||
print(f"Evaluating {file}")
|
||||
args = Namespace(**yaml.unsafe_load(open(f"{run_dir}/args.yaml")))
|
||||
curstep = int(file.split("checkpoint-")[1].split("/")[0])
|
||||
wandb_kwargs = {
|
||||
"project": os.getenv("WANDB_PROJECT"),
|
||||
"group": run_name,
|
||||
"name": f"{run_name}-eval",
|
||||
"id": f"{run_name}-eval",
|
||||
"resume": "allow",
|
||||
}
|
||||
wandb.init(**wandb_kwargs)
|
||||
|
||||
# TODO: have to change this for bigger models
|
||||
eval_batch_size = 8
|
||||
eval_batch_size_gen = 8
|
||||
metrics = {}
|
||||
|
||||
# try:
|
||||
# # metrics = run_eval(
|
||||
# # checkpoint_path=file,
|
||||
# # eval_batch_size=eval_batch_size,
|
||||
# # split="validation",
|
||||
# # generative=False,
|
||||
# # )
|
||||
# except FileNotFoundError as e:
|
||||
# print(f"Error evaluating {file}: {e}. The checkpoint might be deleted.")
|
||||
# continue
|
||||
try:
|
||||
gen_metrics = run_eval(
|
||||
checkpoint_path=file,
|
||||
split="validation",
|
||||
eval_batch_size=eval_batch_size_gen,
|
||||
max_ctx_chunk_len=args.max_ctx_chunk_len,
|
||||
generative=True,
|
||||
)
|
||||
except FileNotFoundError as e:
|
||||
print(f"The checkpoint might be deleted. Error evaluating {file}: {e}.")
|
||||
gen_metrics = {}
|
||||
file = ""
|
||||
metrics.update(gen_metrics)
|
||||
for k in metrics:
|
||||
wandb.log(metrics[k], step=curstep)
|
||||
wandb.finish()
|
||||
print(f"Logged metrics: {metrics}")
|
||||
print("=" * 80)
|
||||
clear_gpu()
|
||||
watcher.update(file)
|
||||
watcher.save_state()
|
||||
31
webui/SELF_GEN_VIEWER.md
Normal file
31
webui/SELF_GEN_VIEWER.md
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# Self-Gen Data Viewer
|
||||
|
||||
Thanks Claude.
|
||||
|
||||
Running the viewer
|
||||
```bash
|
||||
uv run self_gen_viewer.py
|
||||
```
|
||||
|
||||
Then open your browser and go to: **http://localhost:5001**
|
||||
|
||||
## Usage
|
||||
|
||||
1. **Select a Model Folder**: Choose from the dropdown list (e.g., `google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0`)
|
||||
2. **Select a Parquet File**: Once a folder is selected, available parquet files will appear
|
||||
3. **Set Number of Samples**: Adjust the sample count (default: 100, max: 1000)
|
||||
4. **Click "Load Data"**: View the visualized data with context and Q&A pairs
|
||||
|
||||
## Data Structure
|
||||
|
||||
The viewer expects data in the following structure:
|
||||
```
|
||||
data/raw_datasets/self_gen/
|
||||
├── google/
|
||||
│ └── gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0/
|
||||
│ └── fw_qa_v2/
|
||||
│ └── *.parquet
|
||||
└── mistralai/
|
||||
└── Mistral-7B-Instruct-v0.2_temp_0.0_closed_qa_prob_1.0/
|
||||
└── *.parquet
|
||||
```
|
||||
170
webui/self_gen_viewer.py
Normal file
170
webui/self_gen_viewer.py
Normal file
|
|
@ -0,0 +1,170 @@
|
|||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
from datasets import load_dataset
|
||||
from flask import Flask, jsonify, render_template, request
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# Base path for self_gen data
|
||||
BASE_DATA_PATH = Path(__file__).parent.parent / "data" / "raw_datasets" / "self_gen"
|
||||
|
||||
# Cache for tokenizers
|
||||
tokenizer_cache = {}
|
||||
|
||||
|
||||
def get_tokenizer(model_path):
|
||||
"""Get or create tokenizer with caching"""
|
||||
if model_path not in tokenizer_cache:
|
||||
try:
|
||||
tokenizer_cache[model_path] = AutoTokenizer.from_pretrained(model_path)
|
||||
except Exception as e:
|
||||
print(f"Error loading tokenizer for {model_path}: {e}")
|
||||
return None
|
||||
return tokenizer_cache[model_path]
|
||||
|
||||
|
||||
def discover_folders():
|
||||
"""Discover all model folders in self_gen directory"""
|
||||
folders = []
|
||||
if not BASE_DATA_PATH.exists():
|
||||
return folders
|
||||
|
||||
for vendor_dir in BASE_DATA_PATH.iterdir():
|
||||
if vendor_dir.is_dir():
|
||||
for model_dir in vendor_dir.iterdir():
|
||||
if model_dir.is_dir():
|
||||
rel_path = model_dir.relative_to(BASE_DATA_PATH)
|
||||
folders.append(str(rel_path))
|
||||
|
||||
return sorted(folders)
|
||||
|
||||
|
||||
def discover_parquet_files(folder_path):
|
||||
"""Discover all parquet files in a folder"""
|
||||
full_path = BASE_DATA_PATH / folder_path
|
||||
parquet_files = []
|
||||
|
||||
if full_path.exists():
|
||||
for parquet_file in full_path.glob("**/*.parquet"):
|
||||
rel_path = parquet_file.relative_to(full_path)
|
||||
parquet_files.append(str(rel_path))
|
||||
|
||||
return sorted(parquet_files)
|
||||
|
||||
|
||||
def extract_model_name_from_folder(folder_path):
|
||||
"""Extract base model name from folder path"""
|
||||
# e.g., "google/gemma-2-2b-it_temp_0.0_closed_qa_prob_1.0" -> "google/gemma-2-2b-it"
|
||||
parts = folder_path.split("/")
|
||||
if len(parts) >= 2:
|
||||
vendor = parts[0]
|
||||
model_part = parts[1].split("_temp_")[0]
|
||||
return f"{vendor}/{model_part}"
|
||||
return None
|
||||
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
"""Main page"""
|
||||
folders = discover_folders()
|
||||
return render_template("self_gen_viewer.html", folders=folders)
|
||||
|
||||
|
||||
@app.route("/api/folders")
|
||||
def api_folders():
|
||||
"""API endpoint to get available folders"""
|
||||
folders = discover_folders()
|
||||
return jsonify({"folders": folders})
|
||||
|
||||
|
||||
@app.route("/api/parquet_files")
|
||||
def api_parquet_files():
|
||||
"""API endpoint to get parquet files in a folder"""
|
||||
folder = request.args.get("folder", "")
|
||||
if not folder:
|
||||
return jsonify({"error": "No folder specified"}), 400
|
||||
|
||||
files = discover_parquet_files(folder)
|
||||
return jsonify({"files": files})
|
||||
|
||||
|
||||
@app.route("/api/load_data")
|
||||
def api_load_data():
|
||||
"""API endpoint to load and display data from a parquet file"""
|
||||
folder = request.args.get("folder", "")
|
||||
parquet_file = request.args.get("file", "")
|
||||
num_samples = int(request.args.get("num_samples", 100))
|
||||
|
||||
if not folder or not parquet_file:
|
||||
return jsonify({"error": "Missing parameters"}), 400
|
||||
|
||||
try:
|
||||
# Construct full path
|
||||
full_path = BASE_DATA_PATH / folder / parquet_file
|
||||
|
||||
if not full_path.exists():
|
||||
return jsonify({"error": f"File not found: {full_path}"}), 404
|
||||
|
||||
# Extract model name for tokenizer
|
||||
model_name = extract_model_name_from_folder(folder)
|
||||
if not model_name:
|
||||
return jsonify({"error": "Could not extract model name from folder"}), 400
|
||||
|
||||
# Load tokenizer
|
||||
tokenizer = get_tokenizer(model_name)
|
||||
if tokenizer is None:
|
||||
return jsonify({"error": f"Could not load tokenizer for {model_name}"}), 500
|
||||
|
||||
# Load dataset
|
||||
ds = load_dataset(
|
||||
"parquet", data_files=str(full_path), split=f"train[:{num_samples}]"
|
||||
)
|
||||
|
||||
# Process samples
|
||||
samples = []
|
||||
for i, sample in enumerate(ds):
|
||||
processed_sample = {
|
||||
"index": i,
|
||||
"ctx": tokenizer.decode(sample["ctx_ids"], skip_special_tokens=False)
|
||||
if "ctx_ids" in sample
|
||||
else "N/A",
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
# Decode input_ids if present
|
||||
if "input_ids" in sample:
|
||||
if isinstance(sample["input_ids"][0], list):
|
||||
# Multiple Q&A pairs
|
||||
processed_sample["questions"] = [
|
||||
tokenizer.decode(qa, skip_special_tokens=False)
|
||||
for qa in sample["input_ids"]
|
||||
]
|
||||
else:
|
||||
# Single item
|
||||
processed_sample["questions"] = [
|
||||
tokenizer.decode(sample["input_ids"], skip_special_tokens=False)
|
||||
]
|
||||
|
||||
samples.append(processed_sample)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"success": True,
|
||||
"num_samples": len(samples),
|
||||
"model_name": model_name,
|
||||
"file_path": str(parquet_file),
|
||||
"samples": samples,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(f"Data path: {BASE_DATA_PATH}")
|
||||
print(f"Available folders: {discover_folders()}")
|
||||
app.run(debug=True, host="0.0.0.0", port=5001)
|
||||
422
webui/templates/self_gen_viewer.html
Normal file
422
webui/templates/self_gen_viewer.html
Normal file
|
|
@ -0,0 +1,422 @@
|
|||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<title>Self-Gen Data Viewer</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
min-height: 100vh;
|
||||
padding: 20px;
|
||||
}
|
||||
|
||||
.container {
|
||||
max-width: 1400px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
border-radius: 15px;
|
||||
box-shadow: 0 10px 40px rgba(0, 0, 0, 0.3);
|
||||
padding: 30px;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #667eea;
|
||||
margin-bottom: 25px;
|
||||
text-align: center;
|
||||
font-size: 2.5em;
|
||||
}
|
||||
|
||||
.controls {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 15px;
|
||||
margin-bottom: 25px;
|
||||
}
|
||||
|
||||
.control-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.samples-row {
|
||||
display: grid;
|
||||
grid-template-columns: 1fr auto;
|
||||
gap: 15px;
|
||||
align-items: end;
|
||||
}
|
||||
|
||||
label {
|
||||
font-weight: 600;
|
||||
margin-bottom: 5px;
|
||||
color: #333;
|
||||
font-size: 0.9em;
|
||||
}
|
||||
|
||||
select,
|
||||
input {
|
||||
padding: 10px;
|
||||
border: 2px solid #ddd;
|
||||
border-radius: 8px;
|
||||
font-size: 1em;
|
||||
transition: border-color 0.3s;
|
||||
}
|
||||
|
||||
select:focus,
|
||||
input:focus {
|
||||
outline: none;
|
||||
border-color: #667eea;
|
||||
}
|
||||
|
||||
button {
|
||||
padding: 10px 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border: none;
|
||||
border-radius: 8px;
|
||||
cursor: pointer;
|
||||
font-size: 1em;
|
||||
font-weight: 600;
|
||||
transition: transform 0.2s, box-shadow 0.2s;
|
||||
margin-top: auto;
|
||||
}
|
||||
|
||||
button:hover {
|
||||
transform: translateY(-2px);
|
||||
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
|
||||
}
|
||||
|
||||
button:active {
|
||||
transform: translateY(0);
|
||||
}
|
||||
|
||||
button:disabled {
|
||||
background: #ccc;
|
||||
cursor: not-allowed;
|
||||
transform: none;
|
||||
}
|
||||
|
||||
.info-box {
|
||||
background: #f8f9fa;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
margin-bottom: 20px;
|
||||
border-left: 4px solid #667eea;
|
||||
}
|
||||
|
||||
.info-box p {
|
||||
margin: 5px 0;
|
||||
color: #555;
|
||||
}
|
||||
|
||||
.loading {
|
||||
text-align: center;
|
||||
padding: 40px;
|
||||
color: #667eea;
|
||||
font-size: 1.2em;
|
||||
}
|
||||
|
||||
.error {
|
||||
background: #fee;
|
||||
color: #c33;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
margin: 20px 0;
|
||||
border-left: 4px solid #c33;
|
||||
}
|
||||
|
||||
.sample {
|
||||
background: #f8f9fa;
|
||||
border: 1px solid #e0e0e0;
|
||||
border-radius: 10px;
|
||||
padding: 20px;
|
||||
margin-bottom: 20px;
|
||||
transition: box-shadow 0.3s;
|
||||
}
|
||||
|
||||
.sample:hover {
|
||||
box-shadow: 0 5px 15px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
|
||||
.sample-header {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 10px 15px;
|
||||
border-radius: 6px;
|
||||
margin-bottom: 15px;
|
||||
font-weight: 600;
|
||||
}
|
||||
|
||||
.section {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.section-title {
|
||||
font-weight: 600;
|
||||
color: #667eea;
|
||||
margin-bottom: 10px;
|
||||
font-size: 1.1em;
|
||||
border-bottom: 2px solid #667eea;
|
||||
padding-bottom: 5px;
|
||||
}
|
||||
|
||||
.content {
|
||||
background: white;
|
||||
padding: 15px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e0e0e0;
|
||||
white-space: pre-wrap;
|
||||
word-wrap: break-word;
|
||||
font-family: 'Courier New', monospace;
|
||||
font-size: 0.9em;
|
||||
line-height: 1.6;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.question-item {
|
||||
background: #fff;
|
||||
padding: 12px;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #ddd;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.question-number {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
padding: 3px 8px;
|
||||
border-radius: 4px;
|
||||
font-size: 0.85em;
|
||||
font-weight: 600;
|
||||
display: inline-block;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
#results {
|
||||
margin-top: 30px;
|
||||
}
|
||||
|
||||
.load-more {
|
||||
text-align: center;
|
||||
margin-top: 20px;
|
||||
}
|
||||
|
||||
.stats {
|
||||
display: grid;
|
||||
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
|
||||
gap: 15px;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
|
||||
.stat-card {
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
padding: 15px;
|
||||
border-radius: 8px;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.stat-value {
|
||||
font-size: 2em;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
.stat-label {
|
||||
font-size: 0.9em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar {
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-track {
|
||||
background: #f1f1f1;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb {
|
||||
background: #667eea;
|
||||
border-radius: 4px;
|
||||
}
|
||||
|
||||
::-webkit-scrollbar-thumb:hover {
|
||||
background: #764ba2;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🔍 Self-Gen Data Viewer</h1>
|
||||
|
||||
<div class="controls">
|
||||
<div class="control-group">
|
||||
<label for="folder-select">Model Folder:</label>
|
||||
<select id="folder-select">
|
||||
<option value="">-- Select a folder --</option>
|
||||
{% for folder in folders %}
|
||||
<option value="{{ folder }}">{{ folder }}</option>
|
||||
{% endfor %}
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="control-group">
|
||||
<label for="file-select">Parquet File:</label>
|
||||
<select id="file-select" disabled>
|
||||
<option value="">-- Select a file --</option>
|
||||
</select>
|
||||
</div>
|
||||
|
||||
<div class="samples-row">
|
||||
<div class="control-group">
|
||||
<label for="num-samples">Number of Samples:</label>
|
||||
<input type="number" id="num-samples" value="100" min="1" max="1000" step="10">
|
||||
</div>
|
||||
<button id="load-btn" disabled>Load Data</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div id="info" style="display: none;"></div>
|
||||
<div id="results"></div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
const folderSelect = document.getElementById('folder-select');
|
||||
const fileSelect = document.getElementById('file-select');
|
||||
const numSamplesInput = document.getElementById('num-samples');
|
||||
const loadBtn = document.getElementById('load-btn');
|
||||
const infoDiv = document.getElementById('info');
|
||||
const resultsDiv = document.getElementById('results');
|
||||
|
||||
// When folder is selected, load parquet files
|
||||
folderSelect.addEventListener('change', async () => {
|
||||
const folder = folderSelect.value;
|
||||
fileSelect.innerHTML = '<option value="">-- Select a file --</option>';
|
||||
fileSelect.disabled = true;
|
||||
loadBtn.disabled = true;
|
||||
|
||||
if (!folder) return;
|
||||
|
||||
try {
|
||||
const response = await fetch(`/api/parquet_files?folder=${encodeURIComponent(folder)}`);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.files && data.files.length > 0) {
|
||||
data.files.forEach(file => {
|
||||
const option = document.createElement('option');
|
||||
option.value = file;
|
||||
option.textContent = file;
|
||||
fileSelect.appendChild(option);
|
||||
});
|
||||
fileSelect.disabled = false;
|
||||
} else {
|
||||
alert('No parquet files found in this folder');
|
||||
}
|
||||
} catch (error) {
|
||||
alert('Error loading files: ' + error.message);
|
||||
}
|
||||
});
|
||||
|
||||
// Enable load button when file is selected
|
||||
fileSelect.addEventListener('change', () => {
|
||||
loadBtn.disabled = !fileSelect.value;
|
||||
});
|
||||
|
||||
// Load data button
|
||||
loadBtn.addEventListener('click', loadData);
|
||||
|
||||
async function loadData() {
|
||||
const folder = folderSelect.value;
|
||||
const file = fileSelect.value;
|
||||
const numSamples = numSamplesInput.value;
|
||||
|
||||
if (!folder || !file) return;
|
||||
|
||||
resultsDiv.innerHTML = '<div class="loading">⏳ Loading data...</div>';
|
||||
infoDiv.style.display = 'none';
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/load_data?folder=${encodeURIComponent(folder)}&file=${encodeURIComponent(file)}&num_samples=${numSamples}`
|
||||
);
|
||||
const data = await response.json();
|
||||
|
||||
if (data.error) {
|
||||
resultsDiv.innerHTML = `<div class="error"><strong>Error:</strong> ${data.error}</div>`;
|
||||
return;
|
||||
}
|
||||
|
||||
displayData(data);
|
||||
} catch (error) {
|
||||
resultsDiv.innerHTML = `<div class="error"><strong>Error:</strong> ${error.message}</div>`;
|
||||
}
|
||||
}
|
||||
|
||||
function displayData(data) {
|
||||
// Show info box
|
||||
infoDiv.style.display = 'block';
|
||||
infoDiv.innerHTML = `
|
||||
<div class="info-box">
|
||||
<div class="stats">
|
||||
<div class="stat-card">
|
||||
<div class="stat-value">${data.num_samples}</div>
|
||||
<div class="stat-label">Samples</div>
|
||||
</div>
|
||||
<div class="stat-card">
|
||||
<div class="stat-value">${data.samples.length > 0 ? data.samples[0].questions.length : 0}</div>
|
||||
<div class="stat-label">Questions per Sample</div>
|
||||
</div>
|
||||
</div>
|
||||
<p><strong>Model:</strong> ${data.model_name}</p>
|
||||
<p><strong>File:</strong> ${data.file_path}</p>
|
||||
</div>
|
||||
`;
|
||||
|
||||
// Show samples
|
||||
let html = '';
|
||||
data.samples.forEach(sample => {
|
||||
html += `
|
||||
<div class="sample">
|
||||
<div class="sample-header">Sample #${sample.index}</div>
|
||||
|
||||
<div class="section">
|
||||
<div class="section-title">📝 Context</div>
|
||||
<div class="content">${escapeHtml(sample.ctx)}</div>
|
||||
</div>
|
||||
|
||||
<div class="section">
|
||||
<div class="section-title">❓ Questions & Answers (${sample.questions.length})</div>
|
||||
${sample.questions.map((q, i) => `
|
||||
<div class="question-item">
|
||||
<span class="question-number">Q&A ${i + 1}</span>
|
||||
<div class="content">${escapeHtml(q)}</div>
|
||||
</div>
|
||||
`).join('')}
|
||||
</div>
|
||||
</div>
|
||||
`;
|
||||
});
|
||||
|
||||
resultsDiv.innerHTML = html;
|
||||
}
|
||||
|
||||
function escapeHtml(text) {
|
||||
const div = document.createElement('div');
|
||||
div.textContent = text;
|
||||
return div.innerHTML;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
Loading…
Add table
Add a link
Reference in a new issue