Add files via upload

This commit is contained in:
YangQianli92 2024-04-17 10:50:21 +08:00 committed by GitHub
parent 52a062db60
commit b93e6779b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel
from metagpt.const import EXAMPLE_BENCHMARK_PATH
from metagpt.logs import logger
from metagpt.rag.factories import get_rag_embedding
from metagpt.utils.common import read_json_file
class DatasetInfo(BaseModel):
@ -31,9 +32,49 @@ class RAGBenchmark:
embed_model: BaseEmbedding = None,
):
self.evaluator = SemanticSimilarityEvaluator(
embed_model=embed_model or get_rag_embedding(),
embed_model=embed_model or get_rag_embedding(),
)
def _set_metrics(
self,
bleu_avg :float = 0.0,
bleu_1 :float = 0.0,
bleu_2 :float = 0.0,
bleu_3 :float = 0.0,
bleu_4 :float = 0.0,
rouge_l :float = 0.0,
semantic_similarity :float = 0.0,
recall :float = 0.0,
hit_rate :float = 0.0,
mrr :float = 0.0,
length :float = 0.0,
generated_text :str = None,
ground_truth_text: str = None,
question: str = None
):
metrics = {
"bleu-avg": bleu_avg,
"bleu-1": bleu_1,
"bleu-2": bleu_2,
"bleu-3": bleu_3,
"bleu-4": bleu_4,
"rouge-L": rouge_l,
"semantic similarity": semantic_similarity,
"recall": recall,
"hit_rate": hit_rate,
"mrr": mrr,
"length": length,
}
log = {
"generated_text": generated_text,
"ground_truth_text": ground_truth_text,
"question": question,
}
return {"metrics": metrics, "log": log}
def bleu_score(self, response: str, reference: str, with_penalty=False) -> float:
f = lambda text: list(jieba.cut(text))
bleu = evaluate.load(path="bleu")
@ -82,7 +123,7 @@ class RAGBenchmark:
mrr_sum += 1.0 / i
break
return mrr_sum / len(nodes) if reference_docs else 0.0
return mrr_sum / len(reference_docs) if reference_docs else 0.0
async def SemanticSimilarity(self, response: str, reference: str) -> float:
result = await self.evaluator.aevaluate(
@ -92,27 +133,48 @@ class RAGBenchmark:
return result.score
async def compute_metric(
self,
response: str = None,
reference: str = None,
nodes: list[NodeWithScore] = None,
reference_doc: list[str] = None,
question: str = None,
):
recall = self.recall(nodes, reference_doc)
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
rouge_l = self.rougel_score(response, reference)
hit_rate = self.HitRate(nodes, reference_doc)
mrr = self.MRR(nodes, reference_doc)
similarity = await self.SemanticSimilarity(response, reference)
result = self._set_metrics(
bleu_avg, bleu1, bleu2, bleu3, bleu4, rouge_l,
similarity,
recall, hit_rate, mrr, len(response), response, reference, question
)
return result
@staticmethod
def load_dataset(ds_names: list[str] = ["CRUD"]):
with open(os.path.join(EXAMPLE_BENCHMARK_PATH, "dataset_info.json"), "r", encoding="utf-8") as f:
infos = json.load(f)
dataset_config = DatasetConfig(
datasets=[
DatasetInfo(
name=name,
document_files=[
os.path.join(EXAMPLE_BENCHMARK_PATH, name, file)
for file in info["document_files"]
],
gt_info=json.load(
open(os.path.join(EXAMPLE_BENCHMARK_PATH, name, info["gt_file"]), "r", encoding="utf-8")
),
)
for dataset_info in infos
for name, info in dataset_info.items()
if name in ds_names or "all" in ds_names
]
)
def load_dataset(ds_names: list[str] = ["all"]):
infos = read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, "dataset_info.json"))
dataset_config = DatasetConfig(
datasets=[
DatasetInfo(
name=name,
document_files=[
os.path.join(EXAMPLE_BENCHMARK_PATH, name, file)
for file in info["document_file"]
],
gt_info=read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, name, info["gt_file"])),
)
for dataset_info in infos
for name, info in dataset_info.items()
if name in ds_names or "all" in ds_names
]
)
return dataset_config
@ -121,16 +183,14 @@ if __name__ == "__main__":
answer = "是的根据提供的信息2023年7月20日应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金中央财政也会提供适当的补助。救助资金将通过专账管理并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息并建立档案。此外《规范》还强调了资金发放的具体方式和公开透明的要求。"
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
rougeL_score = benchmark.rougel_score(answer, ground_truth)
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
logger.info(
f"BLEU Scores:\n"
f"bleu_avg = {bleu_avg}\n"
f"bleu1 = {bleu1}\n"
f"bleu2 = {bleu2}\n"
f"bleu3 = {bleu3}\n"
f"bleu4 = {bleu4}\n"
f"rougeL_score = {rougeL_score}\n"
f"similarity = {similarity}\n"
)
logger.info(f"bleu_avg = {bleu_avg}")
logger.info(f"bleu1 = {bleu1}")
logger.info(f"bleu2 = {bleu2}")
logger.info(f"bleu3 = {bleu3}")
logger.info(f"bleu4 = {bleu4}")
rougeL_score = benchmark.rougel_score(answer, ground_truth)
logger.info(f"rougeL_score = {rougeL_score}")
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
logger.info(f"similarity = {similarity}")