mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge branch 'main' of https://github.com/didiforgithub/MetaGPT-MathAI
This commit is contained in:
commit
7f45ef6231
3 changed files with 268 additions and 156 deletions
|
|
@ -1,65 +1,166 @@
|
|||
import json
|
||||
import asyncio
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Callable, Dict, Any, Set
|
||||
import string
|
||||
import re
|
||||
from typing import List, Tuple, Callable, Dict, Any, Set, Union
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
def is_number(text: str) -> bool:
|
||||
def _remove_articles(text: str) -> str:
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
|
||||
def _white_space_fix(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
EXCLUDE = set(string.punctuation)
|
||||
|
||||
def _is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def normalize_answer(text):
|
||||
import re
|
||||
import string
|
||||
def _normalize_number(text: str) -> str:
|
||||
if _is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
def _remove_punc(text: str) -> str:
|
||||
if not _is_number(text):
|
||||
return "".join(ch for ch in text if ch not in EXCLUDE)
|
||||
else:
|
||||
return text
|
||||
|
||||
def white_space_fix(text):
|
||||
return " ".join(text.split())
|
||||
|
||||
def remove_punc(text):
|
||||
exclude = set(string.punctuation)
|
||||
return "".join(ch for ch in text if ch not in exclude)
|
||||
def _lower(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
def tokenize(text):
|
||||
return re.split(" |-", text)
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.split(" |-", text)
|
||||
|
||||
def normalize_number(text: str) -> str:
|
||||
if is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
def _normalize_answer(text: str) -> str:
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
parts = [
|
||||
white_space_fix(remove_articles(normalize_number(remove_punc(lower(token)))))
|
||||
for token in tokenize(text)
|
||||
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
|
||||
for token in _tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
def answer_to_bags(answer: str) -> Set[str]:
|
||||
raw_spans = [answer]
|
||||
|
||||
normalized_spans = []
|
||||
def _answer_to_bags(
|
||||
answer: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> Tuple[List[str], List[Set[str]]]:
|
||||
if isinstance(answer, (list, tuple)):
|
||||
raw_spans = answer
|
||||
else:
|
||||
raw_spans = [answer]
|
||||
normalized_spans: List[str] = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = normalize_answer(raw_span)
|
||||
normalized_span = _normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
#!/usr/bin/python
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Set, Tuple, Union, Optional
|
||||
import json
|
||||
import argparse
|
||||
import string
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
|
||||
|
||||
# From here through _normalize_answer was originally copied from:
|
||||
# https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
|
||||
# Then cleaned up and modified a bit.
|
||||
def _remove_articles(text: str) -> str:
|
||||
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
||||
return re.sub(regex, " ", text)
|
||||
|
||||
|
||||
def _white_space_fix(text: str) -> str:
|
||||
return " ".join(text.split())
|
||||
|
||||
|
||||
EXCLUDE = set(string.punctuation)
|
||||
|
||||
|
||||
def _remove_punc(text: str) -> str:
|
||||
if not _is_number(text):
|
||||
return "".join(ch for ch in text if ch not in EXCLUDE)
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _lower(text: str) -> str:
|
||||
return text.lower()
|
||||
|
||||
|
||||
def _tokenize(text: str) -> List[str]:
|
||||
return re.split(" |-", text)
|
||||
|
||||
|
||||
def _normalize_answer(text: str) -> str:
|
||||
"""Lower text and remove punctuation, articles and extra whitespace."""
|
||||
|
||||
parts = [
|
||||
_white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token)))))
|
||||
for token in _tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
|
||||
def _is_number(text: str) -> bool:
|
||||
try:
|
||||
float(text)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _normalize_number(text: str) -> str:
|
||||
if _is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
|
||||
def _answer_to_bags(
|
||||
answer: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> Tuple[List[str], List[Set[str]]]:
|
||||
if isinstance(answer, (list, tuple)):
|
||||
raw_spans = answer
|
||||
else:
|
||||
raw_spans = [answer]
|
||||
normalized_spans: List[str] = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = _normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
|
|
@ -68,8 +169,8 @@ def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
|||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = f1_score(pred_item, gold_item)
|
||||
if _match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
|
|
@ -77,20 +178,8 @@ def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
|||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
|
|
@ -100,9 +189,108 @@ def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
|||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
|
||||
f1 = (
|
||||
(2 * precision * recall) / (precision + recall)
|
||||
if not (precision == 0.0 and recall == 0.0)
|
||||
else 0.0
|
||||
)
|
||||
return f1
|
||||
|
||||
def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if _is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if _is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
else:
|
||||
precision = intersection / float(len(predicted_bag))
|
||||
if not gold_bag:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (
|
||||
(2 * precision * recall) / (precision + recall)
|
||||
if not (precision == 0.0 and recall == 0.0)
|
||||
else 0.0
|
||||
)
|
||||
return f1
|
||||
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
between them and gets maximum metric values over all the answers.
|
||||
"""
|
||||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if _match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
for row, column in zip(row_ind, col_ind):
|
||||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
def get_metrics(
|
||||
predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]]
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Takes a predicted answer and a gold answer (that are both either a string or a list of
|
||||
strings), and returns exact match and the DROP F1 metric for the prediction. If you are
|
||||
writing a script for evaluating objects in memory (say, the output of predictions during
|
||||
validation, or while training), this is the function you want to call, after using
|
||||
:func:`answer_json_to_strings` when reading the gold answer from the released data file.
|
||||
"""
|
||||
predicted_bags = _answer_to_bags(predicted)
|
||||
gold_bags = _answer_to_bags(gold)
|
||||
|
||||
if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]):
|
||||
exact_match = 1.0
|
||||
else:
|
||||
exact_match = 0.0
|
||||
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
f1 = np.mean(f1_per_bag)
|
||||
f1 = round(f1, 2)
|
||||
return exact_match, f1
|
||||
|
||||
def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
|
||||
"""
|
||||
Takes an answer JSON blob from the DROP data release and converts it into strings used for
|
||||
evaluation.
|
||||
"""
|
||||
if "number" in answer and answer["number"]:
|
||||
return tuple([str(answer["number"])]), "number"
|
||||
elif "spans" in answer and answer["spans"]:
|
||||
return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
|
||||
elif "date" in answer:
|
||||
return (
|
||||
tuple(
|
||||
[
|
||||
"{0} {1} {2}".format(
|
||||
answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
|
||||
)
|
||||
]
|
||||
),
|
||||
"date",
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}"
|
||||
)
|
||||
|
||||
def load_data(file_path: str, samples: int) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
with open(file_path, mode="r") as file:
|
||||
data = json.load(file)
|
||||
|
|
@ -112,37 +300,15 @@ def load_data(file_path: str, samples: int) -> List[Tuple[str, Dict[str, Any]]]:
|
|||
data = [data[i] for i in random_indices]
|
||||
return data
|
||||
|
||||
async def evaluate_problem(question: str, passage: str, answers: List[Dict[str, Any]], graph: Callable) -> Tuple[str, str, float, str]:
|
||||
def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
|
||||
if "number" in answer and answer["number"]:
|
||||
return tuple([str(answer["number"])]), "number"
|
||||
elif "spans" in answer and answer["spans"]:
|
||||
return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
|
||||
elif "date" in answer:
|
||||
return (
|
||||
tuple(
|
||||
[
|
||||
"{0} {1} {2}".format(
|
||||
answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
|
||||
)
|
||||
]
|
||||
),
|
||||
"date",
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}")
|
||||
async def evaluate_problem(question: str, passage: str, answers: List[Dict[str, Any]], graph: Callable) -> Tuple[str, str, float]:
|
||||
|
||||
prediction = await graph(question, passage)
|
||||
cost = prediction[1] # 添加这行来获取cost
|
||||
prediction = prediction[0] # 修改这行以获取实际的预测结果
|
||||
max_retries = 5
|
||||
retries = 0
|
||||
|
||||
def get_f1_score(prediction: str, golden_answer: str) -> float:
|
||||
predicted_bags = answer_to_bags(prediction)
|
||||
gold_bags = answer_to_bags(golden_answer)
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction = await graph(question, passage)
|
||||
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
score = np.mean(f1_per_bag)
|
||||
return score
|
||||
|
||||
max_score = 0.0
|
||||
best_answer = None
|
||||
|
|
@ -167,7 +333,7 @@ async def evaluate_all_passages(annotations: List[Tuple[str, Dict[str, Any]]], g
|
|||
question = qa_pair["question"]
|
||||
answers = [qa_pair["answer"]]
|
||||
if "validated_answers" in qa_pair and qa_pair["validated_answers"]:
|
||||
answers.extend(qa_pair["validated_answers"])
|
||||
answers += qa_pair["validated_answers"]
|
||||
best_answer, prediction, score, cost = await evaluate_problem(question, passage, answers, graph)
|
||||
results.append([id, question, prediction, best_answer, score, cost]) # 修改这行以包含cost
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,12 @@ import aiofiles
|
|||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import List, Tuple, Callable, Set
|
||||
from collections import Counter
|
||||
from tqdm.asyncio import tqdm_asyncio
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
import string
|
||||
import re
|
||||
|
||||
|
||||
from examples.ags.benchmark.utils import generate_random_indices
|
||||
|
||||
|
|
@ -16,9 +20,10 @@ def is_number(text: str) -> bool:
|
|||
except ValueError:
|
||||
return False
|
||||
|
||||
def normalize_answer(text):
|
||||
import re
|
||||
import string
|
||||
def normalize_answer(s):
|
||||
"""
|
||||
Normalize answers for evaluation.
|
||||
"""
|
||||
|
||||
def remove_articles(text):
|
||||
return re.sub(r"\b(a|an|the)\b", " ", text)
|
||||
|
|
@ -33,77 +38,24 @@ def normalize_answer(text):
|
|||
def lower(text):
|
||||
return text.lower()
|
||||
|
||||
def tokenize(text):
|
||||
return re.split(" |-", text)
|
||||
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
||||
|
||||
def normalize_number(text: str) -> str:
|
||||
if is_number(text):
|
||||
return str(float(text))
|
||||
else:
|
||||
return text
|
||||
|
||||
parts = [
|
||||
white_space_fix(remove_articles(normalize_number(remove_punc(lower(token)))))
|
||||
for token in tokenize(text)
|
||||
]
|
||||
parts = [part for part in parts if part.strip()]
|
||||
normalized = " ".join(parts).strip()
|
||||
return normalized
|
||||
|
||||
def answer_to_bags(answer: str) -> Set[str]:
|
||||
raw_spans = [answer]
|
||||
|
||||
normalized_spans = []
|
||||
token_bags = []
|
||||
for raw_span in raw_spans:
|
||||
normalized_span = normalize_answer(raw_span)
|
||||
normalized_spans.append(normalized_span)
|
||||
token_bags.append(set(normalized_span.split()))
|
||||
return normalized_spans, token_bags
|
||||
|
||||
def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]:
|
||||
def f1_score(prediction, ground_truth):
|
||||
"""
|
||||
Takes gold and predicted answer sets and first finds the optimal 1-1 alignment
|
||||
between them and gets maximum metric values over all the answers.
|
||||
Compute the F1 score between prediction and ground truth answers.
|
||||
"""
|
||||
scores = np.zeros([len(gold), len(predicted)])
|
||||
for gold_index, gold_item in enumerate(gold):
|
||||
for pred_index, pred_item in enumerate(predicted):
|
||||
if match_numbers_if_present(gold_item, pred_item):
|
||||
scores[gold_index, pred_index] = f1_score(pred_item, gold_item)
|
||||
row_ind, col_ind = linear_sum_assignment(-scores)
|
||||
|
||||
max_scores = np.zeros([max(len(gold), len(predicted))])
|
||||
for row, column in zip(row_ind, col_ind):
|
||||
max_scores[row] = max(max_scores[row], scores[row, column])
|
||||
return max_scores
|
||||
|
||||
def match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool:
|
||||
gold_numbers = set()
|
||||
predicted_numbers = set()
|
||||
for word in gold_bag:
|
||||
if is_number(word):
|
||||
gold_numbers.add(word)
|
||||
for word in predicted_bag:
|
||||
if is_number(word):
|
||||
predicted_numbers.add(word)
|
||||
if (not gold_numbers) or gold_numbers.intersection(predicted_numbers):
|
||||
return True
|
||||
return False
|
||||
|
||||
def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float:
|
||||
intersection = len(gold_bag.intersection(predicted_bag))
|
||||
if not predicted_bag:
|
||||
precision = 1.0
|
||||
else:
|
||||
precision = intersection / float(len(predicted_bag))
|
||||
if not gold_bag:
|
||||
recall = 1.0
|
||||
else:
|
||||
recall = intersection / float(len(gold_bag))
|
||||
f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0
|
||||
prediction_tokens = normalize_answer(prediction).split()
|
||||
ground_truth_tokens = normalize_answer(ground_truth).split()
|
||||
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
||||
num_same = sum(common.values())
|
||||
if num_same == 0:
|
||||
return 0
|
||||
precision = 1.0 * num_same / len(prediction_tokens)
|
||||
recall = 1.0 * num_same / len(ground_truth_tokens)
|
||||
f1 = (2 * precision * recall) / (precision + recall)
|
||||
return f1
|
||||
|
||||
|
||||
async def load_data(file_path: str, samples=20, total_length=1000) -> List[dict]:
|
||||
data = []
|
||||
async with aiofiles.open(file_path, mode="r") as file:
|
||||
|
|
@ -120,12 +72,8 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
|
|||
|
||||
while retries < max_retries:
|
||||
try:
|
||||
prediction, supporting_sentences, cost = await graph(input, context_str) if graph else ("None", None, 0)
|
||||
predicted_bags = answer_to_bags(prediction)
|
||||
gold_bags = answer_to_bags(expected_output)
|
||||
|
||||
f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1])
|
||||
score = np.mean(f1_per_bag)
|
||||
prediction, cost = await graph(input, context_str) if graph else ("None", None, 0)
|
||||
score = f1_score(prediction, expected_output)
|
||||
|
||||
break
|
||||
except Exception as e:
|
||||
|
|
@ -135,12 +83,11 @@ async def evaluate_problem(input: str, context_str: str, graph: Callable, expect
|
|||
if retries == max_retries:
|
||||
print("Maximum retries reached. Skipping this sample.")
|
||||
prediction = None
|
||||
supporting_sentences = None
|
||||
score = 0
|
||||
cost = 0
|
||||
break
|
||||
|
||||
return input, prediction, expected_output, supporting_sentences, score, cost
|
||||
return input, prediction, expected_output, score, cost
|
||||
|
||||
async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurrent_tasks: int = 50):
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
|
|
@ -157,9 +104,9 @@ async def evaluate_all_problems(data: List[dict], graph: Callable, max_concurren
|
|||
|
||||
return await tqdm_asyncio.gather(*tasks, desc="Evaluating HotpotQA problems", total=len(data))
|
||||
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, str, float, str]], path: str) -> Tuple[float, float]:
|
||||
def save_results_to_csv(results: List[Tuple[str, str, str, float, str]], path: str) -> Tuple[float, float]:
|
||||
df = pd.DataFrame(
|
||||
results, columns=["question", "prediction", "expected_output", "supporting_sentences", "score", "cost"]
|
||||
results, columns=["question", "prediction", "expected_output", "score", "cost"]
|
||||
)
|
||||
average_score = df["score"].mean()
|
||||
total_cost = df["cost"].iloc[-1]
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ HOTPOTQA_PROMPT = """
|
|||
|
||||
class GenerateOp(BaseModel):
|
||||
answer: str = Field(default="", description="问题的答案")
|
||||
supporting_sentences: str = Field(default="", description="支持性句子")
|
||||
|
||||
class CoTGenerate(Operator):
|
||||
def __init__(self, llm: LLM, name: str = "Generate"):
|
||||
|
|
@ -32,7 +31,7 @@ class CoTGenerate(Operator):
|
|||
fill_kwargs["mode"] = mode
|
||||
node = await ActionNode.from_pydantic(GenerateOp).fill(**fill_kwargs)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response["answer"], response["supporting_sentences"]
|
||||
return response["answer"]
|
||||
|
||||
class CoTSolveGraph(SolveGraph):
|
||||
def __init__(self, name: str, llm_config, dataset: str):
|
||||
|
|
@ -40,8 +39,8 @@ class CoTSolveGraph(SolveGraph):
|
|||
self.cot_generate = CoTGenerate(self.llm)
|
||||
|
||||
async def __call__(self, question: str, context: str) -> Tuple[str, str]:
|
||||
answer, supporting_sentences = await self.cot_generate(question, context, mode="context_fill")
|
||||
return answer, supporting_sentences
|
||||
answer = await self.cot_generate(question, context, mode="context_fill")
|
||||
return answer
|
||||
|
||||
if __name__ == "__main__":
|
||||
async def main():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue