Update QA

This commit is contained in:
Zhaoyang Yu 2024-09-10 18:59:13 +08:00
parent 257b994409
commit 445a2e6048
4 changed files with 16 additions and 88 deletions

View file

@ -10,6 +10,9 @@ from tqdm.asyncio import tqdm_asyncio
from examples.ags.benchmark.utils import generate_random_indices
global cost
cost = 0
def _remove_articles(text: str) -> str:
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
@ -60,90 +63,6 @@ def _normalize_answer(text: str) -> str:
normalized = " ".join(parts).strip()
return normalized
def _answer_to_bags(
answer: Union[str, List[str], Tuple[str, ...]]
) -> Tuple[List[str], List[Set[str]]]:
if isinstance(answer, (list, tuple)):
raw_spans = answer
else:
raw_spans = [answer]
normalized_spans: List[str] = []
token_bags = []
for raw_span in raw_spans:
normalized_span = _normalize_answer(raw_span)
normalized_spans.append(normalized_span)
token_bags.append(set(normalized_span.split()))
return normalized_spans, token_bags
#!/usr/bin/python
from collections import defaultdict
from typing import Any, Dict, List, Set, Tuple, Union, Optional
import json
import argparse
import string
import re
import numpy as np
from scipy.optimize import linear_sum_assignment
# From here through _normalize_answer was originally copied from:
# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
# Then cleaned up and modified a bit.
def _remove_articles(text: str) -> str:
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
return re.sub(regex, " ", text)
def _white_space_fix(text: str) -> str:
return " ".join(text.split())
EXCLUDE = set(string.punctuation)
def _remove_punc(text: str) -> str:
if not _is_number(text):
return "".join(ch for ch in text if ch not in EXCLUDE)
else:
return text
def _lower(text: str) -> str:
return text.lower()
def _tokenize(text: str) -> List[str]:
return re.split(" |-", text)
def _normalize_answer(text: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
parts = [
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
for token in _tokenize(text)
]
parts = [part for part in parts if part.strip()]
normalized = " ".join(parts).strip()
return normalized
def _is_number(text: str) -> bool:
try:
float(text)
return True
except ValueError:
return False
def _normalize_number(text: str) -> str:
if _is_number(text):
return str(float(text))
else:
return text
def _answer_to_bags(
answer: Union[str, List[str], Tuple[str, ...]]
@ -307,7 +226,8 @@ async def evaluate_problem(question: str, passage: str, answers: List[Dict[str,
while retries < max_retries:
try:
prediction = await graph(question, passage)
global cost
prediction, cost = await graph(question, passage)
max_score = 0.0
@ -372,4 +292,6 @@ async def drop_evaluation(graph: Callable, file_path: str, samples: int, path: s
results = await evaluate_all_passages(data, graph, max_concurrent_tasks=20)
average_score = save_results_to_csv(results, path=path)
print(f"Average score on DROP dataset: {average_score:.5f}")
global cost
print(f"Total cost: {cost}")
return average_score

View file

@ -13,6 +13,9 @@ import re
from examples.ags.benchmark.utils import generate_random_indices
global cost
cost = 0
def is_number(text: str) -> bool:
try:
float(text)
@ -72,7 +75,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
while retries < max_retries:
try:
prediction = await graph(input, context_str) if graph else "None"
global cost
prediction, cost = await graph(input, context_str) if graph else "None"
score = f1_score(prediction, expected_output)
break
@ -120,4 +124,6 @@ async def hotpotqa_evaluation(graph: Callable, file_path: str, samples: int, pat
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
average_score = save_results_to_csv(results, path=path)
print(f"Average score on HotpotQA dataset: {average_score:.5f}")
global cost
print(f"Total cost: {cost}")
return average_score

View file

@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph):
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
answer = await self.cot_generate(question, context, mode="context_fill")
return answer
return answer, self.llm.cost_manager.total_cost
if __name__ == "__main__":
async def main():

View file

@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph):
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
answer = await self.cot_generate(question, context, mode="context_fill")
return answer
return answer, self.llm.cost_manager.total_cost
if __name__ == "__main__":
async def main():