mirror of
https://github.com/alainnothere/llm-circuit-finder.git
synced 2026-04-24 20:56:21 +02:00
Add files via upload
This commit is contained in:
parent
91017bbd68
commit
add2765a15
1 changed files with 307 additions and 0 deletions
307
comprehensive_probe.py
Normal file
307
comprehensive_probe.py
Normal file
|
|
@ -0,0 +1,307 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Comprehensive RYS evaluation probe.
|
||||||
|
|
||||||
|
Pulls questions from cached HuggingFace datasets (BBH, GSM8K) plus
|
||||||
|
our custom EQ scenarios. All questions produce short outputs and are
|
||||||
|
objectively scorable without a judge model.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python comprehensive_probe.py --port 8089 --output results_base.json
|
||||||
|
python comprehensive_probe.py --port 8089 --output results_rys.json
|
||||||
|
python comprehensive_probe.py --compare results_base.json results_rys.json
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from datasets import load_dataset
|
||||||
|
|
||||||
|
from eq_probe import EQ_SCENARIOS, build_eq_prompt, parse_eq_response, score_eq_response
|
||||||
|
|
||||||
|
|
||||||
|
REQUEST_TIMEOUT = 120
|
||||||
|
|
||||||
|
|
||||||
|
def query_model(prompt: str, port: int, max_tokens: int = 512) -> str | None:
|
||||||
|
url = f"http://127.0.0.1:{port}/v1/chat/completions"
|
||||||
|
payload = {
|
||||||
|
"model": "test",
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
r = requests.post(url, json=payload, timeout=REQUEST_TIMEOUT)
|
||||||
|
if r.status_code == 200:
|
||||||
|
return r.json()["choices"][0]["message"]["content"]
|
||||||
|
except (requests.ConnectionError, requests.Timeout) as e:
|
||||||
|
print(f" [WARN] {e}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Dataset loaders ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def load_gsm8k_questions(limit=50):
|
||||||
|
ds = load_dataset("openai/gsm8k", "main")
|
||||||
|
test = ds["test"]
|
||||||
|
total = len(test)
|
||||||
|
step = max(1, total // limit)
|
||||||
|
indices = list(range(total - 1, -1, -step))[:limit]
|
||||||
|
indices.reverse()
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
for idx in indices:
|
||||||
|
item = test[idx]
|
||||||
|
answer_match = re.search(r'####\s*(-?[\d,]+)', item["answer"])
|
||||||
|
if answer_match:
|
||||||
|
answer = answer_match.group(1).replace(",", "")
|
||||||
|
questions.append({
|
||||||
|
"prompt": item["question"] + "\n\nSolve step by step. End with 'The answer is [NUMBER]'.",
|
||||||
|
"answer": answer,
|
||||||
|
"type": "gsm8k",
|
||||||
|
})
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def load_bbh_questions(subtask, limit=50):
|
||||||
|
ds = load_dataset("SaylorTwift/bbh", subtask)
|
||||||
|
test = ds["test"]
|
||||||
|
total = len(test)
|
||||||
|
step = max(1, total // limit)
|
||||||
|
indices = list(range(0, total, step))[:limit]
|
||||||
|
|
||||||
|
questions = []
|
||||||
|
for idx in indices:
|
||||||
|
item = test[idx]
|
||||||
|
answer_match = re.search(r'the answer is (.+?)\.?\s*$', item["target"], re.IGNORECASE)
|
||||||
|
if answer_match:
|
||||||
|
answer = answer_match.group(1).strip()
|
||||||
|
else:
|
||||||
|
answer = item["target"].strip().split()[-1].rstrip(".")
|
||||||
|
|
||||||
|
questions.append({
|
||||||
|
"prompt": item["input"] + "\n\nThink step by step, then give your final answer.",
|
||||||
|
"answer": answer,
|
||||||
|
"type": f"bbh_{subtask}",
|
||||||
|
})
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
def load_eq_questions():
|
||||||
|
questions = []
|
||||||
|
for scenario in EQ_SCENARIOS:
|
||||||
|
questions.append({
|
||||||
|
"prompt": build_eq_prompt(scenario),
|
||||||
|
"answer": scenario["reference"],
|
||||||
|
"emotions": scenario["emotions"],
|
||||||
|
"type": "eq",
|
||||||
|
"id": scenario["id"],
|
||||||
|
})
|
||||||
|
return questions
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Scoring ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def extract_final_answer(response: str) -> str:
|
||||||
|
match = re.search(r'the answer is\s+(.+?)[\.\!\n]', response, re.IGNORECASE)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
match = re.search(r'####\s*(.+)', response)
|
||||||
|
if match:
|
||||||
|
return match.group(1).strip()
|
||||||
|
lines = response.strip().split('\n')
|
||||||
|
return lines[-1].strip()
|
||||||
|
|
||||||
|
|
||||||
|
def extract_number(text: str) -> str | None:
|
||||||
|
nums = re.findall(r'-?[\d,]+\.?\d*', text)
|
||||||
|
if nums:
|
||||||
|
return nums[-1].replace(",", "")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def score_question(question: dict, response: str) -> dict:
|
||||||
|
if response is None:
|
||||||
|
return {"score": 0.0, "parsed": None, "correct": question["answer"]}
|
||||||
|
|
||||||
|
qtype = question["type"]
|
||||||
|
|
||||||
|
if qtype == "eq":
|
||||||
|
predicted = parse_eq_response(response, len(question["answer"]))
|
||||||
|
score = score_eq_response(question["answer"], predicted) / 100.0
|
||||||
|
return {"score": score, "parsed": predicted, "correct": question["answer"]}
|
||||||
|
|
||||||
|
elif qtype == "gsm8k":
|
||||||
|
final = extract_final_answer(response)
|
||||||
|
pred_num = extract_number(final)
|
||||||
|
correct_num = question["answer"]
|
||||||
|
if pred_num is not None and pred_num == correct_num:
|
||||||
|
score = 1.0
|
||||||
|
elif pred_num is not None:
|
||||||
|
try:
|
||||||
|
diff = abs(float(pred_num) - float(correct_num))
|
||||||
|
max_val = max(abs(float(correct_num)), 1)
|
||||||
|
score = max(0, 1.0 - diff / max_val) * 0.5
|
||||||
|
except ValueError:
|
||||||
|
score = 0.0
|
||||||
|
else:
|
||||||
|
score = 0.0
|
||||||
|
return {"score": score, "parsed": pred_num, "correct": correct_num}
|
||||||
|
|
||||||
|
else: # BBH
|
||||||
|
final = extract_final_answer(response)
|
||||||
|
correct = question["answer"].strip().lower()
|
||||||
|
final_clean = final.strip().lower()
|
||||||
|
final_clean = re.sub(r'[^a-z0-9\s\(\)]', '', final_clean).strip()
|
||||||
|
correct_clean = re.sub(r'[^a-z0-9\s\(\)]', '', correct).strip()
|
||||||
|
|
||||||
|
if correct_clean in final_clean or final_clean == correct_clean:
|
||||||
|
score = 1.0
|
||||||
|
elif correct_clean in ("yes", "no"):
|
||||||
|
if correct_clean in final_clean.split():
|
||||||
|
score = 1.0
|
||||||
|
else:
|
||||||
|
score = 0.0
|
||||||
|
else:
|
||||||
|
score = 0.0
|
||||||
|
return {"score": score, "parsed": final, "correct": question["answer"]}
|
||||||
|
|
||||||
|
|
||||||
|
# ─── Main evaluation ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def run_full_evaluation(port: int, gsm8k_limit: int = 50, bbh_limit: int = 50) -> dict:
|
||||||
|
print("Loading questions...")
|
||||||
|
all_questions = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
gsm = load_gsm8k_questions(limit=gsm8k_limit)
|
||||||
|
all_questions.extend(gsm)
|
||||||
|
print(f" GSM8K: {len(gsm)} questions")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" GSM8K: FAILED ({e})")
|
||||||
|
|
||||||
|
for subtask in ["causal_judgement", "date_understanding",
|
||||||
|
"logical_deduction_five_objects", "navigate",
|
||||||
|
"boolean_expressions", "tracking_shuffled_objects_three_objects"]:
|
||||||
|
try:
|
||||||
|
bbh = load_bbh_questions(subtask, limit=bbh_limit)
|
||||||
|
all_questions.extend(bbh)
|
||||||
|
print(f" BBH {subtask}: {len(bbh)} questions")
|
||||||
|
except Exception as e:
|
||||||
|
print(f" BBH {subtask}: FAILED ({e})")
|
||||||
|
|
||||||
|
eq = load_eq_questions()
|
||||||
|
all_questions.extend(eq)
|
||||||
|
print(f" EQ: {len(eq)} scenarios")
|
||||||
|
|
||||||
|
total = len(all_questions)
|
||||||
|
print(f"\nTotal: {total} questions")
|
||||||
|
print("Running evaluation...\n")
|
||||||
|
|
||||||
|
results_by_type = {}
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for i, q in enumerate(all_questions):
|
||||||
|
qtype = q["type"]
|
||||||
|
if qtype not in results_by_type:
|
||||||
|
results_by_type[qtype] = {"scores": [], "total": 0, "correct": 0}
|
||||||
|
|
||||||
|
response = query_model(q["prompt"], port,
|
||||||
|
max_tokens=512 if qtype == "gsm8k" else 256)
|
||||||
|
result = score_question(q, response)
|
||||||
|
|
||||||
|
results_by_type[qtype]["scores"].append(result["score"])
|
||||||
|
results_by_type[qtype]["total"] += 1
|
||||||
|
if result["score"] >= 0.99:
|
||||||
|
results_by_type[qtype]["correct"] += 1
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||||
|
eta = (total - i - 1) / rate if rate > 0 else 0
|
||||||
|
print(f"\r [{i+1}/{total}] {qtype:40s} "
|
||||||
|
f"score={result['score']:.2f} "
|
||||||
|
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", end="", flush=True)
|
||||||
|
|
||||||
|
print("\n")
|
||||||
|
elapsed_total = time.time() - start_time
|
||||||
|
|
||||||
|
print("=" * 70)
|
||||||
|
print(f"{'Probe':40s} {'Accuracy':>8} {'Avg Score':>10} {'N':>5}")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
overall_scores = []
|
||||||
|
summary = {}
|
||||||
|
|
||||||
|
for qtype in sorted(results_by_type.keys()):
|
||||||
|
data = results_by_type[qtype]
|
||||||
|
avg = sum(data["scores"]) / len(data["scores"]) if data["scores"] else 0
|
||||||
|
acc = data["correct"] / data["total"] if data["total"] > 0 else 0
|
||||||
|
print(f" {qtype:38s} {acc:>8.1%} {avg:>10.4f} {data['total']:>5}")
|
||||||
|
overall_scores.extend(data["scores"])
|
||||||
|
summary[qtype] = {"accuracy": acc, "avg_score": avg, "n": data["total"]}
|
||||||
|
|
||||||
|
overall_avg = sum(overall_scores) / len(overall_scores) if overall_scores else 0
|
||||||
|
print("-" * 70)
|
||||||
|
print(f" {'OVERALL':38s} {'':>8} {overall_avg:>10.4f} {len(overall_scores):>5}")
|
||||||
|
print(f"\nCompleted in {elapsed_total:.0f}s")
|
||||||
|
print("=" * 70)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"summary": summary,
|
||||||
|
"overall": overall_avg,
|
||||||
|
"elapsed": elapsed_total,
|
||||||
|
"n_questions": len(overall_scores),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def compare_results(file1: str, file2: str):
|
||||||
|
with open(file1) as f:
|
||||||
|
r1 = json.load(f)
|
||||||
|
with open(file2) as f:
|
||||||
|
r2 = json.load(f)
|
||||||
|
|
||||||
|
print(f"\n{'Probe':40s} {'Base':>8} {'RYS':>8} {'Delta':>8}")
|
||||||
|
print("-" * 70)
|
||||||
|
|
||||||
|
for qtype in sorted(set(list(r1["summary"].keys()) + list(r2["summary"].keys()))):
|
||||||
|
s1 = r1["summary"].get(qtype, {}).get("avg_score", 0)
|
||||||
|
s2 = r2["summary"].get(qtype, {}).get("avg_score", 0)
|
||||||
|
delta = s2 - s1
|
||||||
|
print(f" {qtype:38s} {s1:>8.4f} {s2:>8.4f} {delta:>+8.4f}")
|
||||||
|
|
||||||
|
print("-" * 70)
|
||||||
|
delta_overall = r2["overall"] - r1["overall"]
|
||||||
|
print(f" {'OVERALL':38s} {r1['overall']:>8.4f} {r2['overall']:>8.4f} {delta_overall:>+8.4f}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Comprehensive RYS evaluation")
|
||||||
|
parser.add_argument("--port", type=int, default=8089)
|
||||||
|
parser.add_argument("--gsm8k-limit", type=int, default=50)
|
||||||
|
parser.add_argument("--bbh-limit", type=int, default=50)
|
||||||
|
parser.add_argument("--output", type=str, default=None,
|
||||||
|
help="Save results to JSON file")
|
||||||
|
parser.add_argument("--compare", nargs=2, metavar=("BASE", "RYS"),
|
||||||
|
help="Compare two result files")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.compare:
|
||||||
|
compare_results(args.compare[0], args.compare[1])
|
||||||
|
return
|
||||||
|
|
||||||
|
results = run_full_evaluation(args.port, args.gsm8k_limit, args.bbh_limit)
|
||||||
|
|
||||||
|
if args.output:
|
||||||
|
with open(args.output, "w") as f:
|
||||||
|
json.dump(results, f, indent=2)
|
||||||
|
print(f"\nSaved to {args.output}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue