mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 20:32:38 +02:00
Update QA
This commit is contained in:
parent
257b994409
commit
445a2e6048
4 changed files with 16 additions and 88 deletions
|
|
@ -10,6 +10,9 @@ from tqdm.asyncio import tqdm_asyncio
|
|||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
global cost
|
||||
cost = 0
|
||||
|
||||
def _remove_articles(text: str) -> str:
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
|
@ -60,90 +63,6 @@ def _normalize_answer(text: str) -> str:
|
|||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
def _answer_to_bags(
|
||||
answer: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> Tuple[List[str], List[Set[str]]]:
|
||||
if isinstance(answer, (list, tuple)):
|
||||
raw_spans = answer
|
||||
else:
|
||||
raw_spans = [answer]
|
||||
normalized_spans: List[str] = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = _normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
#!/usr/bin/python
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Tuple, Union, Optional
|
||||
import json
|
||||
import argparse
|
||||
import string
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
# From here through _normalize_answer was originally copied from:
|
||||
# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
|
||||
# Then cleaned up and modified a bit.
|
||||
def _remove_articles(text: str) -> str:
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
|
||||
def _white_space_fix(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
EXCLUDE = set(string.punctuation)
|
||||
|
||||
|
||||
def _remove_punc(text: str) -> str:
|
||||
if not _is_number(text):
|
||||
return "".join(ch for ch in text if ch not in EXCLUDE)
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _lower(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.split(" |-", text)
|
||||
|
||||
|
||||
def _normalize_answer(text: str) -> str:
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
parts = [
|
||||
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
|
||||
for token in _tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_number(text: str) -> str:
|
||||
if _is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _answer_to_bags(
|
||||
answer: Union[str, List[str], Tuple[str, ...]]
|
||||
|
|
@ -307,7 +226,8 @@ async def evaluate_problem(question: str, passage: str, answers: List[Dict[str,
|
|||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction = await graph(question, passage)
|
||||
global cost
|
||||
prediction, cost = await graph(question, passage)
|
||||
|
||||
|
||||
max_score = 0.0
|
||||
|
|
@ -372,4 +292,6 @@ async def drop_evaluation(graph: Callable, file_path: str, samples: int, path: s
|
|||
results = await evaluate_all_passages(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on DROP dataset: {average_score:.5f}")
|
||||
global cost
|
||||
print(f"Total cost: {cost}")
|
||||
return average_score
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ import re
|
|||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
global cost
|
||||
cost = 0
|
||||
|
||||
def is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
|
|
@ -72,7 +75,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
|
|||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction = await graph(input, context_str) if graph else "None"
|
||||
global cost
|
||||
prediction, cost = await graph(input, context_str) if graph else "None"
|
||||
score = f1_score(prediction, expected_output)
|
||||
|
||||
break
|
||||
|
|
@ -120,4 +124,6 @@ async def hotpotqa_evaluation(graph: Callable, file_path: str, samples: int, pat
|
|||
results = await evaluate_all_problems(data, graph, max_concurrent_tasks=20)
|
||||
average_score = save_results_to_csv(results, path=path)
|
||||
print(f"Average score on HotpotQA dataset: {average_score:.5f}")
|
||||
global cost
|
||||
print(f"Total cost: {cost}")
|
||||
return average_score
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph):
|
|||
|
||||
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
|
||||
answer = await self.cot_generate(question, context, mode="context_fill")
|
||||
return answer
|
||||
return answer, self.llm.cost_manager.total_cost
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class CoTSolveGraph(SolveGraph):
|
|||
|
||||
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
|
||||
answer = await self.cot_generate(question, context, mode="context_fill")
|
||||
return answer
|
||||
return answer, self.llm.cost_manager.total_cost
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue