mirror of
https://github.com/alainnothere/llm-circuit-finder.git
synced 2026-04-24 20:56:21 +02:00
Add files via upload
This commit is contained in:
parent
91017bbd68
commit
add2765a15
1 changed files with 307 additions and 0 deletions
307
comprehensive_probe.py
Normal file
307
comprehensive_probe.py
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Comprehensive RYS evaluation probe.
|
||||
|
||||
Pulls questions from cached HuggingFace datasets (BBH, GSM8K) plus
|
||||
our custom EQ scenarios. All questions produce short outputs and are
|
||||
objectively scorable without a judge model.
|
||||
|
||||
Usage:
|
||||
python comprehensive_probe.py --port 8089 --output results_base.json
|
||||
python comprehensive_probe.py --port 8089 --output results_rys.json
|
||||
python comprehensive_probe.py --compare results_base.json results_rys.json
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from datasets import load_dataset
|
||||
|
||||
from eq_probe import EQ_SCENARIOS, build_eq_prompt, parse_eq_response, score_eq_response
|
||||
|
||||
|
||||
REQUEST_TIMEOUT = 120
|
||||
|
||||
|
||||
def query_model(prompt: str, port: int, max_tokens: int = 512) -> str | None:
|
||||
url = f"http://127.0.0.1:{port}/v1/chat/completions"
|
||||
payload = {
|
||||
"model": "test",
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
try:
|
||||
r = requests.post(url, json=payload, timeout=REQUEST_TIMEOUT)
|
||||
if r.status_code == 200:
|
||||
return r.json()["choices"][0]["message"]["content"]
|
||||
except (requests.ConnectionError, requests.Timeout) as e:
|
||||
print(f" [WARN] {e}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
# ─── Dataset loaders ───────────────────────────────────────────────
|
||||
|
||||
def load_gsm8k_questions(limit=50):
|
||||
ds = load_dataset("openai/gsm8k", "main")
|
||||
test = ds["test"]
|
||||
total = len(test)
|
||||
step = max(1, total // limit)
|
||||
indices = list(range(total - 1, -1, -step))[:limit]
|
||||
indices.reverse()
|
||||
|
||||
questions = []
|
||||
for idx in indices:
|
||||
item = test[idx]
|
||||
answer_match = re.search(r'####\s*(-?[\d,]+)', item["answer"])
|
||||
if answer_match:
|
||||
answer = answer_match.group(1).replace(",", "")
|
||||
questions.append({
|
||||
"prompt": item["question"] + "\n\nSolve step by step. End with 'The answer is [NUMBER]'.",
|
||||
"answer": answer,
|
||||
"type": "gsm8k",
|
||||
})
|
||||
return questions
|
||||
|
||||
|
||||
def load_bbh_questions(subtask, limit=50):
|
||||
ds = load_dataset("SaylorTwift/bbh", subtask)
|
||||
test = ds["test"]
|
||||
total = len(test)
|
||||
step = max(1, total // limit)
|
||||
indices = list(range(0, total, step))[:limit]
|
||||
|
||||
questions = []
|
||||
for idx in indices:
|
||||
item = test[idx]
|
||||
answer_match = re.search(r'the answer is (.+?)\.?\s*$', item["target"], re.IGNORECASE)
|
||||
if answer_match:
|
||||
answer = answer_match.group(1).strip()
|
||||
else:
|
||||
answer = item["target"].strip().split()[-1].rstrip(".")
|
||||
|
||||
questions.append({
|
||||
"prompt": item["input"] + "\n\nThink step by step, then give your final answer.",
|
||||
"answer": answer,
|
||||
"type": f"bbh_{subtask}",
|
||||
})
|
||||
return questions
|
||||
|
||||
|
||||
def load_eq_questions():
|
||||
questions = []
|
||||
for scenario in EQ_SCENARIOS:
|
||||
questions.append({
|
||||
"prompt": build_eq_prompt(scenario),
|
||||
"answer": scenario["reference"],
|
||||
"emotions": scenario["emotions"],
|
||||
"type": "eq",
|
||||
"id": scenario["id"],
|
||||
})
|
||||
return questions
|
||||
|
||||
|
||||
# ─── Scoring ───────────────────────────────────────────────────────
|
||||
|
||||
def extract_final_answer(response: str) -> str:
|
||||
match = re.search(r'the answer is\s+(.+?)[\.\!\n]', response, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
match = re.search(r'####\s*(.+)', response)
|
||||
if match:
|
||||
return match.group(1).strip()
|
||||
lines = response.strip().split('\n')
|
||||
return lines[-1].strip()
|
||||
|
||||
|
||||
def extract_number(text: str) -> str | None:
|
||||
nums = re.findall(r'-?[\d,]+\.?\d*', text)
|
||||
if nums:
|
||||
return nums[-1].replace(",", "")
|
||||
return None
|
||||
|
||||
|
||||
def score_question(question: dict, response: str) -> dict:
|
||||
if response is None:
|
||||
return {"score": 0.0, "parsed": None, "correct": question["answer"]}
|
||||
|
||||
qtype = question["type"]
|
||||
|
||||
if qtype == "eq":
|
||||
predicted = parse_eq_response(response, len(question["answer"]))
|
||||
score = score_eq_response(question["answer"], predicted) / 100.0
|
||||
return {"score": score, "parsed": predicted, "correct": question["answer"]}
|
||||
|
||||
elif qtype == "gsm8k":
|
||||
final = extract_final_answer(response)
|
||||
pred_num = extract_number(final)
|
||||
correct_num = question["answer"]
|
||||
if pred_num is not None and pred_num == correct_num:
|
||||
score = 1.0
|
||||
elif pred_num is not None:
|
||||
try:
|
||||
diff = abs(float(pred_num) - float(correct_num))
|
||||
max_val = max(abs(float(correct_num)), 1)
|
||||
score = max(0, 1.0 - diff / max_val) * 0.5
|
||||
except ValueError:
|
||||
score = 0.0
|
||||
else:
|
||||
score = 0.0
|
||||
return {"score": score, "parsed": pred_num, "correct": correct_num}
|
||||
|
||||
else: # BBH
|
||||
final = extract_final_answer(response)
|
||||
correct = question["answer"].strip().lower()
|
||||
final_clean = final.strip().lower()
|
||||
final_clean = re.sub(r'[^a-z0-9\s\(\)]', '', final_clean).strip()
|
||||
correct_clean = re.sub(r'[^a-z0-9\s\(\)]', '', correct).strip()
|
||||
|
||||
if correct_clean in final_clean or final_clean == correct_clean:
|
||||
score = 1.0
|
||||
elif correct_clean in ("yes", "no"):
|
||||
if correct_clean in final_clean.split():
|
||||
score = 1.0
|
||||
else:
|
||||
score = 0.0
|
||||
else:
|
||||
score = 0.0
|
||||
return {"score": score, "parsed": final, "correct": question["answer"]}
|
||||
|
||||
|
||||
# ─── Main evaluation ───────────────────────────────────────────────
|
||||
|
||||
def run_full_evaluation(port: int, gsm8k_limit: int = 50, bbh_limit: int = 50) -> dict:
|
||||
print("Loading questions...")
|
||||
all_questions = []
|
||||
|
||||
try:
|
||||
gsm = load_gsm8k_questions(limit=gsm8k_limit)
|
||||
all_questions.extend(gsm)
|
||||
print(f" GSM8K: {len(gsm)} questions")
|
||||
except Exception as e:
|
||||
print(f" GSM8K: FAILED ({e})")
|
||||
|
||||
for subtask in ["causal_judgement", "date_understanding",
|
||||
"logical_deduction_five_objects", "navigate",
|
||||
"boolean_expressions", "tracking_shuffled_objects_three_objects"]:
|
||||
try:
|
||||
bbh = load_bbh_questions(subtask, limit=bbh_limit)
|
||||
all_questions.extend(bbh)
|
||||
print(f" BBH {subtask}: {len(bbh)} questions")
|
||||
except Exception as e:
|
||||
print(f" BBH {subtask}: FAILED ({e})")
|
||||
|
||||
eq = load_eq_questions()
|
||||
all_questions.extend(eq)
|
||||
print(f" EQ: {len(eq)} scenarios")
|
||||
|
||||
total = len(all_questions)
|
||||
print(f"\nTotal: {total} questions")
|
||||
print("Running evaluation...\n")
|
||||
|
||||
results_by_type = {}
|
||||
start_time = time.time()
|
||||
|
||||
for i, q in enumerate(all_questions):
|
||||
qtype = q["type"]
|
||||
if qtype not in results_by_type:
|
||||
results_by_type[qtype] = {"scores": [], "total": 0, "correct": 0}
|
||||
|
||||
response = query_model(q["prompt"], port,
|
||||
max_tokens=512 if qtype == "gsm8k" else 256)
|
||||
result = score_question(q, response)
|
||||
|
||||
results_by_type[qtype]["scores"].append(result["score"])
|
||||
results_by_type[qtype]["total"] += 1
|
||||
if result["score"] >= 0.99:
|
||||
results_by_type[qtype]["correct"] += 1
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
rate = (i + 1) / elapsed if elapsed > 0 else 0
|
||||
eta = (total - i - 1) / rate if rate > 0 else 0
|
||||
print(f"\r [{i+1}/{total}] {qtype:40s} "
|
||||
f"score={result['score']:.2f} "
|
||||
f"({elapsed:.0f}s elapsed, ~{eta:.0f}s remaining)", end="", flush=True)
|
||||
|
||||
print("\n")
|
||||
elapsed_total = time.time() - start_time
|
||||
|
||||
print("=" * 70)
|
||||
print(f"{'Probe':40s} {'Accuracy':>8} {'Avg Score':>10} {'N':>5}")
|
||||
print("-" * 70)
|
||||
|
||||
overall_scores = []
|
||||
summary = {}
|
||||
|
||||
for qtype in sorted(results_by_type.keys()):
|
||||
data = results_by_type[qtype]
|
||||
avg = sum(data["scores"]) / len(data["scores"]) if data["scores"] else 0
|
||||
acc = data["correct"] / data["total"] if data["total"] > 0 else 0
|
||||
print(f" {qtype:38s} {acc:>8.1%} {avg:>10.4f} {data['total']:>5}")
|
||||
overall_scores.extend(data["scores"])
|
||||
summary[qtype] = {"accuracy": acc, "avg_score": avg, "n": data["total"]}
|
||||
|
||||
overall_avg = sum(overall_scores) / len(overall_scores) if overall_scores else 0
|
||||
print("-" * 70)
|
||||
print(f" {'OVERALL':38s} {'':>8} {overall_avg:>10.4f} {len(overall_scores):>5}")
|
||||
print(f"\nCompleted in {elapsed_total:.0f}s")
|
||||
print("=" * 70)
|
||||
|
||||
return {
|
||||
"summary": summary,
|
||||
"overall": overall_avg,
|
||||
"elapsed": elapsed_total,
|
||||
"n_questions": len(overall_scores),
|
||||
}
|
||||
|
||||
|
||||
def compare_results(file1: str, file2: str):
|
||||
with open(file1) as f:
|
||||
r1 = json.load(f)
|
||||
with open(file2) as f:
|
||||
r2 = json.load(f)
|
||||
|
||||
print(f"\n{'Probe':40s} {'Base':>8} {'RYS':>8} {'Delta':>8}")
|
||||
print("-" * 70)
|
||||
|
||||
for qtype in sorted(set(list(r1["summary"].keys()) + list(r2["summary"].keys()))):
|
||||
s1 = r1["summary"].get(qtype, {}).get("avg_score", 0)
|
||||
s2 = r2["summary"].get(qtype, {}).get("avg_score", 0)
|
||||
delta = s2 - s1
|
||||
print(f" {qtype:38s} {s1:>8.4f} {s2:>8.4f} {delta:>+8.4f}")
|
||||
|
||||
print("-" * 70)
|
||||
delta_overall = r2["overall"] - r1["overall"]
|
||||
print(f" {'OVERALL':38s} {r1['overall']:>8.4f} {r2['overall']:>8.4f} {delta_overall:>+8.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Comprehensive RYS evaluation")
|
||||
parser.add_argument("--port", type=int, default=8089)
|
||||
parser.add_argument("--gsm8k-limit", type=int, default=50)
|
||||
parser.add_argument("--bbh-limit", type=int, default=50)
|
||||
parser.add_argument("--output", type=str, default=None,
|
||||
help="Save results to JSON file")
|
||||
parser.add_argument("--compare", nargs=2, metavar=("BASE", "RYS"),
|
||||
help="Compare two result files")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.compare:
|
||||
compare_results(args.compare[0], args.compare[1])
|
||||
return
|
||||
|
||||
results = run_full_evaluation(args.port, args.gsm8k_limit, args.bbh_limit)
|
||||
|
||||
if args.output:
|
||||
with open(args.output, "w") as f:
|
||||
json.dump(results, f, indent=2)
|
||||
print(f"\nSaved to {args.output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Add table
Add a link
Reference in a new issue