mirror of
https://github.com/ranausmanai/tinyforge-zero.git
synced 2026-06-08 20:55:13 +02:00
Ship every paper-referenced experiment script
Reorganizes the repo so every section of the paper has a corresponding
script. Previously only the core recipe + control + evals were here.
New subdirs:
- tts/ — test-time sampling (§2.2, §3.3): scaling sweep, HE, MATH-500,
AIME, 14B-recipe + TTS, 8B-raw-TTS control.
- experiments/ — every §3 finding as a runnable script:
· self_consistency (§3.4)
· recipe_x_tts_synergy (§3.5, novel)
· mbpp_seeded_cross_arch (§3.9)
· cross_domain_code_to_math (§3.10)
· self_correction_math_{naive,fixed} (§3.10, the
catastrophic-then-recovered case)
· math500_seeded_mining (§3.10 distribution mismatch)
· bcb_hard_eval (§3.10 distribution mismatch)
· recursive_bootstrap (§3.10 plateau)
· diversity_cued_mining (§3.10 low yield)
· aime_scaling (TTS curve)
· star_baseline_gsm8k (related-work baseline)
- evals/ — moved out of recipe/ (eval_raw, eval_plus, confirm)
Also adds: bootstrap_14b_4bit_harvest, curriculum_code, math_bootstrap to
recipe/ for completeness.
REPRODUCE.md now maps each paper section / table / figure to its exact
script and expected output.
This commit is contained in:
parent
c867697f7c
commit
826f934d2e
27 changed files with 4467 additions and 134 deletions
91
experiments/aime_scaling.py
Normal file
91
experiments/aime_scaling.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
"""TTS scaling on AIME — pass@k curve from k=1 to k=64."""
|
||||
import os, json, time, re, argparse
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def extract_int(text):
|
||||
m = re.search(r"\\boxed\{(\d+)\}", text)
|
||||
if m:
|
||||
try: return int(m.group(1))
|
||||
except: return None
|
||||
nums = re.findall(r"\b(\d+)\b", text.strip().split("\n")[-3:][-1] if text.strip().split("\n") else "")
|
||||
if nums:
|
||||
try: return int(nums[-1])
|
||||
except: pass
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=3072)
|
||||
log("loaded")
|
||||
|
||||
ds = list(load_dataset("AI-MO/aimo-validation-aime", split="train"))
|
||||
log(f" AIME: {len(ds)} problems")
|
||||
|
||||
UTMPL = "Solve this AIME problem. Answer is integer 0-999. End with \\boxed{{N}}.\n\nProblem: {p}\n\nSolution:"
|
||||
prompts = []
|
||||
for p in ds:
|
||||
try:
|
||||
msgs = [{"role": "system", "content": "AIME solver. End with \\boxed{integer}."},
|
||||
{"role": "user", "content": UTMPL.format(p=p["problem"])}]
|
||||
prompts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
except Exception:
|
||||
prompts.append(UTMPL.format(p=p["problem"]))
|
||||
|
||||
MAX_N = 64
|
||||
sp = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=1500, n=MAX_N)
|
||||
log(f"generating {MAX_N} samples per problem...")
|
||||
t0 = time.time()
|
||||
outs = llm.generate(prompts, sp, use_tqdm=False)
|
||||
log(f" gen in {time.time()-t0:.1f}s")
|
||||
|
||||
# Per-task per-sample correctness
|
||||
per_task_results = []
|
||||
for p, outset in zip(ds, outs):
|
||||
gold = int(p["answer"])
|
||||
per_sample = []
|
||||
for o in outset.outputs:
|
||||
pred = extract_int(o.text)
|
||||
per_sample.append(pred == gold)
|
||||
per_task_results.append(per_sample)
|
||||
|
||||
NS = [1, 2, 4, 8, 16, 32, 64]
|
||||
scaling = {}
|
||||
for k in NS:
|
||||
scaling[k] = sum(1 for r in per_task_results if any(r[:k]))
|
||||
|
||||
result = {"model": args.model, "tag": args.tag, "MAX_N": MAX_N,
|
||||
"n_total": len(ds), "pass_at_k": scaling, "elapsed_s": time.time() - T0}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — AIME TTS SCALING")
|
||||
for k in NS:
|
||||
print(f" pass@{k:<3}: {scaling[k]:>3}/{len(ds)} ({100*scaling[k]/len(ds):.1f}%)")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
190
experiments/bcb_hard_eval.py
Normal file
190
experiments/bcb_hard_eval.py
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
"""Train Qwen3-8B-Base with 40-pair recipe, eval on BigCodeBench-Hard.
|
||||
|
||||
BigCodeBench is harder than HumanEval (real-world Python tasks, library use).
|
||||
Qwen3-8B-Base likely has headroom there (~30-45% baseline). Tests if recipe
|
||||
generalizes to newer model AND harder benchmark.
|
||||
"""
|
||||
import os, json, time, re, subprocess, tempfile, argparse
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from datasets import load_dataset, Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def extract_code(text):
|
||||
if "```python" in text: text = text.split("```python", 1)[1]
|
||||
elif "```" in text: text = text.split("```", 1)[1]
|
||||
if "```" in text: text = text.split("```", 1)[0]
|
||||
return text.strip()
|
||||
|
||||
|
||||
def verify_bcb(code, test_code):
|
||||
runner = "\n\nif __name__ == '__main__':\n import unittest; unittest.main(argv=['x'], exit=False, verbosity=0)\n"
|
||||
body = code + "\n\n" + test_code + runner
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(body); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=20, text=True, cwd="/tmp")
|
||||
out = (r.stdout or "") + "\n" + (r.stderr or "")
|
||||
if "OK" in out and "FAILED" not in out and "Error" not in out and r.returncode == 0:
|
||||
return True
|
||||
return False
|
||||
except subprocess.TimeoutExpired:
|
||||
return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=4):
|
||||
outs = []
|
||||
for i in range(0, len(prompts), batch):
|
||||
chunk = prompts[i:i+batch]
|
||||
texts = []
|
||||
for p in chunk:
|
||||
msgs = [{"role": "system", "content": "You are an expert Python coder. Output one ```python block with the complete solution."},
|
||||
{"role": "user", "content": p}]
|
||||
texts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
inp = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=2000).to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inp, max_new_tokens=max_new, do_sample=temperature > 0,
|
||||
temperature=temperature if temperature > 0 else 1.0, top_p=0.95,
|
||||
pad_token_id=tok.eos_token_id)
|
||||
for j in range(out.size(0)):
|
||||
outs.append(tok.decode(out[j][inp.input_ids.shape[1]:], skip_special_tokens=True))
|
||||
return outs
|
||||
|
||||
|
||||
def eval_bcb_hard(model, tok, label, max_n=148):
|
||||
bcb = list(load_dataset("bigcode/bigcodebench-hard", split="v0.1.4"))[:max_n]
|
||||
log(f" BCB-Hard [{label}] ({len(bcb)})")
|
||||
prompts = [p["instruct_prompt"] for p in bcb]
|
||||
outs = gen_batch(model, tok, prompts, max_new=700, batch=4)
|
||||
correct = 0
|
||||
for i, (p, raw) in enumerate(zip(bcb, outs)):
|
||||
code = extract_code(raw) if "```" in raw else raw
|
||||
if verify_bcb(code, p["test"]): correct += 1
|
||||
if (i+1) % 20 == 0: log(f" {label} BCB {i+1}/{len(bcb)}: {correct}")
|
||||
return correct, len(bcb)
|
||||
|
||||
|
||||
def eval_humaneval(model, tok, label):
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
log(f" HumanEval [{label}] ({len(he)})")
|
||||
prompts = [p["prompt"] + "\n# Complete the function above." for p in he]
|
||||
outs = gen_batch(model, tok, prompts, max_new=400, batch=4)
|
||||
correct = 0
|
||||
for i, (p, raw) in enumerate(zip(he, outs)):
|
||||
code = extract_code(raw) if "```" in raw else raw
|
||||
full = p["prompt"] + "\n" + code if "def " not in code else code
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(test_code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=10, text=True, cwd="/tmp")
|
||||
if r.returncode == 0: correct += 1
|
||||
except subprocess.TimeoutExpired: pass
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
if (i+1) % 40 == 0: log(f" {label} HE {i+1}/{len(he)}: {correct}")
|
||||
return correct, len(he)
|
||||
|
||||
|
||||
def make_example(r, tok):
|
||||
user = (f"Implement: {r['signature']}\n\n"
|
||||
f"Tests:\n{chr(10).join(r['tests'])}\n\n"
|
||||
f"My attempt:\n```python\n{r['broken']}\n```\n\n"
|
||||
f"Error:\n{r.get('error','')}\n\n"
|
||||
f"Fix and output the corrected code only.")
|
||||
assistant = f"```python\n{r['fixed']}\n```"
|
||||
msgs_pre = [{"role": "system", "content": "You are an expert Python coder. Output one ```python block with the complete solution."},
|
||||
{"role": "user", "content": user}]
|
||||
msgs_full = msgs_pre + [{"role": "assistant", "content": assistant}]
|
||||
pre = tok.apply_chat_template(msgs_pre, tokenize=False, add_generation_prompt=True)
|
||||
full = tok.apply_chat_template(msgs_full, tokenize=False)
|
||||
pre_ids = tok(pre, add_special_tokens=False)["input_ids"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_pre = min(len(pre_ids), len(labels))
|
||||
for i in range(n_pre): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--pairs", default="/workspace/saved_pairs/pairs_40.jsonl")
|
||||
ap.add_argument("--n_pairs", type=int, default=40)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/bcb_eval/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
log(f" loaded mem={torch.cuda.memory_allocated('cuda:0')/1e9:.1f}GB")
|
||||
|
||||
model.eval()
|
||||
log("=== BASE evals ===")
|
||||
base_he, _ = eval_humaneval(model, tok, "BASE")
|
||||
base_bcb, _ = eval_bcb_hard(model, tok, "BASE")
|
||||
log(f" BASE: HumanEval={base_he}/164 BCB-Hard={base_bcb}/148")
|
||||
|
||||
pairs = [json.loads(l) for l in open(args.pairs)][:args.n_pairs]
|
||||
log(f"=== TRAINING — {len(pairs)} pairs ===")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
tok.padding_side = "right"
|
||||
ds = HFDataset.from_list([make_example(r, tok) for r in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=10,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds, processing_class=tok).train()
|
||||
log(" training done")
|
||||
tok.padding_side = "left"
|
||||
|
||||
model.eval()
|
||||
log("=== TRAINED evals ===")
|
||||
tr_he, _ = eval_humaneval(model, tok, "TRAINED")
|
||||
tr_bcb, _ = eval_bcb_hard(model, tok, "TRAINED")
|
||||
|
||||
result = {
|
||||
"model": args.model, "method": "warmup 40 pairs",
|
||||
"humaneval": {"base": base_he, "trained": tr_he, "delta": tr_he-base_he, "n": 164},
|
||||
"bcb_hard": {"base": base_bcb, "trained": tr_bcb, "delta": tr_bcb-base_bcb, "n": 148},
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model}")
|
||||
print(f" HumanEval: base={base_he}/164 trained={tr_he}/164 Δ={tr_he-base_he:+d}")
|
||||
print(f" BCB-Hard: base={base_bcb}/148 trained={tr_bcb}/148 Δ={tr_bcb-base_bcb:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
222
experiments/cross_domain_code_to_math.py
Normal file
222
experiments/cross_domain_code_to_math.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""Cross-domain transfer: train recipe on CODE, eval on MATH (no math training).
|
||||
Tests if self-bootstrap teaches generic reasoning vs domain-specific patterns."""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def run_python(code, timeout=10):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0
|
||||
except subprocess.TimeoutExpired: return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def extract_boxed(text):
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx < 0: return None
|
||||
start = idx + len("\\boxed{"); depth = 1; i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{": depth += 1
|
||||
elif text[i] == "}": depth -= 1
|
||||
i += 1
|
||||
if depth != 0: return None
|
||||
return text[start:i-1].strip()
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--train_domain", choices=["code", "math"], default="code")
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
random.seed(42)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log("loaded")
|
||||
|
||||
# Eval sets
|
||||
he = list(load_dataset("openai_humaneval", split="test"))[:80]
|
||||
math500 = list(load_dataset("HuggingFaceH4/MATH-500", split="test"))[:100]
|
||||
|
||||
# Build prompts
|
||||
he_prompts = [p["prompt"] for p in he]
|
||||
math_prompts = []
|
||||
for p in math500:
|
||||
try:
|
||||
msgs = [{"role": "system", "content": "Math solver. End with \\boxed{answer}."},
|
||||
{"role": "user", "content": f"Solve. Problem: {p['problem']}\n\nSolution:"}]
|
||||
math_prompts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
except Exception:
|
||||
math_prompts.append(f"Solve. Problem: {p['problem']}\n\nSolution:")
|
||||
|
||||
import sympy
|
||||
from sympy.parsing.latex import parse_latex
|
||||
def sympy_eq(a, b):
|
||||
if a is None or b is None: return False
|
||||
if a.strip() == b.strip(): return True
|
||||
try:
|
||||
if sympy.simplify(parse_latex(a) - parse_latex(b)) == 0: return True
|
||||
except Exception: pass
|
||||
try:
|
||||
if abs(float(a) - float(b)) < 1e-6: return True
|
||||
except Exception: pass
|
||||
return False
|
||||
|
||||
def eval_he(llm, lora_req=None):
|
||||
sp = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
outs = llm.generate(he_prompts, sp, lora_request=lora_req, use_tqdm=False) if lora_req else \
|
||||
llm.generate(he_prompts, sp, use_tqdm=False)
|
||||
outs = [o.outputs[0].text for o in outs]
|
||||
c = 0
|
||||
for p, raw in zip(he, outs):
|
||||
full = p["prompt"] + "\n" + raw
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10): c += 1
|
||||
return c, len(he)
|
||||
|
||||
def eval_math(llm, lora_req=None):
|
||||
sp = SamplingParams(temperature=0, max_tokens=800)
|
||||
outs = llm.generate(math_prompts, sp, lora_request=lora_req, use_tqdm=False) if lora_req else \
|
||||
llm.generate(math_prompts, sp, use_tqdm=False)
|
||||
outs = [o.outputs[0].text for o in outs]
|
||||
c = 0
|
||||
for p, raw in zip(math500, outs):
|
||||
if sympy_eq(extract_boxed(raw), p["answer"]): c += 1
|
||||
return c, len(math500)
|
||||
|
||||
log("=== BASE evals ===")
|
||||
base_he = eval_he(llm)
|
||||
base_math = eval_math(llm)
|
||||
log(f" base HE: {base_he[0]}/{base_he[1]} MATH: {base_math[0]}/{base_math[1]}")
|
||||
|
||||
# Mine code pairs
|
||||
log("mining code pairs...")
|
||||
mbpp_full = list(load_dataset("mbpp", split="train"))
|
||||
random.shuffle(mbpp_full)
|
||||
seeds = []
|
||||
for p in mbpp_full[:200]:
|
||||
prompt_text = p.get("prompt") or p.get("text", "")
|
||||
if prompt_text and p.get("test_list"):
|
||||
seeds.append({"prompt": prompt_text, "test_list": p["test_list"]})
|
||||
|
||||
def mbpp_prompt(p): return f"# Task: {p['prompt']}\n# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n"
|
||||
|
||||
sp = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass Test", "\nif __name__"])
|
||||
g_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in seeds], sp, use_tqdm=False)]
|
||||
hard_idx = []
|
||||
for i, (p, raw) in enumerate(zip(seeds, g_outs)):
|
||||
if not run_python(raw + "\n\n" + "\n".join(p["test_list"]), 8):
|
||||
hard_idx.append(i)
|
||||
log(f" greedy: {len(seeds)-len(hard_idx)} pass, {len(hard_idx)} hard")
|
||||
pairs = []
|
||||
if hard_idx:
|
||||
sp2 = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=400, n=8,
|
||||
stop=["\nclass Test", "\nif __name__"])
|
||||
hard_prompts = [mbpp_prompt(seeds[i]) for i in hard_idx]
|
||||
sample_outs = llm.generate(hard_prompts, sp2, use_tqdm=False)
|
||||
for j, i in enumerate(hard_idx):
|
||||
attempts = [o.text for o in sample_outs[j].outputs]
|
||||
for a in attempts:
|
||||
if run_python(a + "\n\n" + "\n".join(seeds[i]["test_list"]), 8):
|
||||
pairs.append({"problem": seeds[i]["prompt"], "tests": seeds[i]["test_list"],
|
||||
"broken": g_outs[i].strip(), "fixed": a.strip()})
|
||||
break
|
||||
log(f" mined {len(pairs)} code pairs")
|
||||
|
||||
if len(pairs) < 5:
|
||||
log("too few pairs, skipping train")
|
||||
result = {"model": args.model, "n_pairs": len(pairs),
|
||||
"base_he": base_he[0], "base_math": base_math[0]}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
return
|
||||
|
||||
# Tear down vLLM, train LoRA
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
def mk_ex(r):
|
||||
user = (f"# Task: {r['problem']}\n# Tests:\n# " + "\n# ".join(r['tests']) + "\n"
|
||||
f"# My broken attempt:\n{r['broken']}\n# Corrected:\n")
|
||||
full = user + r["fixed"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids); n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
log("training LoRA on code pairs...")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([mk_ex(r) for r in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{args.out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
adapter_dir = f"{args.out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
log("training done")
|
||||
|
||||
# Re-eval with adapter
|
||||
log("=== TRAINED evals ===")
|
||||
from vllm import LLM as LLM2
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM2(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("trained", 1, adapter_dir)
|
||||
tr_he = eval_he(llm, lora_req)
|
||||
tr_math = eval_math(llm, lora_req)
|
||||
log(f" trained HE: {tr_he[0]}/{tr_he[1]} MATH: {tr_math[0]}/{tr_math[1]}")
|
||||
|
||||
result = {
|
||||
"model": args.model, "train_domain": args.train_domain,
|
||||
"n_pairs": len(pairs),
|
||||
"humaneval": {"base": base_he[0], "trained": tr_he[0], "delta": tr_he[0]-base_he[0], "n": base_he[1]},
|
||||
"math500": {"base": base_math[0], "trained": tr_math[0], "delta": tr_math[0]-base_math[0], "n": base_math[1]},
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — CROSS-DOMAIN ({args.train_domain} train, eval HE+MATH)")
|
||||
print(f" HE: base={base_he[0]}/{base_he[1]} trained={tr_he[0]}/{tr_he[1]} Δ={tr_he[0]-base_he[0]:+d}")
|
||||
print(f" MATH: base={base_math[0]}/{base_math[1]} trained={tr_math[0]}/{tr_math[1]} Δ={tr_math[0]-base_math[0]:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
180
experiments/diversity_cued_mining.py
Normal file
180
experiments/diversity_cued_mining.py
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
"""Diversity-aware mining: prompt model with multiple cognitive lenses, mine pairs WITHOUT including failed code.
|
||||
Train on (problem, best_approach_summary, working_code) — minimal traces."""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def run_python(code, timeout=10):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0
|
||||
except subprocess.TimeoutExpired: return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
LENS_PROMPTS = [
|
||||
("brute force iteration", "# Loop and check each case."),
|
||||
("math formula", "# Use a closed-form formula."),
|
||||
("hash map/set", "# Use a hashmap/set for O(1) lookup."),
|
||||
("recursion", "# Solve recursively."),
|
||||
]
|
||||
|
||||
|
||||
def mbpp_prompt(p): return f"# Task: {p['prompt']}\n# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n"
|
||||
def he_prompt(p): return p["prompt"]
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--n_mining", type=int, default=150)
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
random.seed(42)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log("loaded")
|
||||
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
mbpp_test = list(load_dataset("mbpp", "sanitized", split="test"))[:100]
|
||||
mbpp_full = list(load_dataset("mbpp", split="train"))
|
||||
random.shuffle(mbpp_full)
|
||||
seeds = []
|
||||
for p in mbpp_full[:args.n_mining]:
|
||||
prompt_text = p.get("prompt") or p.get("text", "")
|
||||
if prompt_text and p.get("test_list"):
|
||||
seeds.append({"prompt": prompt_text, "test_list": p["test_list"]})
|
||||
log(f" HE: {len(he)}, MBPP-test: {len(mbpp_test)}, mining: {len(seeds)}")
|
||||
|
||||
# Base eval
|
||||
sp_g = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp_g, use_tqdm=False)]
|
||||
base_he = sum(1 for p, raw in zip(he, he_outs)
|
||||
if run_python(p["prompt"] + "\n" + raw + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})", 10))
|
||||
mbpp_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in mbpp_test], sp_g, use_tqdm=False)]
|
||||
base_mbpp = sum(1 for p, raw in zip(mbpp_test, mbpp_outs)
|
||||
if run_python(raw + "\n\n" + "\n".join(p["test_list"]), 10))
|
||||
log(f"BASE: HE={base_he}/{len(he)} MBPP={base_mbpp}/{len(mbpp_test)}")
|
||||
|
||||
# Mine: for each problem, generate 4 lens-cued attempts, keep one that works
|
||||
log("mining with cued diversity...")
|
||||
pairs = []
|
||||
for lens_name, lens_hint in LENS_PROMPTS:
|
||||
log(f" lens: {lens_name}")
|
||||
# Prefill prompts with lens hint
|
||||
prefilled = []
|
||||
for s in seeds:
|
||||
base = mbpp_prompt(s) + f"# Approach: {lens_name}.\n{lens_hint}\ndef solution"
|
||||
prefilled.append(base)
|
||||
sp = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=300,
|
||||
stop=["\nclass Test", "\nif __name__", "\n\nprint", "\n# Task"])
|
||||
outs = [o.outputs[0].text for o in llm.generate(prefilled, sp, use_tqdm=False)]
|
||||
# Verify each
|
||||
for s, raw in zip(seeds, outs):
|
||||
code = "def solution" + raw
|
||||
if run_python(code + "\n\n" + "\n".join(s["test_list"]), 8):
|
||||
# Greedy attempt to use as broken
|
||||
greedy = [o.outputs[0].text for o in llm.generate([mbpp_prompt(s)], sp_g, use_tqdm=False)][0]
|
||||
if not run_python(greedy + "\n\n" + "\n".join(s["test_list"]), 8):
|
||||
pairs.append({"problem": s["prompt"], "tests": s["test_list"],
|
||||
"broken": greedy.strip(), "fixed": code.strip(),
|
||||
"lens": lens_name})
|
||||
log(f"mined {len(pairs)} pairs across lenses")
|
||||
|
||||
with open(f"{args.out_dir}/pairs.jsonl", "w") as fh:
|
||||
for r in pairs: fh.write(json.dumps(r) + "\n")
|
||||
|
||||
if len(pairs) < 5:
|
||||
result = {"model": args.model, "n_pairs": len(pairs), "base_he": base_he, "base_mbpp": base_mbpp}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
return
|
||||
|
||||
# Train flat
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
def mk_ex(r):
|
||||
user = (f"# Task: {r['problem']}\n# Tests:\n# " + "\n# ".join(r['tests']) + "\n"
|
||||
f"# My broken attempt:\n{r['broken']}\n# Corrected:\n")
|
||||
full = user + r["fixed"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids); n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
log("training...")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([mk_ex(r) for r in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{args.out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
adapter_dir = f"{args.out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# Trained eval
|
||||
from vllm import LLM as LLM2
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM2(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("trained", 1, adapter_dir)
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp_g, lora_request=lora_req, use_tqdm=False)]
|
||||
tr_he = sum(1 for p, raw in zip(he, he_outs)
|
||||
if run_python(p["prompt"] + "\n" + raw + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})", 10))
|
||||
mbpp_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in mbpp_test], sp_g, lora_request=lora_req, use_tqdm=False)]
|
||||
tr_mbpp = sum(1 for p, raw in zip(mbpp_test, mbpp_outs)
|
||||
if run_python(raw + "\n\n" + "\n".join(p["test_list"]), 10))
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_pairs": len(pairs),
|
||||
"humaneval": {"base": base_he, "trained": tr_he, "delta": tr_he-base_he, "n": len(he)},
|
||||
"mbpp": {"base": base_mbpp, "trained": tr_mbpp, "delta": tr_mbpp-base_mbpp, "n": len(mbpp_test)},
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — DIVERSITY-CUED MINING ({len(pairs)} pairs)")
|
||||
print(f" HE: base={base_he}/{len(he)} trained={tr_he}/{len(he)} Δ={tr_he-base_he:+d}")
|
||||
print(f" MBPP: base={base_mbpp}/{len(mbpp_test)} trained={tr_mbpp}/{len(mbpp_test)} Δ={tr_mbpp-base_mbpp:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
276
experiments/math500_seeded_mining.py
Normal file
276
experiments/math500_seeded_mining.py
Normal file
|
|
@ -0,0 +1,276 @@
|
|||
"""TinyForge-Zero math with MATH-train-split as problem seeds.
|
||||
|
||||
Recipe:
|
||||
1. Sample N problems from MATH train split (NOT test).
|
||||
2. Greedy solve each. Verify with sympy against gold answer.
|
||||
3. If greedy correct → save (problem, greedy_solution) as positive.
|
||||
4. If greedy wrong, sample 4 attempts at temp=0.8.
|
||||
Some pass → mine pair: (problem, sampled_correct_solution).
|
||||
5. Repeat until max_pairs.
|
||||
6. Train LoRA on pairs.
|
||||
7. Eval on MATH-500 (test).
|
||||
|
||||
Uses MATH train as problem source — model still self-generates ALL solutions.
|
||||
No human solutions used.
|
||||
"""
|
||||
import os, json, time, re, argparse, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from datasets import load_dataset, Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
import sympy
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
SOLVE_PROMPT = """Solve this competition math problem. Show your reasoning, then put the final answer in \\boxed{{...}}.
|
||||
|
||||
Problem: {problem}
|
||||
|
||||
Solution:"""
|
||||
|
||||
|
||||
def extract_boxed(text):
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx < 0: return None
|
||||
start = idx + len("\\boxed{")
|
||||
depth = 1; i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{": depth += 1
|
||||
elif text[i] == "}": depth -= 1
|
||||
i += 1
|
||||
if depth != 0: return None
|
||||
return text[start:i-1].strip()
|
||||
|
||||
|
||||
def normalize(s):
|
||||
if s is None: return None
|
||||
s = s.strip()
|
||||
s = re.sub(r"^\$|\$$", "", s).strip()
|
||||
s = re.sub(r"\\text\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"\\mbox\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"(?<=\d),(?=\d)", "", s)
|
||||
s = s.replace("\\left", "").replace("\\right", "").replace("^\\circ", "").replace("^{\\circ}", "")
|
||||
return s.strip()
|
||||
|
||||
|
||||
def sympy_equal(a, b):
|
||||
if a is None or b is None: return False
|
||||
a, b = normalize(a), normalize(b)
|
||||
if a == b: return True
|
||||
try:
|
||||
ea = parse_latex(a); eb = parse_latex(b)
|
||||
if sympy.simplify(ea - eb) == 0: return True
|
||||
except Exception: pass
|
||||
try:
|
||||
fa = float(a); fb = float(b)
|
||||
if abs(fa - fb) < 1e-6: return True
|
||||
except Exception: pass
|
||||
return False
|
||||
|
||||
|
||||
def gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=16):
|
||||
outs = []
|
||||
for i in range(0, len(prompts), batch):
|
||||
chunk = prompts[i:i+batch]
|
||||
texts = []
|
||||
for p in chunk:
|
||||
msgs = [{"role": "system", "content": "You are a careful math problem solver."},
|
||||
{"role": "user", "content": p}]
|
||||
try:
|
||||
texts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
except Exception:
|
||||
texts.append(p)
|
||||
inp = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=1500).to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inp, max_new_tokens=max_new, do_sample=temperature > 0,
|
||||
temperature=temperature if temperature > 0 else 1.0, top_p=0.95,
|
||||
pad_token_id=tok.eos_token_id)
|
||||
for j in range(out.size(0)):
|
||||
outs.append(tok.decode(out[j][inp.input_ids.shape[1]:], skip_special_tokens=True))
|
||||
return outs
|
||||
|
||||
|
||||
def math500_eval(model, tok, n=500, batch=16):
|
||||
ds = list(load_dataset("HuggingFaceH4/MATH-500", split="test"))[:n]
|
||||
log(f" eval on MATH-500 ({len(ds)} problems)")
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in ds]
|
||||
outs = gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=batch)
|
||||
correct = 0
|
||||
for p, raw in zip(ds, outs):
|
||||
pred = extract_boxed(raw)
|
||||
if sympy_equal(pred, p["answer"]): correct += 1
|
||||
return correct, len(ds)
|
||||
|
||||
|
||||
def make_train_example(problem, solution, tok):
|
||||
user = SOLVE_PROMPT.format(problem=problem)
|
||||
msgs_pre = [{"role": "system", "content": "You are a careful math problem solver."},
|
||||
{"role": "user", "content": user}]
|
||||
msgs_full = msgs_pre + [{"role": "assistant", "content": solution}]
|
||||
pre = tok.apply_chat_template(msgs_pre, tokenize=False, add_generation_prompt=True)
|
||||
full = tok.apply_chat_template(msgs_full, tokenize=False)
|
||||
pre_ids = tok(pre, add_special_tokens=False)["input_ids"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1280
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_pre = min(len(pre_ids), len(labels))
|
||||
for i in range(n_pre): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
|
||||
def train_on_pairs(model, tok, pairs, out_dir, lr=1e-4, epochs=2, rank=16):
|
||||
log(f" training on {len(pairs)} pairs (lr={lr}, e={epochs}, r={rank})")
|
||||
lora_cfg = LoraConfig(r=rank, lora_alpha=rank*2, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
tok.padding_side = "right"
|
||||
ds = HFDataset.from_list([make_train_example(p["problem"], p["solution"], tok) for p in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=epochs,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=lr, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds, processing_class=tok).train()
|
||||
tok.padding_side = "left"
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--iterations", type=int, default=6)
|
||||
ap.add_argument("--problems_per_iter", type=int, default=32)
|
||||
ap.add_argument("--n_eval", type=int, default=500)
|
||||
ap.add_argument("--max_pairs", type=int, default=120)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/math500_seeded/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
random.seed(args.seed); torch.manual_seed(args.seed)
|
||||
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
log(f" loaded mem={torch.cuda.memory_allocated('cuda:0')/1e9:.1f}GB")
|
||||
|
||||
log("loading MATH train split")
|
||||
train_ds = []
|
||||
for cfg in ["algebra","counting_and_probability","geometry","intermediate_algebra","number_theory","prealgebra","precalculus"]:
|
||||
try:
|
||||
sub = list(load_dataset("EleutherAI/hendrycks_math", cfg, split="train"))
|
||||
train_ds.extend(sub)
|
||||
except Exception as e:
|
||||
log(f" warn: failed to load {cfg}: {e}")
|
||||
log(f" {len(train_ds)} train problems")
|
||||
random.shuffle(train_ds)
|
||||
|
||||
model.eval()
|
||||
log("INITIAL eval on MATH-500")
|
||||
base_c, base_n = math500_eval(model, tok, n=args.n_eval)
|
||||
log(f" MATH-500 base: {base_c}/{base_n} ({100*base_c/base_n:.1f}%)")
|
||||
|
||||
pairs = []
|
||||
cursor = 0
|
||||
|
||||
def gold_of(p):
|
||||
ans = p.get("solution", "")
|
||||
b = extract_boxed(ans)
|
||||
return b
|
||||
|
||||
for it in range(1, args.iterations + 1):
|
||||
log(f"--- iter {it} ---")
|
||||
batch_size = args.problems_per_iter
|
||||
# Sample with gold extractable
|
||||
batch_problems = []
|
||||
while len(batch_problems) < batch_size and cursor < len(train_ds):
|
||||
p = train_ds[cursor]; cursor += 1
|
||||
gold = gold_of(p)
|
||||
if gold is not None:
|
||||
batch_problems.append({"problem": p["problem"], "gold": gold})
|
||||
if not batch_problems:
|
||||
log(" exhausted train problems"); break
|
||||
|
||||
# Greedy
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in batch_problems]
|
||||
greedy_outs = gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=16)
|
||||
greedy_correct, hard_idx = 0, []
|
||||
for i, (p, raw) in enumerate(zip(batch_problems, greedy_outs)):
|
||||
pred = extract_boxed(raw)
|
||||
if sympy_equal(pred, p["gold"]):
|
||||
pairs.append({"problem": p["problem"], "solution": raw.strip(), "source": "greedy"})
|
||||
greedy_correct += 1
|
||||
else:
|
||||
hard_idx.append(i)
|
||||
log(f" iter {it}: {greedy_correct} greedy-correct, {len(hard_idx)} hard")
|
||||
|
||||
# Sampled for hard
|
||||
if hard_idx:
|
||||
hard_problems = [batch_problems[i] for i in hard_idx]
|
||||
sample_prompts = []
|
||||
for p in hard_problems:
|
||||
sample_prompts.extend([SOLVE_PROMPT.format(problem=p["problem"])] * 4)
|
||||
sample_outs = gen_batch(model, tok, sample_prompts, max_new=600, temperature=0.8, batch=16)
|
||||
sampled_correct = 0
|
||||
for i, p in enumerate(hard_problems):
|
||||
attempts = sample_outs[i*4:(i+1)*4]
|
||||
preds = [extract_boxed(a) for a in attempts]
|
||||
correct_idx = [j for j, pr in enumerate(preds) if sympy_equal(pr, p["gold"])]
|
||||
if correct_idx:
|
||||
pairs.append({"problem": p["problem"], "solution": attempts[correct_idx[0]].strip(), "source": "sampled"})
|
||||
sampled_correct += 1
|
||||
log(f" iter {it}: {sampled_correct} sampled-correct (from {len(hard_idx)} hard)")
|
||||
|
||||
log(f" iter {it}: pairs total = {len(pairs)}")
|
||||
if len(pairs) >= args.max_pairs:
|
||||
log(f" reached max_pairs={args.max_pairs}, stopping")
|
||||
break
|
||||
|
||||
log(f"=== mined {len(pairs)} total pairs ===")
|
||||
with open(f"{out_dir}/pairs.jsonl", "w") as fh:
|
||||
for p in pairs: fh.write(json.dumps(p) + "\n")
|
||||
|
||||
if not pairs:
|
||||
log("no pairs — exiting"); return
|
||||
|
||||
model = train_on_pairs(model, tok, pairs, out_dir)
|
||||
log("training done")
|
||||
|
||||
model.eval()
|
||||
log("FINAL eval on MATH-500")
|
||||
tr_c, tr_n = math500_eval(model, tok, n=args.n_eval)
|
||||
log(f" MATH-500 trained: {tr_c}/{tr_n} ({100*tr_c/tr_n:.1f}%)")
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_pairs": len(pairs),
|
||||
"base": base_c, "trained": tr_c, "n": tr_n,
|
||||
"delta": tr_c - base_c, "elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model}")
|
||||
print(f" MATH-500: base={base_c}/{tr_n} trained={tr_c}/{tr_n} Δ={tr_c-base_c:+d}")
|
||||
print(f" Pairs mined: {len(pairs)}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
241
experiments/mbpp_seeded_cross_arch.py
Normal file
241
experiments/mbpp_seeded_cross_arch.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
"""Self-bootstrap with MBPP-train as problem seeds + vLLM on H100.
|
||||
|
||||
- Use MBPP train (374 problems) as PROBLEM seeds (no human solutions used).
|
||||
- For each: greedy attempt. If fails, sample N attempts at temp=0.8.
|
||||
- Mine at-edge pairs (broken, fixed).
|
||||
- Train LoRA. Eval on HumanEval + MBPP-test.
|
||||
"""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def run_python(code, timeout=8):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0, (r.stderr or "")[:200]
|
||||
except subprocess.TimeoutExpired: return False, "timeout"
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def vllm_gen(llm, prompts, max_new=400, temperature=0.0, n=1, stops=None):
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=temperature, top_p=0.95 if temperature > 0 else 1.0,
|
||||
max_tokens=max_new, n=n,
|
||||
stop=stops or ["\nclass ", "\nif __name__", "\n\nprint", "\n\ndef "])
|
||||
out = llm.generate(prompts, sp, use_tqdm=False)
|
||||
# returns list of lists when n>1
|
||||
if n == 1:
|
||||
return [o.outputs[0].text for o in out]
|
||||
return [[c.text for c in o.outputs] for o in out]
|
||||
|
||||
|
||||
def he_prompt(p): return p["prompt"]
|
||||
def mbpp_prompt(p):
|
||||
return (f"# Task: {p['prompt']}\n"
|
||||
f"# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n")
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--attempts_per", type=int, default=8)
|
||||
ap.add_argument("--max_pairs", type=int, default=200)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/selfmine_mbpp/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model} into vLLM")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log(f" loaded")
|
||||
|
||||
# --- Load benchmarks
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
mbpp_test = list(load_dataset("mbpp", "sanitized", split="test"))[:200]
|
||||
mbpp_train = list(load_dataset("mbpp", "sanitized", split="train"))
|
||||
log(f" HE: {len(he)}, MBPP-test: {len(mbpp_test)}, MBPP-train: {len(mbpp_train)}")
|
||||
|
||||
# --- BASE eval
|
||||
log("=== BASE evals ===")
|
||||
t0 = time.time()
|
||||
he_outs = vllm_gen(llm, [he_prompt(p) for p in he], max_new=400)
|
||||
log(f" HE base gen done in {time.time()-t0:.1f}s")
|
||||
base_he = 0
|
||||
for p, raw in zip(he, he_outs):
|
||||
full = p["prompt"] + raw
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
ok, _ = run_python(test_code, timeout=10)
|
||||
if ok: base_he += 1
|
||||
|
||||
t1 = time.time()
|
||||
mbpp_outs = vllm_gen(llm, [mbpp_prompt(p) for p in mbpp_test], max_new=400)
|
||||
log(f" MBPP-test base gen done in {time.time()-t1:.1f}s")
|
||||
base_mbpp = 0
|
||||
for p, raw in zip(mbpp_test, mbpp_outs):
|
||||
test_code = raw + "\n\n" + "\n".join(p["test_list"])
|
||||
ok, _ = run_python(test_code, timeout=10)
|
||||
if ok: base_mbpp += 1
|
||||
log(f" BASE: HE={base_he}/{len(he)} MBPP={base_mbpp}/{len(mbpp_test)}")
|
||||
|
||||
# --- Mine pairs from MBPP-train
|
||||
log(f"=== mining from {len(mbpp_train)} MBPP-train problems ===")
|
||||
train_prompts = [mbpp_prompt(p) for p in mbpp_train]
|
||||
# greedy attempt
|
||||
t0 = time.time()
|
||||
greedy_outs = vllm_gen(llm, train_prompts, max_new=400)
|
||||
log(f" greedy gen in {time.time()-t0:.1f}s")
|
||||
pairs = []
|
||||
hard_indices = []
|
||||
for i, (p, raw) in enumerate(zip(mbpp_train, greedy_outs)):
|
||||
test_code = raw + "\n\n" + "\n".join(p["test_list"])
|
||||
ok, err = run_python(test_code, timeout=8)
|
||||
if not ok:
|
||||
hard_indices.append((i, p, raw, err))
|
||||
log(f" {len(mbpp_train) - len(hard_indices)} greedy-correct, {len(hard_indices)} hard")
|
||||
|
||||
if not hard_indices:
|
||||
log("nothing to mine — base too strong"); return
|
||||
|
||||
# sample N attempts per hard problem
|
||||
log(f" sampling {args.attempts_per} attempts × {len(hard_indices)} hard problems...")
|
||||
hard_prompts = []
|
||||
for _i, p, _r, _e in hard_indices:
|
||||
hard_prompts.append(mbpp_prompt(p))
|
||||
t1 = time.time()
|
||||
sample_outs = vllm_gen(llm, hard_prompts, max_new=400, temperature=0.8, n=args.attempts_per)
|
||||
log(f" sample gen in {time.time()-t1:.1f}s")
|
||||
|
||||
t2 = time.time()
|
||||
for (idx, p, greedy_raw, err), attempts in zip(hard_indices, sample_outs):
|
||||
# check each attempt
|
||||
passes = []
|
||||
for a in attempts:
|
||||
test_code = a + "\n\n" + "\n".join(p["test_list"])
|
||||
ok, _ = run_python(test_code, timeout=8)
|
||||
if ok: passes.append(a)
|
||||
if passes:
|
||||
pairs.append({
|
||||
"problem": p["prompt"],
|
||||
"tests": p["test_list"],
|
||||
"broken": greedy_raw.strip(),
|
||||
"fixed": passes[0].strip(),
|
||||
"error": err,
|
||||
})
|
||||
if len(pairs) >= args.max_pairs: break
|
||||
log(f" verification in {time.time()-t2:.1f}s — mined {len(pairs)} pairs")
|
||||
|
||||
with open(f"{out_dir}/pairs.jsonl", "w") as fh:
|
||||
for r in pairs: fh.write(json.dumps(r) + "\n")
|
||||
|
||||
if len(pairs) < 5:
|
||||
log("too few pairs — exiting"); return
|
||||
|
||||
# --- Train LoRA
|
||||
log("=== TRAINING ===")
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
def make_ex(r):
|
||||
user = (f"# Task: {r['problem']}\n"
|
||||
f"# Tests:\n# " + "\n# ".join(r['tests']) + "\n"
|
||||
f"# My broken attempt:\n{r['broken']}\n"
|
||||
f"# Error: {r.get('error','')[:120]}\n"
|
||||
f"# Corrected:\n")
|
||||
target = r["fixed"]
|
||||
full = user + target
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds = HFDataset.from_list([make_ex(r) for r in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=2, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds, tokenizer=tok).train()
|
||||
log("training done")
|
||||
adapter_dir = f"{out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# --- TRAINED eval
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("tf_adapter", 1, adapter_dir)
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass ", "\nif __name__", "\n\nprint", "\n\ndef "])
|
||||
|
||||
log("=== TRAINED evals ===")
|
||||
t0 = time.time()
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp, lora_request=lora_req, use_tqdm=False)]
|
||||
log(f" HE trained gen in {time.time()-t0:.1f}s")
|
||||
tr_he = 0
|
||||
for p, raw in zip(he, he_outs):
|
||||
full = p["prompt"] + raw
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
ok, _ = run_python(test_code, timeout=10)
|
||||
if ok: tr_he += 1
|
||||
|
||||
t1 = time.time()
|
||||
mbpp_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in mbpp_test], sp, lora_request=lora_req, use_tqdm=False)]
|
||||
log(f" MBPP-test trained gen in {time.time()-t1:.1f}s")
|
||||
tr_mbpp = 0
|
||||
for p, raw in zip(mbpp_test, mbpp_outs):
|
||||
test_code = raw + "\n\n" + "\n".join(p["test_list"])
|
||||
ok, _ = run_python(test_code, timeout=10)
|
||||
if ok: tr_mbpp += 1
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_pairs": len(pairs),
|
||||
"humaneval": {"base": base_he, "trained": tr_he, "delta": tr_he-base_he, "n": len(he)},
|
||||
"mbpp": {"base": base_mbpp, "trained": tr_mbpp, "delta": tr_mbpp-base_mbpp, "n": len(mbpp_test)},
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — MBPP-train SEEDED ({len(pairs)} pairs)")
|
||||
print(f" HumanEval: base={base_he}/{len(he)} trained={tr_he}/{len(he)} Δ={tr_he-base_he:+d}")
|
||||
print(f" MBPP: base={base_mbpp}/{len(mbpp_test)} trained={tr_mbpp}/{len(mbpp_test)} Δ={tr_mbpp-base_mbpp:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
210
experiments/recipe_x_tts_synergy.py
Normal file
210
experiments/recipe_x_tts_synergy.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Compound recipe + TTS: train recipe, then measure best-of-N on TOP of recipe-trained model.
|
||||
Tests if recipe-trained model has BETTER sample diversity / quality at inference."""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def run_python(code, timeout=10):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0
|
||||
except subprocess.TimeoutExpired: return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def mbpp_prompt(p): return f"# Task: {p['prompt']}\n# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n"
|
||||
def he_prompt(p): return p["prompt"]
|
||||
|
||||
|
||||
def he_score_outputs(he, outs):
|
||||
c = 0
|
||||
for p, raw in zip(he, outs):
|
||||
code = raw
|
||||
if "```python" in code:
|
||||
code = code.split("```python",1)[1]
|
||||
if "```" in code: code = code.split("```",1)[0]
|
||||
full = p["prompt"] + "\n" + code
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10): c += 1
|
||||
return c
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
random.seed(42)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log("loaded")
|
||||
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
|
||||
# 4 metrics:
|
||||
# A) raw greedy
|
||||
# B) raw + best-of-8
|
||||
# C) recipe greedy
|
||||
# D) recipe + best-of-8
|
||||
|
||||
sp_g = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
sp_s = SamplingParams(temperature=0.6, top_p=0.95, max_tokens=400, n=8,
|
||||
stop=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
|
||||
log("A) raw greedy")
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp_g, use_tqdm=False)]
|
||||
A_raw_greedy = he_score_outputs(he, he_outs)
|
||||
log(f" raw greedy: {A_raw_greedy}/{len(he)}")
|
||||
|
||||
log("B) raw best-of-8")
|
||||
he_samples = llm.generate([he_prompt(p) for p in he], sp_s, use_tqdm=False)
|
||||
B_raw_bo8 = 0
|
||||
for p, outset in zip(he, he_samples):
|
||||
for o in outset.outputs:
|
||||
code = o.text
|
||||
if "```python" in code:
|
||||
code = code.split("```python",1)[1]
|
||||
if "```" in code: code = code.split("```",1)[0]
|
||||
full = p["prompt"] + "\n" + code
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10):
|
||||
B_raw_bo8 += 1; break
|
||||
log(f" raw best-of-8: {B_raw_bo8}/{len(he)}")
|
||||
|
||||
# Mine pairs
|
||||
log("mining pairs from MBPP-train...")
|
||||
mbpp_full = list(load_dataset("mbpp", split="train"))
|
||||
random.shuffle(mbpp_full)
|
||||
seeds = []
|
||||
for p in mbpp_full[:200]:
|
||||
prompt_text = p.get("prompt") or p.get("text", "")
|
||||
if prompt_text and p.get("test_list"):
|
||||
seeds.append({"prompt": prompt_text, "test_list": p["test_list"]})
|
||||
|
||||
sp_mine = SamplingParams(temperature=0, max_tokens=400, stop=["\nclass Test", "\nif __name__"])
|
||||
g_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in seeds], sp_mine, use_tqdm=False)]
|
||||
hard_idx = [i for i, (p, raw) in enumerate(zip(seeds, g_outs))
|
||||
if not run_python(raw + "\n\n" + "\n".join(p["test_list"]), 8)]
|
||||
log(f" hard: {len(hard_idx)}")
|
||||
pairs = []
|
||||
if hard_idx:
|
||||
sp_m2 = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=400, n=8,
|
||||
stop=["\nclass Test", "\nif __name__"])
|
||||
hard_prompts = [mbpp_prompt(seeds[i]) for i in hard_idx]
|
||||
sample_outs = llm.generate(hard_prompts, sp_m2, use_tqdm=False)
|
||||
for j, i in enumerate(hard_idx):
|
||||
for o in sample_outs[j].outputs:
|
||||
if run_python(o.text + "\n\n" + "\n".join(seeds[i]["test_list"]), 8):
|
||||
pairs.append({"problem": seeds[i]["prompt"], "tests": seeds[i]["test_list"],
|
||||
"broken": g_outs[i].strip(), "fixed": o.text.strip()}); break
|
||||
log(f" mined {len(pairs)} pairs")
|
||||
|
||||
# Train LoRA
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
if len(pairs) < 5:
|
||||
log("too few pairs, exit"); return
|
||||
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
def mk_ex(r):
|
||||
user = (f"# Task: {r['problem']}\n# Tests:\n# " + "\n# ".join(r['tests']) + "\n"
|
||||
f"# My broken attempt:\n{r['broken']}\n# Corrected:\n")
|
||||
full = user + r["fixed"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids); n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
log("training...")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([mk_ex(r) for r in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{args.out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
adapter_dir = f"{args.out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# C, D
|
||||
from vllm import LLM as LLM2
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM2(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("trained", 1, adapter_dir)
|
||||
|
||||
log("C) recipe greedy")
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp_g, lora_request=lora_req, use_tqdm=False)]
|
||||
C_rec_greedy = he_score_outputs(he, he_outs)
|
||||
log(f" recipe greedy: {C_rec_greedy}/{len(he)}")
|
||||
|
||||
log("D) recipe best-of-8")
|
||||
he_samples = llm.generate([he_prompt(p) for p in he], sp_s, lora_request=lora_req, use_tqdm=False)
|
||||
D_rec_bo8 = 0
|
||||
for p, outset in zip(he, he_samples):
|
||||
for o in outset.outputs:
|
||||
code = o.text
|
||||
if "```python" in code:
|
||||
code = code.split("```python",1)[1]
|
||||
if "```" in code: code = code.split("```",1)[0]
|
||||
full = p["prompt"] + "\n" + code
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10):
|
||||
D_rec_bo8 += 1; break
|
||||
log(f" recipe best-of-8: {D_rec_bo8}/{len(he)}")
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_pairs": len(pairs),
|
||||
"raw_greedy": A_raw_greedy, "raw_bo8": B_raw_bo8,
|
||||
"recipe_greedy": C_rec_greedy, "recipe_bo8": D_rec_bo8,
|
||||
"n": len(he), "elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — RECIPE × TTS COMPOUND (HumanEval, n={len(he)}, {len(pairs)} pairs)")
|
||||
print(f" A) Raw greedy: {A_raw_greedy:>3}/{len(he)} ({100*A_raw_greedy/len(he):.1f}%)")
|
||||
print(f" B) Raw best-of-8: {B_raw_bo8:>3}/{len(he)} ({100*B_raw_bo8/len(he):.1f}%)")
|
||||
print(f" C) Recipe greedy: {C_rec_greedy:>3}/{len(he)} ({100*C_rec_greedy/len(he):.1f}%)")
|
||||
print(f" D) Recipe best-of-8: {D_rec_bo8:>3}/{len(he)} ({100*D_rec_bo8/len(he):.1f}%)")
|
||||
print(f" Synergy: D - max(B,C) = {D_rec_bo8 - max(B_raw_bo8, C_rec_greedy):+d} (>0 = real synergy)")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
219
experiments/recursive_bootstrap.py
Normal file
219
experiments/recursive_bootstrap.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
"""Recursive self-bootstrap: iter1->iter2->iter3.
|
||||
|
||||
Iter k:
|
||||
- Use model from previous iter (or base for iter 1)
|
||||
- Mine pairs on MBPP-train
|
||||
- Train fresh LoRA from BASE on accumulated pairs
|
||||
- Eval on HE
|
||||
"""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def run_python(code, timeout=10):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0
|
||||
except subprocess.TimeoutExpired: return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def mbpp_prompt(p):
|
||||
return f"# Task: {p['prompt']}\n# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n"
|
||||
|
||||
|
||||
def he_prompt(p): return p["prompt"]
|
||||
|
||||
|
||||
def vllm_gen(llm, prompts, max_new=400, temperature=0.0, n=1, lora_req=None, stops=None):
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=temperature, top_p=0.95 if temperature > 0 else 1.0,
|
||||
max_tokens=max_new, n=n,
|
||||
stop=stops or ["\nclass Test", "\nif __name__", "\n\nprint", "\nassert "])
|
||||
if lora_req:
|
||||
out = llm.generate(prompts, sp, lora_request=lora_req, use_tqdm=False)
|
||||
else:
|
||||
out = llm.generate(prompts, sp, use_tqdm=False)
|
||||
if n == 1: return [o.outputs[0].text for o in out]
|
||||
return [[c.text for c in o.outputs] for o in out]
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
ap.add_argument("--n_iters", type=int, default=3)
|
||||
ap.add_argument("--n_mining", type=int, default=200)
|
||||
ap.add_argument("--attempts_per", type=int, default=8)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.lora.request import LoRARequest
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
mbpp_full = list(load_dataset("mbpp", split="train"))
|
||||
random.seed(42); random.shuffle(mbpp_full)
|
||||
seeds_pool = []
|
||||
for p in mbpp_full[:args.n_mining * args.n_iters]:
|
||||
prompt_text = p.get("prompt") or p.get("text", "")
|
||||
if prompt_text and p.get("test_list"):
|
||||
seeds_pool.append({"prompt": prompt_text, "test_list": p["test_list"]})
|
||||
log(f"seeds pool: {len(seeds_pool)}")
|
||||
|
||||
iter_results = []
|
||||
accumulated_pairs = []
|
||||
current_adapter = None # path
|
||||
|
||||
for it in range(1, args.n_iters + 1):
|
||||
log(f"\n========== ITER {it} ==========")
|
||||
# Load model (with current adapter if exists)
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85,
|
||||
max_model_len=2048,
|
||||
enable_lora=(current_adapter is not None), max_lora_rank=16)
|
||||
lora_req = LoRARequest("cur", 1, current_adapter) if current_adapter else None
|
||||
log(f" loaded {'(with adapter)' if current_adapter else '(base)'}")
|
||||
|
||||
# Mine pairs using current model
|
||||
seeds = seeds_pool[(it-1)*args.n_mining:it*args.n_mining]
|
||||
log(f" mining from {len(seeds)} new seeds")
|
||||
prompts = [mbpp_prompt(p) for p in seeds]
|
||||
greedy_outs = vllm_gen(llm, prompts, max_new=400, lora_req=lora_req)
|
||||
hard_idx = []
|
||||
for i, (p, raw) in enumerate(zip(seeds, greedy_outs)):
|
||||
test_code = raw + "\n\n" + "\n".join(p["test_list"])
|
||||
if not run_python(test_code, 8):
|
||||
hard_idx.append(i)
|
||||
log(f" greedy: {len(seeds)-len(hard_idx)} pass, {len(hard_idx)} hard")
|
||||
|
||||
if hard_idx:
|
||||
hard_prompts = [mbpp_prompt(seeds[i]) for i in hard_idx]
|
||||
sample_outs = vllm_gen(llm, hard_prompts, max_new=400, temperature=0.8,
|
||||
n=args.attempts_per, lora_req=lora_req)
|
||||
new_pairs = []
|
||||
for j, i in enumerate(hard_idx):
|
||||
attempts = sample_outs[j]
|
||||
passes = []
|
||||
for a in attempts:
|
||||
if run_python(a + "\n\n" + "\n".join(seeds[i]["test_list"]), 8):
|
||||
passes.append(a); break
|
||||
if passes:
|
||||
new_pairs.append({"problem": seeds[i]["prompt"], "tests": seeds[i]["test_list"],
|
||||
"broken": greedy_outs[i].strip(), "fixed": passes[0].strip(),
|
||||
"iter": it})
|
||||
accumulated_pairs.extend(new_pairs)
|
||||
log(f" mined {len(new_pairs)} new pairs (cumulative: {len(accumulated_pairs)})")
|
||||
|
||||
# Eval current model on HE
|
||||
log(f" eval HE...")
|
||||
he_outs = vllm_gen(llm, [he_prompt(p) for p in he], max_new=400, lora_req=lora_req,
|
||||
stops=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
he_correct = 0
|
||||
for p, raw in zip(he, he_outs):
|
||||
full = p["prompt"] + "\n" + raw
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10): he_correct += 1
|
||||
log(f" HE iter{it} (pre-train): {he_correct}/{len(he)}")
|
||||
iter_results.append({"iter": it, "he_pretrain": he_correct, "cumulative_pairs": len(accumulated_pairs)})
|
||||
|
||||
# Tear down vLLM, train new adapter on accumulated pairs
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
if len(accumulated_pairs) < 5:
|
||||
log(f" too few pairs to train, skipping iter {it} training")
|
||||
continue
|
||||
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
def mk_ex(r):
|
||||
user = (f"# Task: {r['problem']}\n# Tests:\n# " + "\n# ".join(r['tests']) + "\n"
|
||||
f"# My broken attempt:\n{r['broken']}\n# Corrected:\n")
|
||||
target = r["fixed"]
|
||||
full = user + target
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
log(f" training fresh adapter on {len(accumulated_pairs)} pairs...")
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([mk_ex(r) for r in accumulated_pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{args.out_dir}/iter{it}_ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
adapter_dir = f"{args.out_dir}/iter{it}_adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
current_adapter = adapter_dir
|
||||
|
||||
# Re-eval with new adapter to get post-train HE
|
||||
log(f" eval post-train HE...")
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest(f"iter{it}", it, current_adapter)
|
||||
he_outs = vllm_gen(llm, [he_prompt(p) for p in he], max_new=400, lora_req=lora_req,
|
||||
stops=["\nclass ", "\nif __name__", "\n\nprint"])
|
||||
he_correct = 0
|
||||
for p, raw in zip(he, he_outs):
|
||||
full = p["prompt"] + "\n" + raw
|
||||
test_code = full + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})"
|
||||
if run_python(test_code, 10): he_correct += 1
|
||||
log(f" HE iter{it} (post-train): {he_correct}/{len(he)}")
|
||||
iter_results[-1]["he_posttrain"] = he_correct
|
||||
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# Save pairs and results
|
||||
with open(f"{args.out_dir}/pairs.jsonl", "w") as fh:
|
||||
for r in accumulated_pairs: fh.write(json.dumps(r) + "\n")
|
||||
result = {"model": args.model, "tag": args.tag, "n_iters": args.n_iters,
|
||||
"iter_results": iter_results, "total_pairs": len(accumulated_pairs),
|
||||
"elapsed_s": time.time() - T0}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — RECURSIVE BOOTSTRAP")
|
||||
for r in iter_results:
|
||||
pre = r.get("he_pretrain", "-")
|
||||
post = r.get("he_posttrain", "-")
|
||||
print(f" iter {r['iter']}: cum_pairs={r['cumulative_pairs']} HE_pre={pre} HE_post={post}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
129
experiments/self_consistency.py
Normal file
129
experiments/self_consistency.py
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
"""Self-consistency selection: majority vote on N samples WITHOUT oracle access.
|
||||
Tests if model's self-agreement is a good selector (deployable TTS without test cases)."""
|
||||
import os, json, time, re, argparse
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from collections import Counter
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def extract_boxed(text):
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx < 0: return None
|
||||
start = idx + len("\\boxed{"); depth = 1; i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{": depth += 1
|
||||
elif text[i] == "}": depth -= 1
|
||||
i += 1
|
||||
if depth != 0: return None
|
||||
return text[start:i-1].strip()
|
||||
|
||||
|
||||
def normalize(s):
|
||||
if s is None: return None
|
||||
s = s.strip().lower()
|
||||
s = re.sub(r"[,$\s]", "", s)
|
||||
return s
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--n_samples", type=int, default=16)
|
||||
ap.add_argument("--tag", required=True)
|
||||
ap.add_argument("--out_dir", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log("loaded")
|
||||
|
||||
math500 = list(load_dataset("HuggingFaceH4/MATH-500", split="test"))[:200]
|
||||
prompts = []
|
||||
for p in math500:
|
||||
try:
|
||||
msgs = [{"role": "system", "content": "Math solver. End with \\boxed{answer}."},
|
||||
{"role": "user", "content": f"Solve. Problem: {p['problem']}\n\nSolution:"}]
|
||||
prompts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
except Exception:
|
||||
prompts.append(f"Solve. Problem: {p['problem']}\n\nSolution:")
|
||||
|
||||
log(f"generating {args.n_samples} samples per problem...")
|
||||
sp = SamplingParams(temperature=0.7, top_p=0.95, max_tokens=800, n=args.n_samples)
|
||||
t0 = time.time()
|
||||
outs = llm.generate(prompts, sp, use_tqdm=False)
|
||||
log(f" gen in {time.time()-t0:.1f}s")
|
||||
|
||||
import sympy
|
||||
from sympy.parsing.latex import parse_latex
|
||||
def sympy_eq(a, b):
|
||||
if a is None or b is None: return False
|
||||
if a == b: return True
|
||||
try:
|
||||
if sympy.simplify(parse_latex(a) - parse_latex(b)) == 0: return True
|
||||
except Exception: pass
|
||||
try:
|
||||
if abs(float(a) - float(b)) < 1e-6: return True
|
||||
except Exception: pass
|
||||
return False
|
||||
|
||||
# Three metrics:
|
||||
# 1. Greedy: take first sample
|
||||
# 2. Oracle pass@N: any correct
|
||||
# 3. Self-consistency: majority vote on extracted boxed answer (normalize numbers/text)
|
||||
greedy_correct = 0
|
||||
oracle_correct = 0
|
||||
sc_correct = 0
|
||||
|
||||
for p, outset in zip(math500, outs):
|
||||
attempts = [o.text for o in outset.outputs]
|
||||
preds = [extract_boxed(a) for a in attempts]
|
||||
# Greedy: first sample
|
||||
if sympy_eq(preds[0], p["answer"]): greedy_correct += 1
|
||||
# Oracle: any pass
|
||||
if any(sympy_eq(pr, p["answer"]) for pr in preds): oracle_correct += 1
|
||||
# Self-consistency: majority vote on normalized answer
|
||||
normalized = [normalize(pr) for pr in preds if pr is not None]
|
||||
if normalized:
|
||||
most_common, _ = Counter(normalized).most_common(1)[0]
|
||||
# Find an original pred with this normalized form
|
||||
for pr in preds:
|
||||
if pr and normalize(pr) == most_common:
|
||||
if sympy_eq(pr, p["answer"]): sc_correct += 1
|
||||
break
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_samples": args.n_samples,
|
||||
"greedy_first": greedy_correct,
|
||||
"oracle_pass_at_N": oracle_correct,
|
||||
"self_consistency": sc_correct,
|
||||
"n": len(math500),
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{args.out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — SELF-CONSISTENCY vs ORACLE on MATH-500 (n={args.n_samples})")
|
||||
print(f" First sample (greedy-like): {greedy_correct}/{len(math500)} ({100*greedy_correct/len(math500):.1f}%)")
|
||||
print(f" Self-consistency (vote): {sc_correct}/{len(math500)} ({100*sc_correct/len(math500):.1f}%)")
|
||||
print(f" Oracle (any-pass): {oracle_correct}/{len(math500)} ({100*oracle_correct/len(math500):.1f}%)")
|
||||
sc_recovery = 100*(sc_correct - greedy_correct)/(oracle_correct - greedy_correct) if oracle_correct > greedy_correct else 0
|
||||
print(f" SC recovers {sc_recovery:.0f}% of oracle-greedy gap")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
236
experiments/self_correction_code.py
Normal file
236
experiments/self_correction_code.py
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
"""Self-correction recipe for CODE. Same pattern as math sc_v2 (which gave +5 recovery).
|
||||
|
||||
Pipeline:
|
||||
1. MBPP-train problems (374 sanitized + extended).
|
||||
2. Greedy attempt. If passes → save as right→stays-right positive.
|
||||
3. If fails → prompt with "Wait, let me reconsider" + sample 4 at temp=0.8.
|
||||
If any pass → mine (problem, wrong, reflection, correct) self-correction trace.
|
||||
4. Train on mixed dataset.
|
||||
5. Eval HE + MBPP.
|
||||
|
||||
Mix teaches model: commit to right answers, fix wrong ones.
|
||||
"""
|
||||
import os, json, time, re, subprocess, tempfile, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
RECONSIDER_TAG = "\n\n# Wait — that doesn't look right. Let me reconsider:\n\n"
|
||||
|
||||
|
||||
def run_python(code, timeout=8):
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py", delete=False) as f:
|
||||
f.write(code); path = f.name
|
||||
try:
|
||||
r = subprocess.run(["python3", path], capture_output=True, timeout=timeout, text=True, cwd="/tmp")
|
||||
return r.returncode == 0
|
||||
except subprocess.TimeoutExpired: return False
|
||||
finally:
|
||||
try: os.unlink(path)
|
||||
except: pass
|
||||
|
||||
|
||||
def vllm_gen(llm, prompts, max_new=400, temperature=0.0, n=1, prefill_texts=None):
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=temperature, top_p=0.95 if temperature > 0 else 1.0,
|
||||
max_tokens=max_new, n=n,
|
||||
stop=["\nclass Test", "\nif __name__", "\n\nprint", "\nassert "])
|
||||
if prefill_texts is None:
|
||||
out = llm.generate(prompts, sp, use_tqdm=False)
|
||||
else:
|
||||
# Each prompt is concatenated with prefill text
|
||||
full_prompts = [p + pre for p, pre in zip(prompts, prefill_texts)]
|
||||
out = llm.generate(full_prompts, sp, use_tqdm=False)
|
||||
if n == 1: return [o.outputs[0].text for o in out]
|
||||
return [[c.text for c in o.outputs] for o in out]
|
||||
|
||||
|
||||
def he_prompt(p): return p["prompt"]
|
||||
def mbpp_prompt(p):
|
||||
return f"# Task: {p['prompt']}\n# Tests:\n# " + "\n# ".join(p["test_list"]) + "\n\n"
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--n_mining", type=int, default=300)
|
||||
ap.add_argument("--max_self_corrections", type=int, default=80)
|
||||
ap.add_argument("--max_positives", type=int, default=80)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/code_sc/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
random.seed(42)
|
||||
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model} into vLLM")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log(f" loaded")
|
||||
|
||||
he = list(load_dataset("openai_humaneval", split="test"))
|
||||
mbpp_test = list(load_dataset("mbpp", "sanitized", split="test"))[:100]
|
||||
mbpp_full = list(load_dataset("mbpp", split="train"))
|
||||
random.shuffle(mbpp_full)
|
||||
seeds = []
|
||||
for p in mbpp_full[:args.n_mining]:
|
||||
prompt_text = p.get("prompt") or p.get("text", "")
|
||||
if prompt_text and p.get("test_list"):
|
||||
seeds.append({"prompt": prompt_text, "test_list": p["test_list"]})
|
||||
log(f" HE: {len(he)}, MBPP-test: {len(mbpp_test)}, mining seeds: {len(seeds)}")
|
||||
|
||||
# --- BASE eval
|
||||
log("=== BASE eval ===")
|
||||
he_outs = vllm_gen(llm, [he_prompt(p) for p in he], max_new=400)
|
||||
base_he = sum(1 for p, raw in zip(he, he_outs)
|
||||
if run_python(p["prompt"] + raw + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})", 10))
|
||||
log(f" HE base: {base_he}/{len(he)}")
|
||||
mbpp_outs = vllm_gen(llm, [mbpp_prompt(p) for p in mbpp_test], max_new=400)
|
||||
base_mbpp = sum(1 for p, raw in zip(mbpp_test, mbpp_outs)
|
||||
if run_python(raw + "\n\n" + "\n".join(p["test_list"]), 10))
|
||||
log(f" MBPP base: {base_mbpp}/{len(mbpp_test)}")
|
||||
|
||||
# --- Mine: greedy on all seeds
|
||||
log(f"=== mining: greedy attempt on {len(seeds)} seeds ===")
|
||||
t0 = time.time()
|
||||
greedy_outs = vllm_gen(llm, [mbpp_prompt(p) for p in seeds], max_new=400)
|
||||
log(f" greedy gen in {time.time()-t0:.1f}s")
|
||||
t1 = time.time()
|
||||
right = [] # greedy correct (positives)
|
||||
wrong = [] # greedy wrong (candidates for self-correction)
|
||||
for p, raw in zip(seeds, greedy_outs):
|
||||
test_code = raw + "\n\n" + "\n".join(p["test_list"])
|
||||
if run_python(test_code, timeout=8):
|
||||
right.append({"problem": p["prompt"], "tests": p["test_list"], "solution": raw.strip()})
|
||||
else:
|
||||
wrong.append({"problem": p["prompt"], "tests": p["test_list"], "wrong": raw.strip()})
|
||||
log(f" verify: {len(right)} greedy-correct, {len(wrong)} hard")
|
||||
|
||||
# --- For wrong: prefill wrong + reconsider tag, sample 4 attempts
|
||||
log(f"=== self-correction sampling on {len(wrong)} hard problems ===")
|
||||
sc_pairs = []
|
||||
if wrong:
|
||||
base_prompts = [mbpp_prompt({"prompt": w["problem"], "test_list": w["tests"]}) for w in wrong]
|
||||
prefills = [w["wrong"] + RECONSIDER_TAG for w in wrong]
|
||||
# Generate 4 attempts each via temperature
|
||||
t0 = time.time()
|
||||
sc_outs = vllm_gen(llm, base_prompts, max_new=400, temperature=0.8, n=4, prefill_texts=prefills)
|
||||
log(f" sc gen in {time.time()-t0:.1f}s")
|
||||
t1 = time.time()
|
||||
for w, attempts in zip(wrong, sc_outs):
|
||||
for a in attempts:
|
||||
test_code = a + "\n\n" + "\n".join(w["tests"])
|
||||
if run_python(test_code, timeout=8):
|
||||
full_trace = w["wrong"] + RECONSIDER_TAG + a.strip()
|
||||
sc_pairs.append({"problem": w["problem"], "tests": w["tests"],
|
||||
"full_trace": full_trace})
|
||||
break # one per problem
|
||||
log(f" sc verify in {time.time()-t1:.1f}s — {len(sc_pairs)} self-correction traces")
|
||||
|
||||
# Cap and sample
|
||||
random.shuffle(right); random.shuffle(sc_pairs)
|
||||
right = right[:args.max_positives]
|
||||
sc_pairs = sc_pairs[:args.max_self_corrections]
|
||||
log(f"=== final: {len(sc_pairs)} self-correction + {len(right)} right→stays-right = {len(sc_pairs)+len(right)} examples ===")
|
||||
|
||||
if len(sc_pairs) + len(right) < 10:
|
||||
log("too few examples — exiting"); return
|
||||
|
||||
with open(f"{out_dir}/sc_pairs.jsonl", "w") as fh:
|
||||
for r in sc_pairs: fh.write(json.dumps(r) + "\n")
|
||||
with open(f"{out_dir}/positives.jsonl", "w") as fh:
|
||||
for r in right: fh.write(json.dumps(r) + "\n")
|
||||
|
||||
# --- Train LoRA on MIXED dataset
|
||||
log("=== TRAINING ===")
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
train_examples = []
|
||||
for r in sc_pairs:
|
||||
train_examples.append({"problem": r["problem"], "tests": r["tests"], "target": r["full_trace"]})
|
||||
for r in right:
|
||||
train_examples.append({"problem": r["problem"], "tests": r["tests"], "target": r["solution"]})
|
||||
random.shuffle(train_examples)
|
||||
|
||||
def mk_ex(r):
|
||||
user = f"# Task: {r['problem']}\n# Tests:\n# " + "\n# ".join(r['tests']) + "\n\n"
|
||||
target = r["target"]
|
||||
full = user + target
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1280
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([mk_ex(r) for r in train_examples])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
log("training done")
|
||||
adapter_dir = f"{out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# --- TRAINED eval
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("tf_adapter", 1, adapter_dir)
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=0, max_tokens=500, stop=["\nclass Test", "\nif __name__"])
|
||||
|
||||
log("=== TRAINED eval ===")
|
||||
he_outs = [o.outputs[0].text for o in llm.generate([he_prompt(p) for p in he], sp, lora_request=lora_req, use_tqdm=False)]
|
||||
tr_he = sum(1 for p, raw in zip(he, he_outs)
|
||||
if run_python(p["prompt"] + raw + "\n\n" + p["test"] + f"\n\ncheck({p['entry_point']})", 10))
|
||||
mbpp_outs = [o.outputs[0].text for o in llm.generate([mbpp_prompt(p) for p in mbpp_test], sp, lora_request=lora_req, use_tqdm=False)]
|
||||
tr_mbpp = sum(1 for p, raw in zip(mbpp_test, mbpp_outs)
|
||||
if run_python(raw + "\n\n" + "\n".join(p["test_list"]), 10))
|
||||
|
||||
result = {
|
||||
"model": args.model,
|
||||
"n_sc": len(sc_pairs), "n_positives": len(right), "n_total": len(train_examples),
|
||||
"humaneval": {"base": base_he, "trained": tr_he, "delta": tr_he-base_he, "n": len(he)},
|
||||
"mbpp": {"base": base_mbpp, "trained": tr_mbpp, "delta": tr_mbpp-base_mbpp, "n": len(mbpp_test)},
|
||||
"elapsed_s": time.time()-T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — CODE SELF-CORRECTION ({len(sc_pairs)} sc + {len(right)} positives)")
|
||||
print(f" HumanEval: base={base_he}/{len(he)} trained={tr_he}/{len(he)} Δ={tr_he-base_he:+d}")
|
||||
print(f" MBPP: base={base_mbpp}/{len(mbpp_test)} trained={tr_mbpp}/{len(mbpp_test)} Δ={tr_mbpp-base_mbpp:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
256
experiments/self_correction_math_fixed.py
Normal file
256
experiments/self_correction_math_fixed.py
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
"""Self-correction recipe FIXED: mix wrong→fix triples WITH right→stays-right.
|
||||
|
||||
Previous failure: training only on wrong→fix taught model to over-doubt itself,
|
||||
causing -230 regression on Qwen3-4B-Base.
|
||||
|
||||
Fix:
|
||||
1. Use existing wrong→fix triples (mined yesterday).
|
||||
2. Add an equal/greater number of right→stays-right examples (greedy was correct).
|
||||
3. Train on the mixed dataset → model learns WHEN to self-correct.
|
||||
4. Eval on MATH-500.
|
||||
|
||||
Uses vLLM on H100 for fast generation.
|
||||
"""
|
||||
import os, json, time, re, argparse, gc, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
import sympy
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
SOLVE_PROMPT = """Solve this competition math problem. Show your reasoning, then put the final answer in \\boxed{{...}}.
|
||||
|
||||
Problem: {problem}
|
||||
|
||||
Solution:"""
|
||||
|
||||
|
||||
RECONSIDER_TAG = "\n\nWait, let me reconsider — I think there's an error above.\n\n"
|
||||
|
||||
|
||||
def extract_boxed(text):
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx < 0: return None
|
||||
start = idx + len("\\boxed{")
|
||||
depth = 1; i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{": depth += 1
|
||||
elif text[i] == "}": depth -= 1
|
||||
i += 1
|
||||
if depth != 0: return None
|
||||
return text[start:i-1].strip()
|
||||
|
||||
|
||||
def normalize(s):
|
||||
if s is None: return None
|
||||
s = s.strip()
|
||||
s = re.sub(r"^\$|\$$", "", s).strip()
|
||||
s = re.sub(r"\\text\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"\\mbox\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"(?<=\d),(?=\d)", "", s)
|
||||
s = s.replace("\\left", "").replace("\\right", "").replace("^\\circ", "").replace("^{\\circ}", "")
|
||||
return s.strip()
|
||||
|
||||
|
||||
def sympy_equal(a, b):
|
||||
if a is None or b is None: return False
|
||||
a, b = normalize(a), normalize(b)
|
||||
if a == b: return True
|
||||
try:
|
||||
ea = parse_latex(a); eb = parse_latex(b)
|
||||
if sympy.simplify(ea - eb) == 0: return True
|
||||
except Exception: pass
|
||||
try:
|
||||
fa = float(a); fb = float(b)
|
||||
if abs(fa - fb) < 1e-6: return True
|
||||
except Exception: pass
|
||||
return False
|
||||
|
||||
|
||||
def vllm_gen(llm, prompts, max_new=600, temperature=0.0, n=1):
|
||||
from vllm import SamplingParams
|
||||
sp = SamplingParams(temperature=temperature, top_p=0.95 if temperature > 0 else 1.0,
|
||||
max_tokens=max_new, n=n)
|
||||
out = llm.generate(prompts, sp, use_tqdm=False)
|
||||
if n == 1: return [o.outputs[0].text for o in out]
|
||||
return [[c.text for c in o.outputs] for o in out]
|
||||
|
||||
|
||||
def math500_eval(gen_func, label):
|
||||
ds = list(load_dataset("HuggingFaceH4/MATH-500", split="test"))
|
||||
log(f" eval MATH-500 [{label}] ({len(ds)})")
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in ds]
|
||||
t0 = time.time()
|
||||
outs = gen_func(prompts, max_new=800)
|
||||
log(f" gen done in {time.time()-t0:.1f}s")
|
||||
correct = 0
|
||||
for p, raw in zip(ds, outs):
|
||||
if sympy_equal(extract_boxed(raw), p["answer"]): correct += 1
|
||||
return correct, len(ds)
|
||||
|
||||
|
||||
def make_train_example(problem, solution, tok):
|
||||
user = SOLVE_PROMPT.format(problem=problem)
|
||||
full = user + " " + solution
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
user_ids = tok(user + " ", add_special_tokens=False)["input_ids"]
|
||||
MAX = 1536
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_user = min(len(user_ids), len(labels))
|
||||
for i in range(n_user): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--wrong_fix_pairs", required=True, help="Existing wrong→fix triples jsonl from prior run")
|
||||
ap.add_argument("--n_positives", type=int, default=100, help="Number of right→stays-right examples to mine")
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/math500_sc_v2/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
from vllm import LLM
|
||||
from transformers import AutoTokenizer
|
||||
log(f"loading {args.model} into vLLM")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048)
|
||||
log(f" loaded")
|
||||
|
||||
# --- BASE eval
|
||||
log("=== BASE eval ===")
|
||||
base_c, base_n = math500_eval(lambda P, max_new=800: vllm_gen(llm, P, max_new=max_new), "BASE")
|
||||
log(f" BASE: {base_c}/{base_n} ({100*base_c/base_n:.1f}%)")
|
||||
|
||||
# --- Load existing wrong→fix triples
|
||||
wrong_fix = [json.loads(l) for l in open(args.wrong_fix_pairs)]
|
||||
log(f" loaded {len(wrong_fix)} wrong→fix triples")
|
||||
|
||||
# --- Mine right→stays-right positives from MATH-train
|
||||
log(f"=== mining {args.n_positives} right→stays-right positives ===")
|
||||
train_ds = []
|
||||
for cfg in ["algebra","counting_and_probability","geometry","intermediate_algebra","number_theory","prealgebra","precalculus"]:
|
||||
try:
|
||||
sub = list(load_dataset("EleutherAI/hendrycks_math", cfg, split="train"))
|
||||
train_ds.extend(sub)
|
||||
except Exception: pass
|
||||
random.seed(42); random.shuffle(train_ds)
|
||||
log(f" {len(train_ds)} train problems available")
|
||||
|
||||
def gold_of(p):
|
||||
return extract_boxed(p.get("solution", ""))
|
||||
|
||||
positives = []
|
||||
cursor = 0
|
||||
while len(positives) < args.n_positives and cursor < len(train_ds):
|
||||
batch = []
|
||||
while len(batch) < 64 and cursor < len(train_ds):
|
||||
p = train_ds[cursor]; cursor += 1
|
||||
g = gold_of(p)
|
||||
if g is not None: batch.append({"problem": p["problem"], "gold": g})
|
||||
if not batch: break
|
||||
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in batch]
|
||||
outs = vllm_gen(llm, prompts, max_new=600, temperature=0.0)
|
||||
for p, raw in zip(batch, outs):
|
||||
if sympy_equal(extract_boxed(raw), p["gold"]):
|
||||
# right→stays-right: model wrote a clean correct solution
|
||||
positives.append({"problem": p["problem"], "solution": raw.strip()})
|
||||
if len(positives) >= args.n_positives: break
|
||||
log(f" positives: {len(positives)} / {args.n_positives}")
|
||||
|
||||
log(f"=== final dataset: {len(wrong_fix)} wrong→fix + {len(positives)} right→stays-right = {len(wrong_fix)+len(positives)} examples ===")
|
||||
|
||||
with open(f"{out_dir}/positives.jsonl", "w") as fh:
|
||||
for p in positives: fh.write(json.dumps(p) + "\n")
|
||||
|
||||
# --- Build training data
|
||||
train_examples = []
|
||||
# wrong→fix as full self-correction traces
|
||||
for r in wrong_fix:
|
||||
train_examples.append({
|
||||
"problem": r["problem"],
|
||||
"solution": r["full_solution"], # already includes wrong + RECONSIDER_TAG + correct
|
||||
})
|
||||
# right→stays-right as plain solutions (no "wait" — model commits)
|
||||
for r in positives:
|
||||
train_examples.append({
|
||||
"problem": r["problem"],
|
||||
"solution": r["solution"],
|
||||
})
|
||||
random.shuffle(train_examples)
|
||||
|
||||
# --- Train LoRA
|
||||
log("=== TRAINING ===")
|
||||
del llm; gc.collect(); torch.cuda.empty_cache()
|
||||
from transformers import AutoModelForCausalLM, TrainingArguments, Trainer
|
||||
from datasets import Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
ds_train = HFDataset.from_list([make_train_example(r["problem"], r["solution"], tok) for r in train_examples])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=2,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds_train, tokenizer=tok).train()
|
||||
log("training done")
|
||||
adapter_dir = f"{out_dir}/adapter"
|
||||
model.save_pretrained(adapter_dir)
|
||||
del model; gc.collect(); torch.cuda.empty_cache()
|
||||
|
||||
# --- TRAINED eval
|
||||
from vllm import LLM
|
||||
from vllm.lora.request import LoRARequest
|
||||
llm = LLM(model=args.model, dtype="bfloat16", gpu_memory_utilization=0.85, max_model_len=2048,
|
||||
enable_lora=True, max_lora_rank=16)
|
||||
lora_req = LoRARequest("tf_adapter", 1, adapter_dir)
|
||||
from vllm import SamplingParams
|
||||
def gen_trained(prompts, max_new=800):
|
||||
sp = SamplingParams(temperature=0, max_tokens=max_new)
|
||||
return [o.outputs[0].text for o in llm.generate(prompts, sp, lora_request=lora_req, use_tqdm=False)]
|
||||
|
||||
log("=== TRAINED eval ===")
|
||||
tr_c, tr_n = math500_eval(gen_trained, "TRAINED")
|
||||
log(f" TRAINED: {tr_c}/{tr_n} ({100*tr_c/tr_n:.1f}%)")
|
||||
|
||||
result = {
|
||||
"model": args.model,
|
||||
"n_wrong_fix": len(wrong_fix),
|
||||
"n_positives": len(positives),
|
||||
"n_total": len(train_examples),
|
||||
"base": base_c, "trained": tr_c, "n": tr_n,
|
||||
"delta": tr_c - base_c,
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — SELF-CORRECTION V2 (mixed: {len(wrong_fix)} wrong→fix + {len(positives)} right→stays)")
|
||||
print(f" MATH-500: base={base_c}/{tr_n} ({100*base_c/tr_n:.1f}%) trained={tr_c}/{tr_n} ({100*tr_c/tr_n:.1f}%) Δ={tr_c-base_c:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
286
experiments/self_correction_math_naive.py
Normal file
286
experiments/self_correction_math_naive.py
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
"""TinyForge-Zero self-correction for MATH-500.
|
||||
|
||||
Recipe:
|
||||
1. Sample real MATH-train problem (no human solutions used).
|
||||
2. Model greedy-attempt → wrong. Capture as wrong_attempt.
|
||||
3. Re-prompt model: {problem} + wrong_attempt + "Wait, let me reconsider:"
|
||||
Sample 4 completions at temp=0.8.
|
||||
4. If any completion gets correct boxed answer (verified via sympy against gold),
|
||||
MINE a triple: (problem, wrong_attempt, reflection+correct).
|
||||
5. Train LoRA on full traces — model learns to catch + fix own errors.
|
||||
6. Eval on MATH-500 (test). Model naturally produces self-correction.
|
||||
|
||||
Key difference from rejection-sampling: training data teaches the FIX,
|
||||
not just the answer. Same broken→fixed structure that worked for code.
|
||||
"""
|
||||
import os, json, time, re, argparse, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from datasets import load_dataset, Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
import sympy
|
||||
from sympy.parsing.latex import parse_latex
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
SOLVE_PROMPT = """Solve this competition math problem. Show your reasoning, then put the final answer in \\boxed{{...}}.
|
||||
|
||||
Problem: {problem}
|
||||
|
||||
Solution:"""
|
||||
|
||||
|
||||
RECONSIDER_TAG = "\n\nWait, let me reconsider — I think there's an error above.\n\n"
|
||||
|
||||
|
||||
def extract_boxed(text):
|
||||
idx = text.rfind("\\boxed{")
|
||||
if idx < 0: return None
|
||||
start = idx + len("\\boxed{")
|
||||
depth = 1; i = start
|
||||
while i < len(text) and depth > 0:
|
||||
if text[i] == "{": depth += 1
|
||||
elif text[i] == "}": depth -= 1
|
||||
i += 1
|
||||
if depth != 0: return None
|
||||
return text[start:i-1].strip()
|
||||
|
||||
|
||||
def normalize(s):
|
||||
if s is None: return None
|
||||
s = s.strip()
|
||||
s = re.sub(r"^\$|\$$", "", s).strip()
|
||||
s = re.sub(r"\\text\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"\\mbox\{([^}]*)\}", r"\1", s)
|
||||
s = re.sub(r"(?<=\d),(?=\d)", "", s)
|
||||
s = s.replace("\\left", "").replace("\\right", "").replace("^\\circ", "").replace("^{\\circ}", "")
|
||||
return s.strip()
|
||||
|
||||
|
||||
def sympy_equal(a, b):
|
||||
if a is None or b is None: return False
|
||||
a, b = normalize(a), normalize(b)
|
||||
if a == b: return True
|
||||
try:
|
||||
ea = parse_latex(a); eb = parse_latex(b)
|
||||
if sympy.simplify(ea - eb) == 0: return True
|
||||
except Exception: pass
|
||||
try:
|
||||
fa = float(a); fb = float(b)
|
||||
if abs(fa - fb) < 1e-6: return True
|
||||
except Exception: pass
|
||||
return False
|
||||
|
||||
|
||||
def chat_messages(user_content):
|
||||
return [{"role": "system", "content": "You are a careful math problem solver. If you make a mistake, catch it and correct yourself."},
|
||||
{"role": "user", "content": user_content}]
|
||||
|
||||
|
||||
def gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=16, prefill_texts=None):
|
||||
"""If prefill_texts provided, append each to its chat-templated prompt (forcing the model to continue from there)."""
|
||||
outs = []
|
||||
for i in range(0, len(prompts), batch):
|
||||
chunk = prompts[i:i+batch]
|
||||
pref_chunk = prefill_texts[i:i+batch] if prefill_texts else [""] * len(chunk)
|
||||
texts = []
|
||||
for p, pre in zip(chunk, pref_chunk):
|
||||
msgs = chat_messages(p)
|
||||
try:
|
||||
base = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
||||
except Exception:
|
||||
base = p
|
||||
texts.append(base + pre)
|
||||
inp = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=2000).to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inp, max_new_tokens=max_new, do_sample=temperature > 0,
|
||||
temperature=temperature if temperature > 0 else 1.0, top_p=0.95,
|
||||
pad_token_id=tok.eos_token_id)
|
||||
for j in range(out.size(0)):
|
||||
outs.append(tok.decode(out[j][inp.input_ids.shape[1]:], skip_special_tokens=True))
|
||||
return outs
|
||||
|
||||
|
||||
def math500_eval(model, tok, n=500, batch=16):
|
||||
ds = list(load_dataset("HuggingFaceH4/MATH-500", split="test"))[:n]
|
||||
log(f" eval on MATH-500 ({len(ds)} problems)")
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in ds]
|
||||
outs = gen_batch(model, tok, prompts, max_new=800, temperature=0.0, batch=batch)
|
||||
correct = 0
|
||||
for p, raw in zip(ds, outs):
|
||||
pred = extract_boxed(raw)
|
||||
if sympy_equal(pred, p["answer"]): correct += 1
|
||||
return correct, len(ds)
|
||||
|
||||
|
||||
def make_train_example(problem, full_solution, tok):
|
||||
"""Train on the full self-correction trace."""
|
||||
user = SOLVE_PROMPT.format(problem=problem)
|
||||
msgs_pre = chat_messages(user)
|
||||
msgs_full = msgs_pre + [{"role": "assistant", "content": full_solution}]
|
||||
pre = tok.apply_chat_template(msgs_pre, tokenize=False, add_generation_prompt=True)
|
||||
full = tok.apply_chat_template(msgs_full, tokenize=False)
|
||||
pre_ids = tok(pre, add_special_tokens=False)["input_ids"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1536
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_pre = min(len(pre_ids), len(labels))
|
||||
for i in range(n_pre): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
|
||||
def train_on_pairs(model, tok, pairs, out_dir, lr=1e-4, epochs=2, rank=16):
|
||||
log(f" training on {len(pairs)} traces (lr={lr}, e={epochs}, r={rank})")
|
||||
lora_cfg = LoraConfig(r=rank, lora_alpha=rank*2, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
tok.padding_side = "right"
|
||||
ds = HFDataset.from_list([make_train_example(p["problem"], p["full_solution"], tok) for p in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=epochs,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=lr, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds, processing_class=tok).train()
|
||||
tok.padding_side = "left"
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True)
|
||||
ap.add_argument("--iterations", type=int, default=8)
|
||||
ap.add_argument("--problems_per_iter", type=int, default=48)
|
||||
ap.add_argument("--n_eval", type=int, default=500)
|
||||
ap.add_argument("--max_pairs", type=int, default=100)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
out_dir = f"/workspace/math500_sc/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
random.seed(args.seed); torch.manual_seed(args.seed)
|
||||
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=torch.bfloat16, device_map="cuda:0")
|
||||
log(f" loaded mem={torch.cuda.memory_allocated('cuda:0')/1e9:.1f}GB")
|
||||
|
||||
log("loading MATH train split")
|
||||
train_ds = []
|
||||
for cfg in ["algebra","counting_and_probability","geometry","intermediate_algebra","number_theory","prealgebra","precalculus"]:
|
||||
try:
|
||||
sub = list(load_dataset("EleutherAI/hendrycks_math", cfg, split="train"))
|
||||
train_ds.extend(sub)
|
||||
except Exception as e:
|
||||
log(f" warn: failed to load {cfg}: {e}")
|
||||
log(f" {len(train_ds)} train problems")
|
||||
random.shuffle(train_ds)
|
||||
|
||||
def gold_of(p):
|
||||
return extract_boxed(p.get("solution", ""))
|
||||
|
||||
model.eval()
|
||||
log("INITIAL eval on MATH-500")
|
||||
base_c, base_n = math500_eval(model, tok, n=args.n_eval)
|
||||
log(f" MATH-500 base: {base_c}/{base_n} ({100*base_c/base_n:.1f}%)")
|
||||
|
||||
pairs = []
|
||||
cursor = 0
|
||||
|
||||
for it in range(1, args.iterations + 1):
|
||||
log(f"--- iter {it} ---")
|
||||
# Sample problems from MATH-train
|
||||
batch_problems = []
|
||||
while len(batch_problems) < args.problems_per_iter and cursor < len(train_ds):
|
||||
p = train_ds[cursor]; cursor += 1
|
||||
g = gold_of(p)
|
||||
if g is not None: batch_problems.append({"problem": p["problem"], "gold": g})
|
||||
if not batch_problems:
|
||||
log(" exhausted train problems"); break
|
||||
|
||||
# Step 1: Greedy attempt
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["problem"]) for p in batch_problems]
|
||||
greedy_outs = gen_batch(model, tok, prompts, max_new=600, temperature=0.0, batch=16)
|
||||
wrong_attempts = []
|
||||
for i, (p, raw) in enumerate(zip(batch_problems, greedy_outs)):
|
||||
pred = extract_boxed(raw)
|
||||
if not sympy_equal(pred, p["gold"]):
|
||||
wrong_attempts.append({"idx": i, "problem": p["problem"], "gold": p["gold"], "wrong": raw.strip()})
|
||||
log(f" iter {it}: {len(wrong_attempts)}/{len(batch_problems)} wrong on greedy (mining candidates)")
|
||||
if not wrong_attempts:
|
||||
continue
|
||||
|
||||
# Step 2: Self-correct prompt (prefill wrong attempt + reconsider tag, sample 4)
|
||||
sc_problems = []
|
||||
prefills = []
|
||||
for w in wrong_attempts:
|
||||
for _ in range(4):
|
||||
sc_problems.append(w["problem"])
|
||||
prefills.append(w["wrong"] + RECONSIDER_TAG)
|
||||
sc_prompts = [SOLVE_PROMPT.format(problem=p) for p in sc_problems]
|
||||
sc_outs = gen_batch(model, tok, sc_prompts, max_new=600, temperature=0.8, batch=16, prefill_texts=prefills)
|
||||
|
||||
mined_this_iter = 0
|
||||
for j, w in enumerate(wrong_attempts):
|
||||
attempts = sc_outs[j*4:(j+1)*4]
|
||||
preds = [extract_boxed(a) for a in attempts]
|
||||
correct_idx = [k for k, pr in enumerate(preds) if sympy_equal(pr, w["gold"])]
|
||||
if correct_idx:
|
||||
# construct full trace
|
||||
fix = attempts[correct_idx[0]].strip()
|
||||
full = w["wrong"] + RECONSIDER_TAG + fix
|
||||
pairs.append({"problem": w["problem"], "wrong_attempt": w["wrong"],
|
||||
"correction": fix, "full_solution": full})
|
||||
mined_this_iter += 1
|
||||
log(f" iter {it}: MINED {mined_this_iter} self-correction triples — total={len(pairs)}")
|
||||
|
||||
if len(pairs) >= args.max_pairs:
|
||||
log(f" reached max_pairs={args.max_pairs}, stopping"); break
|
||||
|
||||
log(f"=== mined {len(pairs)} total self-correction triples ===")
|
||||
with open(f"{out_dir}/pairs.jsonl", "w") as fh:
|
||||
for p in pairs: fh.write(json.dumps(p) + "\n")
|
||||
|
||||
if not pairs:
|
||||
log("no triples — exiting"); return
|
||||
|
||||
model = train_on_pairs(model, tok, pairs, out_dir)
|
||||
log("training done")
|
||||
|
||||
model.eval()
|
||||
log("FINAL eval on MATH-500")
|
||||
tr_c, tr_n = math500_eval(model, tok, n=args.n_eval)
|
||||
log(f" MATH-500 trained: {tr_c}/{tr_n} ({100*tr_c/tr_n:.1f}%)")
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_pairs": len(pairs),
|
||||
"base": base_c, "trained": tr_c, "n": tr_n,
|
||||
"delta": tr_c - base_c, "elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh: json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" {args.model} — SELF-CORRECTION recipe")
|
||||
print(f" MATH-500: base={base_c}/{tr_n} trained={tr_c}/{tr_n} Δ={tr_c-base_c:+d}")
|
||||
print(f" Triples mined: {len(pairs)}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
204
experiments/star_baseline_gsm8k.py
Normal file
204
experiments/star_baseline_gsm8k.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""STaR / Rejection Sampling Fine-Tuning on GSM8K.
|
||||
|
||||
For each GSM8K-train problem:
|
||||
- sample N reasoning chains at temp=0.8
|
||||
- keep chains that produce correct final answer
|
||||
- train on (problem, correct chain) pairs
|
||||
Then eval on GSM8K-test.
|
||||
"""
|
||||
import os, sys, json, time, re, gc, argparse, random
|
||||
os.environ.setdefault("HF_HOME", "/workspace/hf")
|
||||
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "1")
|
||||
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
|
||||
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
|
||||
from datasets import load_dataset, Dataset as HFDataset
|
||||
from peft import LoraConfig, get_peft_model
|
||||
|
||||
T0 = time.time()
|
||||
def log(m): print(f"[{time.time()-T0:7.1f}s] {m}", flush=True)
|
||||
|
||||
|
||||
def extract_answer(text: str):
|
||||
m = re.search(r"####\s*(-?\d+(?:\.\d+)?)", text)
|
||||
if m: return float(m.group(1))
|
||||
m = re.search(r"\\boxed\{(-?\d+(?:\.\d+)?)\}", text)
|
||||
if m: return float(m.group(1))
|
||||
matches = re.findall(r"-?\d+(?:\.\d+)?", text)
|
||||
if matches:
|
||||
try: return float(matches[-1])
|
||||
except: return None
|
||||
return None
|
||||
|
||||
|
||||
def gen_batch(model, tok, prompts, max_new=400, temperature=0.0, batch=8):
|
||||
outs = []
|
||||
for i in range(0, len(prompts), batch):
|
||||
chunk = prompts[i:i+batch]
|
||||
texts = []
|
||||
for p in chunk:
|
||||
msgs = [{"role": "system", "content": "You are a careful math tutor."},
|
||||
{"role": "user", "content": p}]
|
||||
texts.append(tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True))
|
||||
inp = tok(texts, return_tensors="pt", padding=True, truncation=True, max_length=1500).to(model.device)
|
||||
with torch.no_grad():
|
||||
out = model.generate(**inp, max_new_tokens=max_new, do_sample=temperature > 0,
|
||||
temperature=temperature if temperature > 0 else 1.0, top_p=0.95,
|
||||
pad_token_id=tok.eos_token_id)
|
||||
for j in range(out.size(0)):
|
||||
outs.append(tok.decode(out[j][inp.input_ids.shape[1]:], skip_special_tokens=True))
|
||||
return outs
|
||||
|
||||
|
||||
SOLVE_PROMPT = "Solve this math problem step by step. End with the answer on a new line as: #### <number>\n\nProblem: {problem}"
|
||||
|
||||
|
||||
def parse_gold(answer_field: str):
|
||||
m = re.search(r"####\s*(-?\d+(?:,\d+)*(?:\.\d+)?)", answer_field)
|
||||
return float(m.group(1).replace(",", "")) if m else None
|
||||
|
||||
|
||||
def gsm8k_eval(model, tok, n=200):
|
||||
ds = list(load_dataset("openai/gsm8k", "main", split="test"))[:n]
|
||||
log(f" eval on GSM8K-test ({len(ds)} problems)")
|
||||
prompts = [SOLVE_PROMPT.format(problem=p["question"]) for p in ds]
|
||||
outs = gen_batch(model, tok, prompts, max_new=400, temperature=0.0, batch=8)
|
||||
correct = 0
|
||||
for p, raw in zip(ds, outs):
|
||||
gold = parse_gold(p["answer"])
|
||||
if gold is None: continue
|
||||
pred = extract_answer(raw)
|
||||
if pred is not None and abs(pred - gold) < 0.01: correct += 1
|
||||
return correct, len(ds)
|
||||
|
||||
|
||||
def make_train_example(problem: str, solution: str, tok):
|
||||
user = SOLVE_PROMPT.format(problem=problem)
|
||||
msgs_pre = [{"role": "system", "content": "You are a careful math tutor."},
|
||||
{"role": "user", "content": user}]
|
||||
msgs_full = msgs_pre + [{"role": "assistant", "content": solution}]
|
||||
pre = tok.apply_chat_template(msgs_pre, tokenize=False, add_generation_prompt=True)
|
||||
full = tok.apply_chat_template(msgs_full, tokenize=False)
|
||||
pre_ids = tok(pre, add_special_tokens=False)["input_ids"]
|
||||
full_ids = tok(full, add_special_tokens=False)["input_ids"]
|
||||
MAX = 1024
|
||||
full_ids = full_ids[:MAX]
|
||||
labels = list(full_ids)
|
||||
n_pre = min(len(pre_ids), len(labels))
|
||||
for i in range(n_pre): labels[i] = -100
|
||||
pad = MAX - len(full_ids)
|
||||
return {"input_ids": full_ids + [tok.pad_token_id]*pad,
|
||||
"attention_mask": [1]*len(full_ids) + [0]*pad,
|
||||
"labels": labels + [-100]*pad}
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", default="Qwen/Qwen2.5-3B")
|
||||
ap.add_argument("--n_train_problems", type=int, default=300)
|
||||
ap.add_argument("--n_chains", type=int, default=8)
|
||||
ap.add_argument("--n_eval", type=int, default=200)
|
||||
ap.add_argument("--epochs", type=int, default=2)
|
||||
ap.add_argument("--seed", type=int, default=42)
|
||||
ap.add_argument("--tag", required=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
random.seed(args.seed); torch.manual_seed(args.seed)
|
||||
out_dir = f"/workspace/star/{args.tag}"
|
||||
os.makedirs(out_dir, exist_ok=True)
|
||||
|
||||
log(f"loading {args.model}")
|
||||
tok = AutoTokenizer.from_pretrained(args.model)
|
||||
if tok.pad_token is None: tok.pad_token = tok.eos_token
|
||||
tok.padding_side = "left"
|
||||
model = AutoModelForCausalLM.from_pretrained(args.model, dtype=torch.bfloat16, device_map="cuda:0")
|
||||
log(f" loaded mem={torch.cuda.memory_allocated('cuda:0')/1e9:.1f}GB")
|
||||
|
||||
# Initial eval on GSM8K-test
|
||||
model.eval()
|
||||
log("INITIAL eval on GSM8K-test")
|
||||
base_correct, base_total = gsm8k_eval(model, tok, n=args.n_eval)
|
||||
log(f" GSM8K-test base: {base_correct}/{base_total}")
|
||||
|
||||
# Mine reasoning chains from GSM8K-train
|
||||
log(f"mining reasoning chains from GSM8K-train ({args.n_train_problems} problems × {args.n_chains} chains)")
|
||||
train_set = list(load_dataset("openai/gsm8k", "main", split="train"))[:args.n_train_problems]
|
||||
pairs = []
|
||||
BATCH_PROBLEMS = 8 # batch problems together
|
||||
for batch_start in range(0, len(train_set), BATCH_PROBLEMS):
|
||||
batch_end = min(batch_start + BATCH_PROBLEMS, len(train_set))
|
||||
batch_problems = train_set[batch_start:batch_end]
|
||||
# For each problem, generate N chains. So total = batch_size * N
|
||||
prompts = []
|
||||
for p in batch_problems:
|
||||
for _ in range(args.n_chains):
|
||||
prompts.append(SOLVE_PROMPT.format(problem=p["question"]))
|
||||
outs = gen_batch(model, tok, prompts, max_new=400, temperature=0.8, batch=8)
|
||||
# Outs are in problem-major × chain-major order
|
||||
for i, p in enumerate(batch_problems):
|
||||
gold = parse_gold(p["answer"])
|
||||
if gold is None: continue
|
||||
chain_outs = outs[i*args.n_chains : (i+1)*args.n_chains]
|
||||
for raw in chain_outs:
|
||||
pred = extract_answer(raw)
|
||||
if pred is not None and abs(pred - gold) < 0.01:
|
||||
pairs.append({"problem": p["question"], "solution": raw.strip()})
|
||||
break # take first correct chain per problem
|
||||
log(f" mined {len(pairs)} pairs from {batch_end} problems")
|
||||
|
||||
if not pairs:
|
||||
log("FATAL: no pairs mined")
|
||||
return
|
||||
with open(f"{out_dir}/pairs.jsonl", "w") as fh:
|
||||
for p in pairs: fh.write(json.dumps(p) + "\n")
|
||||
log(f"total pairs mined: {len(pairs)} from {len(train_set)} problems "
|
||||
f"(coverage: {len(pairs)/len(train_set)*100:.1f}%)")
|
||||
|
||||
# Train
|
||||
log(f"TRAINING on {len(pairs)} pairs, {args.epochs} epochs")
|
||||
lora_cfg = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], task_type="CAUSAL_LM")
|
||||
model = get_peft_model(model, lora_cfg)
|
||||
tok.padding_side = "right"
|
||||
ds = HFDataset.from_list([make_train_example(p["problem"], p["solution"], tok) for p in pairs])
|
||||
targs = TrainingArguments(
|
||||
output_dir=f"{out_dir}/ckpt", num_train_epochs=args.epochs,
|
||||
per_device_train_batch_size=1, gradient_accumulation_steps=4,
|
||||
learning_rate=1e-4, bf16=True, logging_steps=20,
|
||||
save_strategy="no", report_to="none", remove_unused_columns=False, warmup_ratio=0.05,
|
||||
)
|
||||
Trainer(model=model, args=targs, train_dataset=ds, processing_class=tok).train()
|
||||
log("training done")
|
||||
tok.padding_side = "left"
|
||||
|
||||
# Final eval
|
||||
model.eval()
|
||||
log("FINAL eval on GSM8K-test")
|
||||
trained_correct, trained_total = gsm8k_eval(model, tok, n=args.n_eval)
|
||||
log(f" GSM8K-test trained: {trained_correct}/{trained_total}")
|
||||
|
||||
result = {
|
||||
"model": args.model, "n_train_problems": args.n_train_problems,
|
||||
"n_chains": args.n_chains, "n_pairs_mined": len(pairs),
|
||||
"epochs": args.epochs, "seed": args.seed,
|
||||
"base": [base_correct, base_total],
|
||||
"trained": [trained_correct, trained_total],
|
||||
"delta": trained_correct - base_correct,
|
||||
"elapsed_s": time.time() - T0,
|
||||
}
|
||||
with open(f"{out_dir}/result.json", "w") as fh:
|
||||
json.dump(result, fh, indent=2)
|
||||
|
||||
print()
|
||||
print("=" * 70)
|
||||
print(f" STaR / RFT on GSM8K — {args.model}")
|
||||
print(f" Mined {len(pairs)} pairs from {len(train_set)} GSM8K-train problems ({len(pairs)/len(train_set)*100:.1f}% coverage)")
|
||||
print(f" GSM8K-test: base={base_correct}/{base_total} trained={trained_correct}/{trained_total} Δ={trained_correct-base_correct:+d}")
|
||||
print(f" Time: {time.time()-T0:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue