From 445a2e6048846221e7dfa8d9c478c575aab93464 Mon Sep 17 00:00:00 2001 From: Zhaoyang Yu Date: Tue, 10 Sep 2024 18:59:13 +0800 Subject: [PATCH] Update QA --- examples/ags/benchmark/drop.py | 92 ++----------------- examples/ags/benchmark/hotpotqa.py | 8 +- .../ags/experiments/baselines/cot_drop.py | 2 +- .../ags/experiments/baselines/cot_hotpotqa.py | 2 +- 4 files changed, 16 insertions(+), 88 deletions(-) diff --git a/examples/ags/benchmark/drop.py b/examples/ags/benchmark/drop.py index 7fc99134a..ff3a8065b 100644 --- a/examples/ags/benchmark/drop.py +++ b/examples/ags/benchmark/drop.py @@ -10,6 +10,9 @@ from tqdm.asyncio import tqdm_asyncio from examples.ags.benchmark.utils import generate_random_indices +global cost +cost = 0 + def _remove_articles(text: str) -> str: regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) @@ -60,90 +63,6 @@ def _normalize_answer(text: str) -> str: normalized = " ".join(parts).strip() return normalized -def _answer_to_bags( - answer: Union[str, List[str], Tuple[str, ...]] -) -> Tuple[List[str], List[Set[str]]]: - if isinstance(answer, (list, tuple)): - raw_spans = answer - else: - raw_spans = [answer] - normalized_spans: List[str] = [] - token_bags = [] - for raw_span in raw_spans: - normalized_span = _normalize_answer(raw_span) - normalized_spans.append(normalized_span) - token_bags.append(set(normalized_span.split())) - return normalized_spans, token_bags - -#!/usr/bin/python - -from collections import defaultdict -from typing import Any, Dict, List, Set, Tuple, Union, Optional -import json -import argparse -import string -import re - -import numpy as np -from scipy.optimize import linear_sum_assignment - - -# From here through _normalize_answer was originally copied from: -# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ -# Then cleaned up and modified a bit. -def _remove_articles(text: str) -> str: - regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) - return re.sub(regex, " ", text) - - -def _white_space_fix(text: str) -> str: - return " ".join(text.split()) - - -EXCLUDE = set(string.punctuation) - - -def _remove_punc(text: str) -> str: - if not _is_number(text): - return "".join(ch for ch in text if ch not in EXCLUDE) - else: - return text - - -def _lower(text: str) -> str: - return text.lower() - - -def _tokenize(text: str) -> List[str]: - return re.split(" |-", text) - - -def _normalize_answer(text: str) -> str: - """Lower text and remove punctuation, articles and extra whitespace.""" - - parts = [ - _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) - for token in _tokenize(text) - ] - parts = [part for part in parts if part.strip()] - normalized = " ".join(parts).strip() - return normalized - - -def _is_number(text: str) -> bool: - try: - float(text) - return True - except ValueError: - return False - - -def _normalize_number(text: str) -> str: - if _is_number(text): - return str(float(text)) - else: - return text - def _answer_to_bags( answer: Union[str, List[str], Tuple[str, ...]] @@ -307,7 +226,8 @@ async def evaluate_problem(question: str, passage: str, answers: List[Dict[str, while retries < max_retries: try: - prediction = await graph(question, passage) + global cost + prediction, cost = await graph(question, passage) max_score = 0.0 @@ -372,4 +292,6 @@ async def drop_evaluation(graph: Callable, file_path: str, samples: int, path: s results = await evaluate_all_passages(data, graph, max_concurrent_tasks=20) average_score = save_results_to_csv(results, path=path) print(f"Average score on DROP dataset: {average_score:.5f}") + global cost + print(f"Total cost: {cost}") return average_score diff --git a/examples/ags/benchmark/hotpotqa.py b/examples/ags/benchmark/hotpotqa.py index 1d977c11b..19873aa37 100644 --- a/examples/ags/benchmark/hotpotqa.py +++ b/examples/ags/benchmark/hotpotqa.py @@ -13,6 +13,9 @@ import re from examples.ags.benchmark.utils import generate_random_indices +global cost +cost = 0 + def is_number(text: str) -> bool: try: float(text) @@ -72,7 +75,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect while retries < max_retries: try: - prediction = await graph(input, context_str) if graph else "None" + global cost + prediction, cost = await graph(input, context_str) if graph else "None" score = f1_score(prediction, expected_output) break @@ -120,4 +124,6 @@ async def hotpotqa_evaluation(graph: Callable, file_path: str, samples: int, pat results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20) average_score = save_results_to_csv(results, path=path) print(f"Average score on HotpotQA dataset: {average_score:.5f}") + global cost + print(f"Total cost: {cost}") return average_score diff --git a/examples/ags/experiments/baselines/cot_drop.py b/examples/ags/experiments/baselines/cot_drop.py index d0d3ecb34..83513d07e 100644 --- a/examples/ags/experiments/baselines/cot_drop.py +++ b/examples/ags/experiments/baselines/cot_drop.py @@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph): async def __call__(self, question: str, context: str) -> Tuple[str, str]: answer = await self.cot_generate(question, context, mode="context_fill") - return answer + return answer, self.llm.cost_manager.total_cost if __name__ == "__main__": async def main(): diff --git a/examples/ags/experiments/baselines/cot_hotpotqa.py b/examples/ags/experiments/baselines/cot_hotpotqa.py index b3a919fff..e9f5592de 100644 --- a/examples/ags/experiments/baselines/cot_hotpotqa.py +++ b/examples/ags/experiments/baselines/cot_hotpotqa.py @@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph): async def __call__(self, question: str, context: str) -> Tuple[str, str]: answer = await self.cot_generate(question, context, mode="context_fill") - return answer + return answer, self.llm.cost_manager.total_cost if __name__ == "__main__": async def main():