From 68e87da378e5407d16c87370691083193e5c7042 Mon Sep 17 00:00:00 2001 From: Zhaoyang Yu Date: Tue, 10 Sep 2024 18:27:20 +0800 Subject: [PATCH] Update Hotpotqa --- examples/ags/benchmark/hotpotqa.py | 105 +++++------------- .../ags/experiments/baselines/cot_hotpotqa.py | 7 +- 2 files changed, 29 insertions(+), 83 deletions(-) diff --git a/examples/ags/benchmark/hotpotqa.py b/examples/ags/benchmark/hotpotqa.py index 375882511..1d977c11b 100644 --- a/examples/ags/benchmark/hotpotqa.py +++ b/examples/ags/benchmark/hotpotqa.py @@ -4,8 +4,12 @@ import aiofiles import pandas as pd import numpy as np from typing import List, Tuple, Callable, Set +from collections import Counter from tqdm.asyncio import tqdm_asyncio from scipy.optimize import linear_sum_assignment +import string +import re + from examples.ags.benchmark.utils import generate_random_indices @@ -16,9 +20,10 @@ def is_number(text: str) -> bool: except ValueError: return False -def normalize_answer(text): - import re - import string +def normalize_answer(s): + """ + Normalize answers for evaluation. + """ def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) @@ -33,77 +38,24 @@ def normalize_answer(text): def lower(text): return text.lower() - def tokenize(text): - return re.split(" |-", text) + return white_space_fix(remove_articles(remove_punc(lower(s)))) - def normalize_number(text: str) -> str: - if is_number(text): - return str(float(text)) - else: - return text - - parts = [ - white_space_fix(remove_articles(normalize_number(remove_punc(lower(token))))) - for token in tokenize(text) - ] - parts = [part for part in parts if part.strip()] - normalized = " ".join(parts).strip() - return normalized - -def answer_to_bags(answer: str) -> Set[str]: - raw_spans = [answer] - - normalized_spans = [] - token_bags = [] - for raw_span in raw_spans: - normalized_span = normalize_answer(raw_span) - normalized_spans.append(normalized_span) - token_bags.append(set(normalized_span.split())) - return normalized_spans, token_bags - -def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: +def f1_score(prediction, ground_truth): """ - Takes gold and predicted answer sets and first finds the optimal 1-1 alignment - between them and gets maximum metric values over all the answers. + Compute the F1 score between prediction and ground truth answers. """ - scores = np.zeros([len(gold), len(predicted)]) - for gold_index, gold_item in enumerate(gold): - for pred_index, pred_item in enumerate(predicted): - if match_numbers_if_present(gold_item, pred_item): - scores[gold_index, pred_index] = f1_score(pred_item, gold_item) - row_ind, col_ind = linear_sum_assignment(-scores) - - max_scores = np.zeros([max(len(gold), len(predicted))]) - for row, column in zip(row_ind, col_ind): - max_scores[row] = max(max_scores[row], scores[row, column]) - return max_scores - -def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: - gold_numbers = set() - predicted_numbers = set() - for word in gold_bag: - if is_number(word): - gold_numbers.add(word) - for word in predicted_bag: - if is_number(word): - predicted_numbers.add(word) - if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): - return True - return False - -def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float: - intersection = len(gold_bag.intersection(predicted_bag)) - if not predicted_bag: - precision = 1.0 - else: - precision = intersection / float(len(predicted_bag)) - if not gold_bag: - recall = 1.0 - else: - recall = intersection / float(len(gold_bag)) - f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 + prediction_tokens = normalize_answer(prediction).split() + ground_truth_tokens = normalize_answer(ground_truth).split() + common = Counter(prediction_tokens) & Counter(ground_truth_tokens) + num_same = sum(common.values()) + if num_same == 0: + return 0 + precision = 1.0 * num_same / len(prediction_tokens) + recall = 1.0 * num_same / len(ground_truth_tokens) + f1 = (2 * precision * recall) / (precision + recall) return f1 + async def load_data(file_path: str, samples=20, total_length=1000) -> List[dict]: data = [] async with aiofiles.open(file_path, mode="r") as file: @@ -120,12 +72,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect while retries < max_retries: try: - prediction, supporting_sentences = await graph(input, context_str) if graph else "None" - predicted_bags = answer_to_bags(prediction) - gold_bags = answer_to_bags(expected_output) - - f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) - score = np.mean(f1_per_bag) + prediction = await graph(input, context_str) if graph else "None" + score = f1_score(prediction, expected_output) break except Exception as e: @@ -135,11 +83,10 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect if retries == max_retries: print("Maximum retries reached. Skipping this sample.") prediction = None - supporting_sentences = None score = 0 break - return input, prediction, expected_output, supporting_sentences, score + return input, prediction, expected_output, score async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50): semaphore = asyncio.Semaphore(max_concurrent_tasks) @@ -156,9 +103,9 @@ async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurren return await tqdm_asyncio.gather(*tasks, desc="Evaluating HotpotQA problems", total=len(data)) -def save_results_to_csv(results: List[Tuple[str, str, str, str, float]], path: str) -> float: +def save_results_to_csv(results: List[Tuple[str, str, str, float]], path: str) -> float: df = pd.DataFrame( - results, columns=["question", "prediction", "expected_output", "supporting_sentences", "score"] + results, columns=["question", "prediction", "expected_output", "score"] ) average_score = df["score"].mean() diff --git a/examples/ags/experiments/baselines/cot_hotpotqa.py b/examples/ags/experiments/baselines/cot_hotpotqa.py index 22bd69438..b3a919fff 100644 --- a/examples/ags/experiments/baselines/cot_hotpotqa.py +++ b/examples/ags/experiments/baselines/cot_hotpotqa.py @@ -19,7 +19,6 @@ HOTPOTQA_PROMPT = """ class GenerateOp(BaseModel): answer: str = Field(default="", description="问题的答案") - supporting_sentences: str = Field(default="", description="支持性句子") class CoTGenerate(Operator): def __init__(self, llm: LLM, name: str = "Generate"): @@ -32,7 +31,7 @@ class CoTGenerate(Operator): fill_kwargs["mode"] = mode node = await ActionNode.from_pydantic(GenerateOp).fill(**fill_kwargs) response = node.instruct_content.model_dump() - return response["answer"], response["supporting_sentences"] + return response["answer"] class CoTSolveGraph(SolveGraph): def __init__(self, name: str, llm_config, dataset: str): @@ -40,8 +39,8 @@ class CoTSolveGraph(SolveGraph): self.cot_generate = CoTGenerate(self.llm) async def __call__(self, question: str, context: str) -> Tuple[str, str]: - answer, supporting_sentences = await self.cot_generate(question, context, mode="context_fill") - return answer, supporting_sentences + answer = await self.cot_generate(question, context, mode="context_fill") + return answer if __name__ == "__main__": async def main():