diff --git a/metagpt/rag/benchmark/base.py b/metagpt/rag/benchmark/base.py index 53bb2f0af..01dc0fbb0 100644 --- a/metagpt/rag/benchmark/base.py +++ b/metagpt/rag/benchmark/base.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from metagpt.const import EXAMPLE_BENCHMARK_PATH from metagpt.logs import logger from metagpt.rag.factories import get_rag_embedding +from metagpt.utils.common import read_json_file class DatasetInfo(BaseModel): @@ -31,9 +32,49 @@ class RAGBenchmark: embed_model: BaseEmbedding = None, ): self.evaluator = SemanticSimilarityEvaluator( - embed_model=embed_model or get_rag_embedding(), + embed_model=embed_model or get_rag_embedding(), ) + def _set_metrics( + self, + bleu_avg :float = 0.0, + bleu_1 :float = 0.0, + bleu_2 :float = 0.0, + bleu_3 :float = 0.0, + bleu_4 :float = 0.0, + rouge_l :float = 0.0, + semantic_similarity :float = 0.0, + recall :float = 0.0, + hit_rate :float = 0.0, + mrr :float = 0.0, + length :float = 0.0, + generated_text :str = None, + ground_truth_text: str = None, + question: str = None + ): + metrics = { + "bleu-avg": bleu_avg, + "bleu-1": bleu_1, + "bleu-2": bleu_2, + "bleu-3": bleu_3, + "bleu-4": bleu_4, + "rouge-L": rouge_l, + "semantic similarity": semantic_similarity, + "recall": recall, + "hit_rate": hit_rate, + "mrr": mrr, + "length": length, + } + + log = { + "generated_text": generated_text, + "ground_truth_text": ground_truth_text, + "question": question, + } + + return {"metrics": metrics, "log": log} + + def bleu_score(self, response: str, reference: str, with_penalty=False) -> float: f = lambda text: list(jieba.cut(text)) bleu = evaluate.load(path="bleu") @@ -82,7 +123,7 @@ class RAGBenchmark: mrr_sum += 1.0 / i break - return mrr_sum / len(nodes) if reference_docs else 0.0 + return mrr_sum / len(reference_docs) if reference_docs else 0.0 async def SemanticSimilarity(self, response: str, reference: str) -> float: result = await self.evaluator.aevaluate( @@ -92,27 +133,48 @@ class RAGBenchmark: return result.score + async def compute_metric( + self, + response: str = None, + reference: str = None, + nodes: list[NodeWithScore] = None, + reference_doc: list[str] = None, + question: str = None, + ): + recall = self.recall(nodes, reference_doc) + bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference) + rouge_l = self.rougel_score(response, reference) + hit_rate = self.HitRate(nodes, reference_doc) + mrr = self.MRR(nodes, reference_doc) + + similarity = await self.SemanticSimilarity(response, reference) + result = self._set_metrics( + bleu_avg, bleu1, bleu2, bleu3, bleu4, rouge_l, + similarity, + recall, hit_rate, mrr, len(response), response, reference, question + ) + + return result + @staticmethod - def load_dataset(ds_names: list[str] = ["CRUD"]): - with open(os.path.join(EXAMPLE_BENCHMARK_PATH, "dataset_info.json"), "r", encoding="utf-8") as f: - infos = json.load(f) - dataset_config = DatasetConfig( - datasets=[ - DatasetInfo( - name=name, - document_files=[ - os.path.join(EXAMPLE_BENCHMARK_PATH, name, file) - for file in info["document_files"] - ], - gt_info=json.load( - open(os.path.join(EXAMPLE_BENCHMARK_PATH, name, info["gt_file"]), "r", encoding="utf-8") - ), - ) - for dataset_info in infos - for name, info in dataset_info.items() - if name in ds_names or "all" in ds_names - ] - ) + def load_dataset(ds_names: list[str] = ["all"]): + infos = read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, "dataset_info.json")) + dataset_config = DatasetConfig( + datasets=[ + DatasetInfo( + name=name, + document_files=[ + os.path.join(EXAMPLE_BENCHMARK_PATH, name, file) + for file in info["document_file"] + ], + gt_info=read_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, name, info["gt_file"])), + ) + for dataset_info in infos + for name, info in dataset_info.items() + if name in ds_names or "all" in ds_names + ] + ) + return dataset_config @@ -121,16 +183,14 @@ if __name__ == "__main__": answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。" ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。" bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth) - rougeL_score = benchmark.rougel_score(answer, ground_truth) - similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth)) - logger.info( - f"BLEU Scores:\n" - f"bleu_avg = {bleu_avg}\n" - f"bleu1 = {bleu1}\n" - f"bleu2 = {bleu2}\n" - f"bleu3 = {bleu3}\n" - f"bleu4 = {bleu4}\n" - f"rougeL_score = {rougeL_score}\n" - f"similarity = {similarity}\n" - ) + logger.info(f"bleu_avg = {bleu_avg}") + logger.info(f"bleu1 = {bleu1}") + logger.info(f"bleu2 = {bleu2}") + logger.info(f"bleu3 = {bleu3}") + logger.info(f"bleu4 = {bleu4}") + rougeL_score = benchmark.rougel_score(answer, ground_truth) + logger.info(f"rougeL_score = {rougeL_score}") + + similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth)) + logger.info(f"similarity = {similarity}")