mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 10:26:32 +02:00
197 lines
7.4 KiB
Python
197 lines
7.4 KiB
Python
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
import os
|
|||
|
|
from typing import List
|
|||
|
|
|
|||
|
|
import evaluate
|
|||
|
|
import jieba
|
|||
|
|
from llama_index.core.embeddings import BaseEmbedding
|
|||
|
|
from llama_index.core.evaluation import SemanticSimilarityEvaluator
|
|||
|
|
from llama_index.core.schema import NodeWithScore
|
|||
|
|
from pydantic import BaseModel
|
|||
|
|
|
|||
|
|
from metagpt.const import EXAMPLE_BENCHMARK_PATH
|
|||
|
|
from metagpt.logs import logger
|
|||
|
|
from metagpt.rag.factories import get_rag_embedding
|
|||
|
|
from metagpt.utils.common import read_json_file
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatasetInfo(BaseModel):
|
|||
|
|
name: str
|
|||
|
|
document_files: List[str]
|
|||
|
|
gt_info: List[dict]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DatasetConfig(BaseModel):
|
|||
|
|
datasets: List[DatasetInfo]
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RAGBenchmark:
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
embed_model: BaseEmbedding = None,
|
|||
|
|
):
|
|||
|
|
self.evaluator = SemanticSimilarityEvaluator(
|
|||
|
|
embed_model=embed_model or get_rag_embedding(),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def _set_metrics(
|
|||
|
|
self,
|
|||
|
|
bleu_avg :float = 0.0,
|
|||
|
|
bleu_1 :float = 0.0,
|
|||
|
|
bleu_2 :float = 0.0,
|
|||
|
|
bleu_3 :float = 0.0,
|
|||
|
|
bleu_4 :float = 0.0,
|
|||
|
|
rouge_l :float = 0.0,
|
|||
|
|
semantic_similarity :float = 0.0,
|
|||
|
|
recall :float = 0.0,
|
|||
|
|
hit_rate :float = 0.0,
|
|||
|
|
mrr :float = 0.0,
|
|||
|
|
length :float = 0.0,
|
|||
|
|
generated_text :str = None,
|
|||
|
|
ground_truth_text: str = None,
|
|||
|
|
question: str = None
|
|||
|
|
):
|
|||
|
|
metrics = {
|
|||
|
|
"bleu-avg": bleu_avg,
|
|||
|
|
"bleu-1": bleu_1,
|
|||
|
|
"bleu-2": bleu_2,
|
|||
|
|
"bleu-3": bleu_3,
|
|||
|
|
"bleu-4": bleu_4,
|
|||
|
|
"rouge-L": rouge_l,
|
|||
|
|
"semantic similarity": semantic_similarity,
|
|||
|
|
"recall": recall,
|
|||
|
|
"hit_rate": hit_rate,
|
|||
|
|
"mrr": mrr,
|
|||
|
|
"length": length,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
log = {
|
|||
|
|
"generated_text": generated_text,
|
|||
|
|
"ground_truth_text": ground_truth_text,
|
|||
|
|
"question": question,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
return {"metrics": metrics, "log": log}
|
|||
|
|
|
|||
|
|
|
|||
|
|
def bleu_score(self, response: str, reference: str, with_penalty=False) -> float:
|
|||
|
|
f = lambda text: list(jieba.cut(text))
|
|||
|
|
bleu = evaluate.load(path="bleu")
|
|||
|
|
results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
|
|||
|
|
|
|||
|
|
bleu_avg = results["bleu"]
|
|||
|
|
bleu1 = results["precisions"][0]
|
|||
|
|
bleu2 = results["precisions"][1]
|
|||
|
|
bleu3 = results["precisions"][2]
|
|||
|
|
bleu4 = results["precisions"][3]
|
|||
|
|
brevity_penalty = results["brevity_penalty"]
|
|||
|
|
|
|||
|
|
if with_penalty:
|
|||
|
|
return bleu_avg, bleu1, bleu2, bleu3, bleu4
|
|||
|
|
else:
|
|||
|
|
return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
|
|||
|
|
|
|||
|
|
def rougel_score(self, response: str, reference: str) -> float:
|
|||
|
|
# pip install rouge_score
|
|||
|
|
f = lambda text: list(jieba.cut(text))
|
|||
|
|
rouge = evaluate.load(path="rouge")
|
|||
|
|
|
|||
|
|
results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
|
|||
|
|
score = results["rougeL"]
|
|||
|
|
return score
|
|||
|
|
|
|||
|
|
def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
|||
|
|
if nodes:
|
|||
|
|
total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
|
|||
|
|
return total_recall / len(reference_docs)
|
|||
|
|
else:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
def HitRate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
|||
|
|
if nodes:
|
|||
|
|
return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
|
|||
|
|
else:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
def MRR(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
|||
|
|
mrr_sum = 0.0
|
|||
|
|
|
|||
|
|
for i, doc in enumerate(reference_docs, start=1):
|
|||
|
|
for node in nodes:
|
|||
|
|
if node.text in doc:
|
|||
|
|
mrr_sum += 1.0 / i
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
return mrr_sum / len(reference_docs) if reference_docs else 0.0
|
|||
|
|
|
|||
|
|
async def SemanticSimilarity(self, response: str, reference: str) -> float:
|
|||
|
|
result = await self.evaluator.aevaluate(
|
|||
|
|
response=response,
|
|||
|
|
reference=reference,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return result.score
|
|||
|
|
|
|||
|
|
async def compute_metric(
|
|||
|
|
self,
|
|||
|
|
response: str = None,
|
|||
|
|
reference: str = None,
|
|||
|
|
nodes: list[NodeWithScore] = None,
|
|||
|
|
reference_doc: list[str] = None,
|
|||
|
|
question: str = None,
|
|||
|
|
):
|
|||
|
|
recall = self.recall(nodes, reference_doc)
|
|||
|
|
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
|
|||
|
|
rouge_l = self.rougel_score(response, reference)
|
|||
|
|
hit_rate = self.HitRate(nodes, reference_doc)
|
|||
|
|
mrr = self.MRR(nodes, reference_doc)
|
|||
|
|
|
|||
|
|
similarity = await self.SemanticSimilarity(response, reference)
|
|||
|
|
result = self._set_metrics(
|
|||
|
|
bleu_avg, bleu1, bleu2, bleu3, bleu4, rouge_l,
|
|||
|
|
similarity,
|
|||
|
|
recall, hit_rate, mrr, len(response), response, reference, question
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
@staticmethod
|
|||
|
|
def load_dataset(ds_names: list[str] = ["all"]):
|
|||
|
|
infos = read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, "dataset_info.json"))
|
|||
|
|
dataset_config = DatasetConfig(
|
|||
|
|
datasets=[
|
|||
|
|
DatasetInfo(
|
|||
|
|
name=name,
|
|||
|
|
document_files=[
|
|||
|
|
os.path.join(EXAMPLE_BENCHMARK_PATH, name, file)
|
|||
|
|
for file in info["document_file"]
|
|||
|
|
],
|
|||
|
|
gt_info=read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, name, info["gt_file"])),
|
|||
|
|
)
|
|||
|
|
for dataset_info in infos
|
|||
|
|
for name, info in dataset_info.items()
|
|||
|
|
if name in ds_names or "all" in ds_names
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return dataset_config
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
benchmark = RAGBenchmark()
|
|||
|
|
answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。"
|
|||
|
|
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
|||
|
|
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
|
|||
|
|
logger.info(f"bleu_avg = {bleu_avg}")
|
|||
|
|
logger.info(f"bleu1 = {bleu1}")
|
|||
|
|
logger.info(f"bleu2 = {bleu2}")
|
|||
|
|
logger.info(f"bleu3 = {bleu3}")
|
|||
|
|
logger.info(f"bleu4 = {bleu4}")
|
|||
|
|
|
|||
|
|
rougeL_score = benchmark.rougel_score(answer, ground_truth)
|
|||
|
|
logger.info(f"rougeL_score = {rougeL_score}")
|
|||
|
|
|
|||
|
|
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
|
|||
|
|
logger.info(f"similarity = {similarity}")
|