Update Hotpotqa

This commit is contained in:
Zhaoyang Yu 2024-09-10 18:27:20 +08:00
parent 4ce18d7f48
commit 68e87da378
2 changed files with 29 additions and 83 deletions

View file

@ -4,8 +4,12 @@ import aiofiles
import pandas as pd
import numpy as np
from typing import List, Tuple, Callable, Set
from collections import Counter
from tqdm.asyncio import tqdm_asyncio
from scipy.optimize import linear_sum_assignment
import string
import re
from examples.ags.benchmark.utils import generate_random_indices
@ -16,9 +20,10 @@ def is_number(text: str) -> bool:
except ValueError:
return False
def normalize_answer(text):
import re
import string
def normalize_answer(s):
"""
Normalize answers for evaluation.
"""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
@ -33,77 +38,24 @@ def normalize_answer(text):
def lower(text):
return text.lower()
def tokenize(text):
return re.split(" |-", text)
return white_space_fix(remove_articles(remove_punc(lower(s))))
def normalize_number(text: str) -> str:
if is_number(text):
return str(float(text))
else:
return text
parts = [
white_space_fix(remove_articles(normalize_number(remove_punc(lower(token)))))
for token in tokenize(text)
]
parts = [part for part in parts if part.strip()]
normalized = " ".join(parts).strip()
return normalized
def answer_to_bags(answer: str) -> Set[str]:
raw_spans = [answer]
normalized_spans = []
token_bags = []
for raw_span in raw_spans:
normalized_span = normalize_answer(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
def f1_score(prediction, ground_truth):
"""
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
between them and gets maximum metric values over all the answers.
Compute the F1 score between prediction and ground truth answers.
"""
scores = np.zeros([len(gold), len(predicted)])
for gold_index, gold_item in enumerate(gold):
for pred_index, pred_item in enumerate(predicted):
if match_numbers_if_present(gold_item, pred_item):
scores[gold_index, pred_index] = f1_score(pred_item, gold_item)
row_ind, col_ind = linear_sum_assignment(-scores)
max_scores = np.zeros([max(len(gold), len(predicted))])
for row, column in zip(row_ind, col_ind):
max_scores[row] = max(max_scores[row], scores[row, column])
return max_scores
def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
gold_numbers = set()
predicted_numbers = set()
for word in gold_bag:
if is_number(word):
gold_numbers.add(word)
for word in predicted_bag:
if is_number(word):
predicted_numbers.add(word)
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
return True
return False
def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
intersection = len(gold_bag.intersection(predicted_bag))
if not predicted_bag:
precision = 1.0
else:
precision = intersection / float(len(predicted_bag))
if not gold_bag:
recall = 1.0
else:
recall = intersection / float(len(gold_bag))
f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
prediction_tokens = normalize_answer(prediction).split()
ground_truth_tokens = normalize_answer(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
async def load_data(file_path: str, samples=20, total_length=1000) -> List[dict]:
data = []
async with aiofiles.open(file_path, mode="r") as file:
@ -120,12 +72,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
while retries < max_retries:
try:
prediction, supporting_sentences = await graph(input, context_str) if graph else "None"
predicted_bags = answer_to_bags(prediction)
gold_bags = answer_to_bags(expected_output)
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
score = np.mean(f1_per_bag)
prediction = await graph(input, context_str) if graph else "None"
score = f1_score(prediction, expected_output)
break
except Exception as e:
@ -135,11 +83,10 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
if retries == max_retries:
print("Maximum retries reached. Skipping this sample.")
prediction = None
supporting_sentences = None
score = 0
break
return input, prediction, expected_output, supporting_sentences, score
return input, prediction, expected_output, score
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50):
semaphore = asyncio.Semaphore(max_concurrent_tasks)
@ -156,9 +103,9 @@ async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurren
return await tqdm_asyncio.gather(*tasks, desc="Evaluating HotpotQA problems", total=len(data))
def save_results_to_csv(results: List[Tuple[str, str, str, str, float]], path: str) -> float:
def save_results_to_csv(results: List[Tuple[str, str, str, float]], path: str) -> float:
df = pd.DataFrame(
results, columns=["question", "prediction", "expected_output", "supporting_sentences", "score"]
results, columns=["question", "prediction", "expected_output", "score"]
)
average_score = df["score"].mean()

View file

@ -19,7 +19,6 @@ HOTPOTQA_PROMPT = """
class GenerateOp(BaseModel):
answer: str = Field(default="", description="问题的答案")
supporting_sentences: str = Field(default="", description="支持性句子")
class CoTGenerate(Operator):
def __init__(self, llm: LLM, name: str = "Generate"):
@ -32,7 +31,7 @@ class CoTGenerate(Operator):
fill_kwargs["mode"] = mode
node = await ActionNode.from_pydantic(GenerateOp).fill(**fill_kwargs)
response = node.instruct_content.model_dump()
return response["answer"], response["supporting_sentences"]
return response["answer"]
class CoTSolveGraph(SolveGraph):
def __init__(self, name: str, llm_config, dataset: str):
@ -40,8 +39,8 @@ class CoTSolveGraph(SolveGraph):
self.cot_generate = CoTGenerate(self.llm)
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
answer, supporting_sentences = await self.cot_generate(question, context, mode="context_fill")
return answer, supporting_sentences
answer = await self.cot_generate(question, context, mode="context_fill")
return answer
if __name__ == "__main__":
async def main():