Doc-to-LoRA release

This commit is contained in:
51616 2026-02-27 03:47:04 +00:00
commit 1abe8ae16d
92 changed files with 22131 additions and 0 deletions

620
data/self_generate_qa.py Normal file
View file

@ -0,0 +1,620 @@
import argparse
import os
import random
import re
from glob import glob
import numpy as np
import yaml
from datasets import Dataset, load_dataset
from vllm import LLM, SamplingParams
from ctx_to_lora.data.definitions import (
CLOSED_QA_INTX_TEMPLATES,
RAW_DATA_DIR,
SELF_GEN_DATA_DIR,
)
from ctx_to_lora.data.processing import (
filter_none,
get_preprocessing_fn,
load_and_process_dataset,
tokenize_ctx_text,
)
from ctx_to_lora.data.self_gen_template import (
PRE_CTX,
PROMPT_TEMPLATE,
QA_PROMPT_TEMPLATE,
SELF_GEN_SYSTEM_MSG,
SELF_QA_INTX,
)
from ctx_to_lora.model_loading import get_tokenizer
from ctx_to_lora.utils import clear_gpu
STOP_STRINGS = {
"google/gemma-2-2b-it": ["<eos>", "<end_of_turn>"],
}
MODEL_CTX_LEN = {
"google/gemma-2-27b-it": 8192,
"google/gemma-2-2b-it": 8192,
"google/gemma-2-9b-it": 8192,
# qwen 4b has 256k ctx length but using lower max lengths is faster
"Qwen/Qwen3-4B-Instruct-2507": 2**13 + 2**12,
}
def truncate_middle_if_too_long(
input_ids: list[int],
max_length: int,
max_new_tokens: int = 256,
) -> list[int]:
"""
Truncate the middle of a list of tokens to fit within a maximum length.
Args:
tokens: List of token IDs
max_length: Maximum length for the truncated tokens
Returns:
List of truncated token IDs
"""
max_new_tokens_half = max_new_tokens // 2
# leave max_new_tokens for generation
half = max_length // 2 - max_new_tokens_half
if len(input_ids) > max_length:
return input_ids[:half] + input_ids[-half:]
return input_ids
def get_prompt(context: str, q: str, remove_qa_template: bool) -> str:
prompt = QA_PROMPT_TEMPLATE if not remove_qa_template else PROMPT_TEMPLATE
return prompt.format(context=context, question=q)
def add_closed_qa_prompt(q: str, closed_qa_prob: float = 0.1) -> str:
if random.random() <= closed_qa_prob:
q = random.choice(CLOSED_QA_INTX_TEMPLATES).format(input=q)
return q
def load_config(config_path: str) -> dict:
"""Load dataset names from YAML config file."""
with open(config_path) as f:
config = yaml.safe_load(f)
return config
def get_dataset_configs(
ds_names: list[str] | None,
config: dict | None,
split: str | None,
) -> list[tuple[str, str]]:
assert not (ds_names and config), "Cannot provide both ds_names and config"
if ds_names:
assert split, "When using ds_names, --split must be provided"
# Validate ds_names format
for ds_name in ds_names:
if not isinstance(ds_name, str):
raise ValueError(f"Invalid dataset name: {ds_name}")
return [(ds_name, split) for ds_name in ds_names]
if config:
dataset_configs = []
# Process train datasets
train_ds_names = config.get("train_ds_names", [])
# self_gen_train_ds_names = [
# (ds_name.split("/")[-1], "train")
# for ds_name in train_ds_names
# if ds_name.startswith("self_gen/")
# ]
self_gen_train_ds_names = [
(ds_name, "train")
for ds_name in train_ds_names
if ds_name.startswith("self_gen/")
]
if not self_gen_train_ds_names:
print("No self_gen datasets found in train_ds_names")
dataset_configs.extend(self_gen_train_ds_names)
# Process validation datasets
val_ds_names = config.get("val_ds_names", [])
self_gen_val_ds_names = [
(ds_name, "validation")
for ds_name in val_ds_names
if ds_name.startswith("self_gen/")
]
if not self_gen_val_ds_names:
print("No self_gen datasets found in val_ds_names")
dataset_configs.extend(self_gen_val_ds_names)
return dataset_configs
def create_messages(
ctxs: list[str],
questions: list[list[str]],
vllm_model: str,
system_template: str,
remove_qa_template: bool,
) -> list[list[dict]]:
"""Create chat messages for the model."""
# if "gemma" in vllm_model:
# gemma models do not support system messages
return [
[
{
"role": "user",
"content": (
system_template + "\n\n\n" + get_prompt(ctx, q, remove_qa_template)
).strip(),
}
]
for ctx, q_list in zip(ctxs, questions)
for q in q_list
]
# else:
# return [
# [
# {"role": "system", "content": system_template},
# {"role": "user", "content": get_prompt(ctx, q)},
# ]
# for ctx, q_list in zip(ctxs, questions)
# for q in q_list
# ]
def self_generate(
ds_name: str,
split: str,
args: argparse.Namespace,
llm: LLM,
system_template: str,
parquet_file: str | None = None,
do_truncate: bool = False,
) -> None:
"""Process a single dataset and generate QA pairs."""
shard_name = ""
# Conflict checks for ds_name-derived overrides
if ds_name is not None:
# temperature & closed_qa already handled later; add new ones
if "_temp_" in ds_name and args.temp != 0.0:
raise ValueError(
f"Multiple sources of truth for temperature: CLI arg --temp={args.temp} and dataset name contains temp specification."
)
if "_closed_qa_prob_" in ds_name and args.closed_qa_prob != 0.0:
raise ValueError(
f"Multiple sources of truth for closed_qa_prob: CLI arg --closed_qa_prob={args.closed_qa_prob} and dataset name contains closed_qa_prob specification."
)
# Base values from args
temp = args.temp
closed_qa_prob = args.closed_qa_prob
# Overrides from ds_name pattern if present
if ds_name is not None:
if "_temp_" in ds_name:
m = re.search(r"_temp_([\d.]+)", ds_name)
if m:
temp = float(m.group(1))
if "_closed_qa_prob_" in ds_name:
m = re.search(r"_closed_qa_prob_([\d.]+)", ds_name)
if m:
closed_qa_prob = float(m.group(1))
print(f"Processing dataset: {ds_name}, split: {split}")
print(f"Using temperature: {temp}")
print(f"Using closed QA prompt probability: {closed_qa_prob}")
if parquet_file:
print(f"Loading dataset from parquet file: {parquet_file}")
split = "train"
ds_name = "/".join(parquet_file.split(RAW_DATA_DIR)[-1].split("/")[:-1])
shard_name = "_" + os.path.basename(parquet_file).replace(".parquet", "")
ds = load_dataset(path="parquet", data_files=[parquet_file], split="train")
processing_fn = get_preprocessing_fn(ds_name, is_eval=False)
ds = ds.map(processing_fn, num_proc=8)
else:
ds_name = ds_name.split("/")[-1] # Extract just the dataset name
print(f"Loading dataset: {ds_name} with split: {split}")
kwargs = dict(ds_name=ds_name, split=split)
ds = load_and_process_dataset(**kwargs, num_proc=8, remove_cols=False)
print(f"Loaded dataset: {ds_name} with split: {split}")
if args.debug:
ds = ds.take(10)
ds = ds.filter(filter_none, batched=False, num_proc=8)
tk = get_tokenizer(args.vllm_model, train=True)
self_qa_intx_tokens = tk(SELF_QA_INTX, add_special_tokens=False)["input_ids"][1:]
if args.remove_qa_template:
self_qa_intx_tokens = tk("\n\n", add_special_tokens=False)["input_ids"]
n_self_qa_intx_tokens = len(self_qa_intx_tokens)
pre_ctx_tokens = tk(PRE_CTX, add_special_tokens=False)["input_ids"]
n_pre_ctx_tokens = len(pre_ctx_tokens)
sys_tokens = tk(system_template.split("\n")[0], add_special_tokens=False)[
"input_ids"
][:-1]
n_sys_tokens = len(sys_tokens)
os.environ["TOKENIZERS_PARALLELISM"] = "true"
ds = ds.map(
tokenize_ctx_text,
fn_kwargs={"tokenizer": tk},
batched=True,
batch_size=50_000,
keep_in_memory=True,
)
ctxs = [sample["context"] for sample in ds]
questions = [
[add_closed_qa_prompt(q, closed_qa_prob) for q in sample["prompts"] if q]
for sample in ds
]
questions = [q_list for q_list in ds["prompts"] if len(q_list) > 0]
print(f"Loaded {len(ctxs)} contexts and {len(questions)} questions")
k = 16
fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
chunk_size = 1_000
for chunk_idx, start in enumerate(range(0, len(ctxs), chunk_size)):
print(f"Processing chunk {chunk_idx}")
chunk_ctxs = ctxs[start : start + chunk_size]
chunk_questions = questions[start : start + chunk_size]
chunk_messages = create_messages(
chunk_ctxs,
chunk_questions,
args.vllm_model,
SELF_GEN_SYSTEM_MSG,
args.remove_qa_template,
)
if do_truncate:
# we should only do this for evaluation data
tokenized_contents = tk(
[m[0]["content"] for m in chunk_messages],
add_special_tokens=False,
return_attention_mask=False,
)
tokenized_contents["input_ids"] = [
truncate_middle_if_too_long(
ids,
max_length=MODEL_CTX_LEN[args.vllm_model],
max_new_tokens=args.max_new_tokens,
)
for ids in tokenized_contents["input_ids"]
]
contents = tk.batch_decode(
tokenized_contents["input_ids"], skip_special_tokens=True
)
for c, m in zip(contents, chunk_messages):
m[0]["content"] = c
print(f"Generating from {len(chunk_messages)} contexts")
# Clear GPU memory before processing the next chunk
clear_gpu()
execute_qa_generation(
fpath + f"_{chunk_idx:04d}",
args,
llm,
temp,
tk,
self_qa_intx_tokens,
n_self_qa_intx_tokens,
sys_tokens,
n_sys_tokens,
chunk_ctxs,
ds[start : start + chunk_size]["ctx_ids"],
chunk_questions,
chunk_messages,
k,
)
def execute_qa_generation(
fpath,
args,
llm,
temp,
tk,
self_qa_intx_tokens,
n_self_qa_intx_tokens,
sys_tokens,
n_sys_tokens,
ctxs,
ctx_ids,
questions,
messages,
k,
):
completions = llm.chat(
messages,
sampling_params=SamplingParams(
max_tokens=args.max_new_tokens,
logprobs=k,
temperature=temp,
seed=42,
spaces_between_special_tokens=False,
skip_special_tokens=False,
include_stop_str_in_output=True,
),
)
self_gen_data = {
ctx: {
"ctx_ids": ctx_ids,
"input_ids": [],
"response_start_end": [],
"logprobs_vals": [],
"logprobs_indices": [],
}
for ctx, ctx_ids in zip(ctxs, ctx_ids)
}
c = 0
n_skips = 0
sys_start = None
for ctx, q_list in zip(ctxs, questions):
# self_gen_data[ctx]["ctx_ids"] = ctx_ids
for i, _ in enumerate(q_list):
# response = completions[c + i].outputs[0].text
reason = completions[c + i].outputs[0].finish_reason
if reason != "stop":
# print(f"idx: {c + i}")
print(f"finish_reason: {completions[c + i].outputs[0].finish_reason}")
print(f"Skipping due to finish_reason={reason} != 'stop'")
n_skips += 1
continue
# includes the logprob before the first response token
# but excludes the logprob from eos token
logp = completions[c + i].outputs[0].logprobs
# len = num response tokens
n_response_tokens = len(completions[c + i].outputs[0].token_ids)
logp_indices = np.empty((n_response_tokens, k), dtype=np.int32)
# float-16 is better for this range
logp_vals = np.empty((n_response_tokens, k), dtype=np.float16)
assert len(logp) == n_response_tokens, (
f"Expected {n_response_tokens} logp entries, got {len(logp)}"
)
for li, info_d in enumerate(logp):
for j, (idx, tok_info) in enumerate(info_d.items()):
logp_indices[li, j] = idx
logp_vals[li, j] = tok_info.logprob
prompt_ids = completions[c + i].prompt_token_ids # 1d list
# token_ids only includes generated tokens, not the prompt
response_token_ids = completions[c + i].outputs[0].token_ids # 1d list
all_ids = prompt_ids + response_token_ids
res_start = len(prompt_ids)
res_end = res_start + n_response_tokens
if sys_start is None:
for ii in range(len(prompt_ids) - n_sys_tokens):
if prompt_ids[ii : ii + n_sys_tokens] == sys_tokens:
# found the start of the system message
sys_start = ii
break
q_start = None
for ii in range(
len(prompt_ids) - n_self_qa_intx_tokens,
-1,
-1,
):
if prompt_ids[ii : ii + n_self_qa_intx_tokens] == self_qa_intx_tokens:
# found the start of the user input
q_start = ii + n_self_qa_intx_tokens
break
# bos + question + eos + start model turn + response + eos
input_ids = all_ids[:sys_start] + all_ids[q_start:res_end]
# relative to the input_ids
res_start = res_start - q_start + sys_start
res_end = res_start + n_response_tokens
# arrays will be saved as nested lists of numbers
self_gen_data[ctx]["input_ids"].append(input_ids)
# assume single-turn chat
self_gen_data[ctx]["response_start_end"].append((res_start, res_end))
self_gen_data[ctx]["logprobs_vals"].append(logp_vals)
self_gen_data[ctx]["logprobs_indices"].append(logp_indices)
c += i + 1
print(f"Skipped {n_skips} responses due to missing stop strings")
samples = [
{
# "context": ctx,
# "prompts": q_list,
# "responses": self_gen_data[ctx]["responses"],
"ctx_ids": self_gen_data[ctx]["ctx_ids"],
"input_ids": self_gen_data[ctx]["input_ids"],
"response_start_end": self_gen_data[ctx]["response_start_end"],
# "prompt_start_end": self_gen_data[ctx]["prompt_start_end"],
"logprobs_vals": self_gen_data[ctx]["logprobs_vals"],
"logprobs_indices": self_gen_data[ctx]["logprobs_indices"],
}
for ctx, q_list in zip(ctxs, questions)
]
if args.debug:
for sample in samples:
# print(f"context={tk.decode(sample['ctx_ids'])}")
print(f"QA={[tk.decode(ids) for ids in sample['input_ids']]}")
for input_ids, (start, end) in zip(
sample["input_ids"], sample["response_start_end"]
):
print(f"start={start}, end={end}")
print(f"response={tk.decode(input_ids[start:end])}")
print(f"logprobs_vals={[x.shape for x in sample['logprobs_vals']]}")
print(f"logprobs_indices={[x.shape for x in sample['logprobs_indices']]}")
for indices in sample["logprobs_indices"]:
print(f"logprobs_indices={indices[-1]}")
print("=" * 80)
print(f"Generated {len(samples)} samples")
# random.shuffle(samples)
# Save results
# df = pd.DataFrame(samples)
# ds_out = Dataset.from_pandas(df)
ds_out = Dataset.from_list(samples)
# fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
if args.debug:
fpath += "_debug"
os.makedirs(os.path.dirname(fpath), exist_ok=True)
fpath = f"{fpath}.parquet"
ds_out.to_parquet(fpath)
print(f"Saved to {fpath}")
# Cleanup
del samples, ds_out, completions, messages, ctxs, questions
clear_gpu()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Generate QA pairs using VLLM")
parser.add_argument(
"--vllm_model",
type=str,
required=True,
help="VLLM model name (e.g., google/gemma-2-2b-it)",
)
parser.add_argument(
"--debug",
action="store_true",
help="Enable debug mode (process only 10 samples)",
)
# Either config file OR ds_names + split
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"--config",
type=str,
help="Path to YAML config file with train_ds_names/val_ds_names",
)
group.add_argument(
"--ds_names",
type=str,
nargs="+",
help="List of dataset names/shard patterns",
)
group.add_argument(
"--glob_pattern",
type=str,
help="Glob pattern to match dataset names (e.g., 'data/raw_datasets/fw_qa_3/*')",
)
parser.add_argument(
"--split",
type=str,
help="Dataset split to use when using --ds_names (required with --ds_names)",
)
parser.add_argument(
"--temp",
type=float,
default=0.0,
help="Temperature for sampling (default: 0.0)",
)
parser.add_argument(
"--closed_qa_prob",
type=float,
default=0.0,
help="Probability of using closed QA prompt template (default: 0.0)",
)
parser.add_argument(
"--do_truncate",
action="store_true",
help="Truncate contexts to fit model context length",
)
parser.add_argument(
"--remove_qa_template",
action="store_true",
help="Remove QA template formatting from prompts",
)
parser.add_argument(
"--max_new_tokens",
type=int,
default=256,
help="Maximum number of new tokens to generate (default: 256)",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
# Validate arguments
if args.ds_names and not args.split:
raise ValueError("--split is required when using --ds_names")
vllm_model = args.vllm_model
print(f"Using model: {vllm_model}")
# Setup model-specific configurations
llm_kwargs = dict(
model=vllm_model,
dtype="bfloat16",
enable_prefix_caching=True,
enable_chunked_prefill=True,
max_model_len=MODEL_CTX_LEN.get(vllm_model),
max_num_batched_tokens=16384,
max_num_seqs=32, # avoid oom when getting logprobs
)
print(f"{llm_kwargs=}")
llm = LLM(**llm_kwargs)
# Get dataset configs from config or CLI args
config = load_config(args.config) if args.config else None
if args.ds_names or args.config:
dataset_configs = get_dataset_configs(
ds_names=args.ds_names,
config=config,
split=args.split,
)
# Process each dataset
for ds_name, split in dataset_configs:
print(f"Processing dataset: {ds_name}, split: {split}")
self_generate(
ds_name, split, args, llm, SELF_GEN_SYSTEM_MSG, None, args.do_truncate
)
else:
assert args.glob_pattern, (
"glob_pattern must be provided if no ds_names or config"
)
files = glob(args.glob_pattern)
for file in files:
print(f"Processing file: {file}")
self_generate(
ds_name=None,
parquet_file=file,
split=args.split,
args=args,
llm=llm,
system_template=SELF_GEN_SYSTEM_MSG,
do_truncate=args.do_truncate,
)