mirror of
https://github.com/SakanaAI/doc-to-lora.git
synced 2026-04-26 08:36:23 +02:00
Doc-to-LoRA release
This commit is contained in:
commit
1abe8ae16d
92 changed files with 22131 additions and 0 deletions
620
data/self_generate_qa.py
Normal file
620
data/self_generate_qa.py
Normal file
|
|
@ -0,0 +1,620 @@
|
|||
import argparse
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from datasets import Dataset, load_dataset
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ctx_to_lora.data.definitions import (
|
||||
CLOSED_QA_INTX_TEMPLATES,
|
||||
RAW_DATA_DIR,
|
||||
SELF_GEN_DATA_DIR,
|
||||
)
|
||||
from ctx_to_lora.data.processing import (
|
||||
filter_none,
|
||||
get_preprocessing_fn,
|
||||
load_and_process_dataset,
|
||||
tokenize_ctx_text,
|
||||
)
|
||||
from ctx_to_lora.data.self_gen_template import (
|
||||
PRE_CTX,
|
||||
PROMPT_TEMPLATE,
|
||||
QA_PROMPT_TEMPLATE,
|
||||
SELF_GEN_SYSTEM_MSG,
|
||||
SELF_QA_INTX,
|
||||
)
|
||||
from ctx_to_lora.model_loading import get_tokenizer
|
||||
from ctx_to_lora.utils import clear_gpu
|
||||
|
||||
STOP_STRINGS = {
|
||||
"google/gemma-2-2b-it": ["<eos>", "<end_of_turn>"],
|
||||
}
|
||||
|
||||
MODEL_CTX_LEN = {
|
||||
"google/gemma-2-27b-it": 8192,
|
||||
"google/gemma-2-2b-it": 8192,
|
||||
"google/gemma-2-9b-it": 8192,
|
||||
# qwen 4b has 256k ctx length but using lower max lengths is faster
|
||||
"Qwen/Qwen3-4B-Instruct-2507": 2**13 + 2**12,
|
||||
}
|
||||
|
||||
|
||||
def truncate_middle_if_too_long(
|
||||
input_ids: list[int],
|
||||
max_length: int,
|
||||
max_new_tokens: int = 256,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Truncate the middle of a list of tokens to fit within a maximum length.
|
||||
|
||||
Args:
|
||||
tokens: List of token IDs
|
||||
max_length: Maximum length for the truncated tokens
|
||||
|
||||
Returns:
|
||||
List of truncated token IDs
|
||||
"""
|
||||
max_new_tokens_half = max_new_tokens // 2
|
||||
# leave max_new_tokens for generation
|
||||
half = max_length // 2 - max_new_tokens_half
|
||||
if len(input_ids) > max_length:
|
||||
return input_ids[:half] + input_ids[-half:]
|
||||
return input_ids
|
||||
|
||||
|
||||
def get_prompt(context: str, q: str, remove_qa_template: bool) -> str:
|
||||
prompt = QA_PROMPT_TEMPLATE if not remove_qa_template else PROMPT_TEMPLATE
|
||||
return prompt.format(context=context, question=q)
|
||||
|
||||
|
||||
def add_closed_qa_prompt(q: str, closed_qa_prob: float = 0.1) -> str:
|
||||
if random.random() <= closed_qa_prob:
|
||||
q = random.choice(CLOSED_QA_INTX_TEMPLATES).format(input=q)
|
||||
return q
|
||||
|
||||
|
||||
def load_config(config_path: str) -> dict:
|
||||
"""Load dataset names from YAML config file."""
|
||||
with open(config_path) as f:
|
||||
config = yaml.safe_load(f)
|
||||
return config
|
||||
|
||||
|
||||
def get_dataset_configs(
|
||||
ds_names: list[str] | None,
|
||||
config: dict | None,
|
||||
split: str | None,
|
||||
) -> list[tuple[str, str]]:
|
||||
assert not (ds_names and config), "Cannot provide both ds_names and config"
|
||||
if ds_names:
|
||||
assert split, "When using ds_names, --split must be provided"
|
||||
# Validate ds_names format
|
||||
for ds_name in ds_names:
|
||||
if not isinstance(ds_name, str):
|
||||
raise ValueError(f"Invalid dataset name: {ds_name}")
|
||||
return [(ds_name, split) for ds_name in ds_names]
|
||||
|
||||
if config:
|
||||
dataset_configs = []
|
||||
|
||||
# Process train datasets
|
||||
train_ds_names = config.get("train_ds_names", [])
|
||||
# self_gen_train_ds_names = [
|
||||
# (ds_name.split("/")[-1], "train")
|
||||
# for ds_name in train_ds_names
|
||||
# if ds_name.startswith("self_gen/")
|
||||
# ]
|
||||
self_gen_train_ds_names = [
|
||||
(ds_name, "train")
|
||||
for ds_name in train_ds_names
|
||||
if ds_name.startswith("self_gen/")
|
||||
]
|
||||
if not self_gen_train_ds_names:
|
||||
print("No self_gen datasets found in train_ds_names")
|
||||
dataset_configs.extend(self_gen_train_ds_names)
|
||||
|
||||
# Process validation datasets
|
||||
val_ds_names = config.get("val_ds_names", [])
|
||||
self_gen_val_ds_names = [
|
||||
(ds_name, "validation")
|
||||
for ds_name in val_ds_names
|
||||
if ds_name.startswith("self_gen/")
|
||||
]
|
||||
if not self_gen_val_ds_names:
|
||||
print("No self_gen datasets found in val_ds_names")
|
||||
dataset_configs.extend(self_gen_val_ds_names)
|
||||
|
||||
return dataset_configs
|
||||
|
||||
|
||||
def create_messages(
|
||||
ctxs: list[str],
|
||||
questions: list[list[str]],
|
||||
vllm_model: str,
|
||||
system_template: str,
|
||||
remove_qa_template: bool,
|
||||
) -> list[list[dict]]:
|
||||
"""Create chat messages for the model."""
|
||||
# if "gemma" in vllm_model:
|
||||
# gemma models do not support system messages
|
||||
return [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
system_template + "\n\n\n" + get_prompt(ctx, q, remove_qa_template)
|
||||
).strip(),
|
||||
}
|
||||
]
|
||||
for ctx, q_list in zip(ctxs, questions)
|
||||
for q in q_list
|
||||
]
|
||||
# else:
|
||||
# return [
|
||||
# [
|
||||
# {"role": "system", "content": system_template},
|
||||
# {"role": "user", "content": get_prompt(ctx, q)},
|
||||
# ]
|
||||
# for ctx, q_list in zip(ctxs, questions)
|
||||
# for q in q_list
|
||||
# ]
|
||||
|
||||
|
||||
def self_generate(
|
||||
ds_name: str,
|
||||
split: str,
|
||||
args: argparse.Namespace,
|
||||
llm: LLM,
|
||||
system_template: str,
|
||||
parquet_file: str | None = None,
|
||||
do_truncate: bool = False,
|
||||
) -> None:
|
||||
"""Process a single dataset and generate QA pairs."""
|
||||
|
||||
shard_name = ""
|
||||
|
||||
# Conflict checks for ds_name-derived overrides
|
||||
if ds_name is not None:
|
||||
# temperature & closed_qa already handled later; add new ones
|
||||
if "_temp_" in ds_name and args.temp != 0.0:
|
||||
raise ValueError(
|
||||
f"Multiple sources of truth for temperature: CLI arg --temp={args.temp} and dataset name contains temp specification."
|
||||
)
|
||||
if "_closed_qa_prob_" in ds_name and args.closed_qa_prob != 0.0:
|
||||
raise ValueError(
|
||||
f"Multiple sources of truth for closed_qa_prob: CLI arg --closed_qa_prob={args.closed_qa_prob} and dataset name contains closed_qa_prob specification."
|
||||
)
|
||||
|
||||
# Base values from args
|
||||
temp = args.temp
|
||||
closed_qa_prob = args.closed_qa_prob
|
||||
|
||||
# Overrides from ds_name pattern if present
|
||||
if ds_name is not None:
|
||||
if "_temp_" in ds_name:
|
||||
m = re.search(r"_temp_([\d.]+)", ds_name)
|
||||
if m:
|
||||
temp = float(m.group(1))
|
||||
if "_closed_qa_prob_" in ds_name:
|
||||
m = re.search(r"_closed_qa_prob_([\d.]+)", ds_name)
|
||||
if m:
|
||||
closed_qa_prob = float(m.group(1))
|
||||
|
||||
print(f"Processing dataset: {ds_name}, split: {split}")
|
||||
print(f"Using temperature: {temp}")
|
||||
print(f"Using closed QA prompt probability: {closed_qa_prob}")
|
||||
|
||||
if parquet_file:
|
||||
print(f"Loading dataset from parquet file: {parquet_file}")
|
||||
|
||||
split = "train"
|
||||
ds_name = "/".join(parquet_file.split(RAW_DATA_DIR)[-1].split("/")[:-1])
|
||||
|
||||
shard_name = "_" + os.path.basename(parquet_file).replace(".parquet", "")
|
||||
ds = load_dataset(path="parquet", data_files=[parquet_file], split="train")
|
||||
processing_fn = get_preprocessing_fn(ds_name, is_eval=False)
|
||||
ds = ds.map(processing_fn, num_proc=8)
|
||||
|
||||
else:
|
||||
ds_name = ds_name.split("/")[-1] # Extract just the dataset name
|
||||
|
||||
print(f"Loading dataset: {ds_name} with split: {split}")
|
||||
kwargs = dict(ds_name=ds_name, split=split)
|
||||
|
||||
ds = load_and_process_dataset(**kwargs, num_proc=8, remove_cols=False)
|
||||
print(f"Loaded dataset: {ds_name} with split: {split}")
|
||||
|
||||
if args.debug:
|
||||
ds = ds.take(10)
|
||||
|
||||
ds = ds.filter(filter_none, batched=False, num_proc=8)
|
||||
|
||||
tk = get_tokenizer(args.vllm_model, train=True)
|
||||
|
||||
self_qa_intx_tokens = tk(SELF_QA_INTX, add_special_tokens=False)["input_ids"][1:]
|
||||
if args.remove_qa_template:
|
||||
self_qa_intx_tokens = tk("\n\n", add_special_tokens=False)["input_ids"]
|
||||
n_self_qa_intx_tokens = len(self_qa_intx_tokens)
|
||||
pre_ctx_tokens = tk(PRE_CTX, add_special_tokens=False)["input_ids"]
|
||||
n_pre_ctx_tokens = len(pre_ctx_tokens)
|
||||
sys_tokens = tk(system_template.split("\n")[0], add_special_tokens=False)[
|
||||
"input_ids"
|
||||
][:-1]
|
||||
n_sys_tokens = len(sys_tokens)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
||||
ds = ds.map(
|
||||
tokenize_ctx_text,
|
||||
fn_kwargs={"tokenizer": tk},
|
||||
batched=True,
|
||||
batch_size=50_000,
|
||||
keep_in_memory=True,
|
||||
)
|
||||
|
||||
ctxs = [sample["context"] for sample in ds]
|
||||
questions = [
|
||||
[add_closed_qa_prompt(q, closed_qa_prob) for q in sample["prompts"] if q]
|
||||
for sample in ds
|
||||
]
|
||||
|
||||
questions = [q_list for q_list in ds["prompts"] if len(q_list) > 0]
|
||||
|
||||
print(f"Loaded {len(ctxs)} contexts and {len(questions)} questions")
|
||||
|
||||
k = 16
|
||||
fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
|
||||
|
||||
chunk_size = 1_000
|
||||
for chunk_idx, start in enumerate(range(0, len(ctxs), chunk_size)):
|
||||
print(f"Processing chunk {chunk_idx}")
|
||||
|
||||
chunk_ctxs = ctxs[start : start + chunk_size]
|
||||
chunk_questions = questions[start : start + chunk_size]
|
||||
chunk_messages = create_messages(
|
||||
chunk_ctxs,
|
||||
chunk_questions,
|
||||
args.vllm_model,
|
||||
SELF_GEN_SYSTEM_MSG,
|
||||
args.remove_qa_template,
|
||||
)
|
||||
|
||||
if do_truncate:
|
||||
# we should only do this for evaluation data
|
||||
tokenized_contents = tk(
|
||||
[m[0]["content"] for m in chunk_messages],
|
||||
add_special_tokens=False,
|
||||
return_attention_mask=False,
|
||||
)
|
||||
tokenized_contents["input_ids"] = [
|
||||
truncate_middle_if_too_long(
|
||||
ids,
|
||||
max_length=MODEL_CTX_LEN[args.vllm_model],
|
||||
max_new_tokens=args.max_new_tokens,
|
||||
)
|
||||
for ids in tokenized_contents["input_ids"]
|
||||
]
|
||||
contents = tk.batch_decode(
|
||||
tokenized_contents["input_ids"], skip_special_tokens=True
|
||||
)
|
||||
for c, m in zip(contents, chunk_messages):
|
||||
m[0]["content"] = c
|
||||
|
||||
print(f"Generating from {len(chunk_messages)} contexts")
|
||||
|
||||
# Clear GPU memory before processing the next chunk
|
||||
clear_gpu()
|
||||
execute_qa_generation(
|
||||
fpath + f"_{chunk_idx:04d}",
|
||||
args,
|
||||
llm,
|
||||
temp,
|
||||
tk,
|
||||
self_qa_intx_tokens,
|
||||
n_self_qa_intx_tokens,
|
||||
sys_tokens,
|
||||
n_sys_tokens,
|
||||
chunk_ctxs,
|
||||
ds[start : start + chunk_size]["ctx_ids"],
|
||||
chunk_questions,
|
||||
chunk_messages,
|
||||
k,
|
||||
)
|
||||
|
||||
|
||||
def execute_qa_generation(
|
||||
fpath,
|
||||
args,
|
||||
llm,
|
||||
temp,
|
||||
tk,
|
||||
self_qa_intx_tokens,
|
||||
n_self_qa_intx_tokens,
|
||||
sys_tokens,
|
||||
n_sys_tokens,
|
||||
ctxs,
|
||||
ctx_ids,
|
||||
questions,
|
||||
messages,
|
||||
k,
|
||||
):
|
||||
completions = llm.chat(
|
||||
messages,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=args.max_new_tokens,
|
||||
logprobs=k,
|
||||
temperature=temp,
|
||||
seed=42,
|
||||
spaces_between_special_tokens=False,
|
||||
skip_special_tokens=False,
|
||||
include_stop_str_in_output=True,
|
||||
),
|
||||
)
|
||||
|
||||
self_gen_data = {
|
||||
ctx: {
|
||||
"ctx_ids": ctx_ids,
|
||||
"input_ids": [],
|
||||
"response_start_end": [],
|
||||
"logprobs_vals": [],
|
||||
"logprobs_indices": [],
|
||||
}
|
||||
for ctx, ctx_ids in zip(ctxs, ctx_ids)
|
||||
}
|
||||
c = 0
|
||||
n_skips = 0
|
||||
sys_start = None
|
||||
for ctx, q_list in zip(ctxs, questions):
|
||||
# self_gen_data[ctx]["ctx_ids"] = ctx_ids
|
||||
for i, _ in enumerate(q_list):
|
||||
# response = completions[c + i].outputs[0].text
|
||||
reason = completions[c + i].outputs[0].finish_reason
|
||||
if reason != "stop":
|
||||
# print(f"idx: {c + i}")
|
||||
print(f"finish_reason: {completions[c + i].outputs[0].finish_reason}")
|
||||
print(f"Skipping due to finish_reason={reason} != 'stop'")
|
||||
n_skips += 1
|
||||
continue
|
||||
|
||||
# includes the logprob before the first response token
|
||||
# but excludes the logprob from eos token
|
||||
logp = completions[c + i].outputs[0].logprobs
|
||||
|
||||
# len = num response tokens
|
||||
n_response_tokens = len(completions[c + i].outputs[0].token_ids)
|
||||
|
||||
logp_indices = np.empty((n_response_tokens, k), dtype=np.int32)
|
||||
# float-16 is better for this range
|
||||
logp_vals = np.empty((n_response_tokens, k), dtype=np.float16)
|
||||
assert len(logp) == n_response_tokens, (
|
||||
f"Expected {n_response_tokens} logp entries, got {len(logp)}"
|
||||
)
|
||||
|
||||
for li, info_d in enumerate(logp):
|
||||
for j, (idx, tok_info) in enumerate(info_d.items()):
|
||||
logp_indices[li, j] = idx
|
||||
logp_vals[li, j] = tok_info.logprob
|
||||
|
||||
prompt_ids = completions[c + i].prompt_token_ids # 1d list
|
||||
# token_ids only includes generated tokens, not the prompt
|
||||
response_token_ids = completions[c + i].outputs[0].token_ids # 1d list
|
||||
all_ids = prompt_ids + response_token_ids
|
||||
res_start = len(prompt_ids)
|
||||
res_end = res_start + n_response_tokens
|
||||
|
||||
if sys_start is None:
|
||||
for ii in range(len(prompt_ids) - n_sys_tokens):
|
||||
if prompt_ids[ii : ii + n_sys_tokens] == sys_tokens:
|
||||
# found the start of the system message
|
||||
sys_start = ii
|
||||
break
|
||||
|
||||
q_start = None
|
||||
for ii in range(
|
||||
len(prompt_ids) - n_self_qa_intx_tokens,
|
||||
-1,
|
||||
-1,
|
||||
):
|
||||
if prompt_ids[ii : ii + n_self_qa_intx_tokens] == self_qa_intx_tokens:
|
||||
# found the start of the user input
|
||||
q_start = ii + n_self_qa_intx_tokens
|
||||
break
|
||||
|
||||
# bos + question + eos + start model turn + response + eos
|
||||
input_ids = all_ids[:sys_start] + all_ids[q_start:res_end]
|
||||
|
||||
# relative to the input_ids
|
||||
res_start = res_start - q_start + sys_start
|
||||
res_end = res_start + n_response_tokens
|
||||
|
||||
# arrays will be saved as nested lists of numbers
|
||||
|
||||
self_gen_data[ctx]["input_ids"].append(input_ids)
|
||||
# assume single-turn chat
|
||||
self_gen_data[ctx]["response_start_end"].append((res_start, res_end))
|
||||
self_gen_data[ctx]["logprobs_vals"].append(logp_vals)
|
||||
self_gen_data[ctx]["logprobs_indices"].append(logp_indices)
|
||||
|
||||
c += i + 1
|
||||
|
||||
print(f"Skipped {n_skips} responses due to missing stop strings")
|
||||
samples = [
|
||||
{
|
||||
# "context": ctx,
|
||||
# "prompts": q_list,
|
||||
# "responses": self_gen_data[ctx]["responses"],
|
||||
"ctx_ids": self_gen_data[ctx]["ctx_ids"],
|
||||
"input_ids": self_gen_data[ctx]["input_ids"],
|
||||
"response_start_end": self_gen_data[ctx]["response_start_end"],
|
||||
# "prompt_start_end": self_gen_data[ctx]["prompt_start_end"],
|
||||
"logprobs_vals": self_gen_data[ctx]["logprobs_vals"],
|
||||
"logprobs_indices": self_gen_data[ctx]["logprobs_indices"],
|
||||
}
|
||||
for ctx, q_list in zip(ctxs, questions)
|
||||
]
|
||||
|
||||
if args.debug:
|
||||
for sample in samples:
|
||||
# print(f"context={tk.decode(sample['ctx_ids'])}")
|
||||
print(f"QA={[tk.decode(ids) for ids in sample['input_ids']]}")
|
||||
|
||||
for input_ids, (start, end) in zip(
|
||||
sample["input_ids"], sample["response_start_end"]
|
||||
):
|
||||
print(f"start={start}, end={end}")
|
||||
print(f"response={tk.decode(input_ids[start:end])}")
|
||||
|
||||
print(f"logprobs_vals={[x.shape for x in sample['logprobs_vals']]}")
|
||||
print(f"logprobs_indices={[x.shape for x in sample['logprobs_indices']]}")
|
||||
for indices in sample["logprobs_indices"]:
|
||||
print(f"logprobs_indices={indices[-1]}")
|
||||
print("=" * 80)
|
||||
|
||||
print(f"Generated {len(samples)} samples")
|
||||
# random.shuffle(samples)
|
||||
|
||||
# Save results
|
||||
# df = pd.DataFrame(samples)
|
||||
# ds_out = Dataset.from_pandas(df)
|
||||
ds_out = Dataset.from_list(samples)
|
||||
# fpath = f"{SELF_GEN_DATA_DIR}/{args.vllm_model}_temp_{temp}_closed_qa_prob_{closed_qa_prob}/{ds_name}/{split}/ds{shard_name}"
|
||||
|
||||
if args.debug:
|
||||
fpath += "_debug"
|
||||
os.makedirs(os.path.dirname(fpath), exist_ok=True)
|
||||
|
||||
fpath = f"{fpath}.parquet"
|
||||
ds_out.to_parquet(fpath)
|
||||
print(f"Saved to {fpath}")
|
||||
|
||||
# Cleanup
|
||||
del samples, ds_out, completions, messages, ctxs, questions
|
||||
clear_gpu()
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Generate QA pairs using VLLM")
|
||||
parser.add_argument(
|
||||
"--vllm_model",
|
||||
type=str,
|
||||
required=True,
|
||||
help="VLLM model name (e.g., google/gemma-2-2b-it)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode (process only 10 samples)",
|
||||
)
|
||||
|
||||
# Either config file OR ds_names + split
|
||||
group = parser.add_mutually_exclusive_group(required=True)
|
||||
group.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
help="Path to YAML config file with train_ds_names/val_ds_names",
|
||||
)
|
||||
group.add_argument(
|
||||
"--ds_names",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="List of dataset names/shard patterns",
|
||||
)
|
||||
group.add_argument(
|
||||
"--glob_pattern",
|
||||
type=str,
|
||||
help="Glob pattern to match dataset names (e.g., 'data/raw_datasets/fw_qa_3/*')",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split",
|
||||
type=str,
|
||||
help="Dataset split to use when using --ds_names (required with --ds_names)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temp",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Temperature for sampling (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--closed_qa_prob",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Probability of using closed QA prompt template (default: 0.0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--do_truncate",
|
||||
action="store_true",
|
||||
help="Truncate contexts to fit model context length",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove_qa_template",
|
||||
action="store_true",
|
||||
help="Remove QA template formatting from prompts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_new_tokens",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Maximum number of new tokens to generate (default: 256)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# Validate arguments
|
||||
if args.ds_names and not args.split:
|
||||
raise ValueError("--split is required when using --ds_names")
|
||||
|
||||
vllm_model = args.vllm_model
|
||||
print(f"Using model: {vllm_model}")
|
||||
|
||||
# Setup model-specific configurations
|
||||
llm_kwargs = dict(
|
||||
model=vllm_model,
|
||||
dtype="bfloat16",
|
||||
enable_prefix_caching=True,
|
||||
enable_chunked_prefill=True,
|
||||
max_model_len=MODEL_CTX_LEN.get(vllm_model),
|
||||
max_num_batched_tokens=16384,
|
||||
max_num_seqs=32, # avoid oom when getting logprobs
|
||||
)
|
||||
|
||||
print(f"{llm_kwargs=}")
|
||||
llm = LLM(**llm_kwargs)
|
||||
|
||||
# Get dataset configs from config or CLI args
|
||||
config = load_config(args.config) if args.config else None
|
||||
if args.ds_names or args.config:
|
||||
dataset_configs = get_dataset_configs(
|
||||
ds_names=args.ds_names,
|
||||
config=config,
|
||||
split=args.split,
|
||||
)
|
||||
|
||||
# Process each dataset
|
||||
for ds_name, split in dataset_configs:
|
||||
print(f"Processing dataset: {ds_name}, split: {split}")
|
||||
self_generate(
|
||||
ds_name, split, args, llm, SELF_GEN_SYSTEM_MSG, None, args.do_truncate
|
||||
)
|
||||
else:
|
||||
assert args.glob_pattern, (
|
||||
"glob_pattern must be provided if no ds_names or config"
|
||||
)
|
||||
files = glob(args.glob_pattern)
|
||||
for file in files:
|
||||
print(f"Processing file: {file}")
|
||||
self_generate(
|
||||
ds_name=None,
|
||||
parquet_file=file,
|
||||
split=args.split,
|
||||
args=args,
|
||||
llm=llm,
|
||||
system_template=SELF_GEN_SYSTEM_MSG,
|
||||
do_truncate=args.do_truncate,
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue