diff --git a/examples/rag_bm.py b/examples/rag_bm.py index bffe1e49b..9cfbac7a3 100644 --- a/examples/rag_bm.py +++ b/examples/rag_bm.py @@ -17,9 +17,12 @@ from metagpt.rag.schema import ( FAISSRetrieverConfig, FlagEmbeddingConfig, LLMRankerConfig, + CohereRerankConfig, + ColbertRerankConfig, ) from metagpt.utils.common import write_json_file - +import time +import pdb DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt" QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。" @@ -44,27 +47,28 @@ class RAGExample: self.embedding = get_rag_embedding() self.llm = get_rag_llm() - async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["CRUD"]): + async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["RGB_en"]): dataset_config = self.benchmark.load_dataset(dataset_name) - + for dataset in dataset_config.datasets: if "all" in dataset_name or dataset.name in dataset_name: output_dir = DATA_PATH / f"{dataset.name}" if os.path.exists(output_dir): logger.info("Loading Exists index!") + logger.info(f"Index Path:{output_dir}") self.engine = SimpleEngine.from_index( - index_config=FAISSIndexConfig(persist_path=output_dir), - ranker_configs=[LLMRankerConfig()], + index_config=FAISSIndexConfig(), + ranker_configs=[ColbertRerankConfig()], retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()], ) else: logger.info("Loading index from document!") self.engine = SimpleEngine.from_docs( input_files=dataset.document_files, - retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()], - ranker_configs=[FlagEmbeddingConfig()], - transformations=[SentenceSplitter(chunk_size=256, chunk_overlap=0)], + retriever_configs=[FAISSRetrieverConfig()], + ranker_configs=[CohereRerankConfig()], + transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)], ) results = [] for gt_info in dataset.gt_info: @@ -74,7 +78,7 @@ class RAGExample: ground_truth=gt_info["gt_answer"], ) results.append(result) - logger.info(f"=====The {dataset.name} BenchMark dataset assessment is complete!=====") + logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====") self._print_bm_result(results) write_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, dataset.name, "bm_result.json"), results, "utf-8") @@ -115,37 +119,18 @@ class RAGExample: self._print_title("RAG Pipeline") try: nodes = await self.engine.aretrieve(question) - # pdb.set_trace() self._print_result(nodes, state="Retrieve") answer = await self.engine.aquery(question) self._print_result(answer, state="Query") except Exception as e: - print(e) - return { - "metrics": { - "bleu-avg": 0.0, - "bleu-1": 0.0, - "bleu-2": 0.0, - "bleu-3": 0.0, - "bleu-4": 0.0, - "rouge-L": 0.0, - "semantic similarity": 0.0, - "recall": 0.0, - "hit_rate": 0.0, - "mrr": 0.0, - "length": 0, - }, - "log": { - "generated_text": "Retrieve failed due to LLM wasn't follow instruction", - "ground_truth_text": ground_truth, - "question": question, - }, - } + logger.error(e) + return self.benchmark._set_metrics( + generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question + ) - result = await self.evaluate_result(answer.response, ground_truth, nodes, reference) - result["log"]["question"] = question + result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question) logger.info("==========RAG BenchMark result demo as follows==========") logger.info(result) @@ -186,32 +171,10 @@ class RAGExample: reference: str = None, nodes: list[NodeWithScore] = None, reference_doc: list[str] = None, + question: str = None, ): - recall = self.benchmark.recall(nodes, reference_doc) - bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.benchmark.bleu_score(response, reference) - rouge_l = self.benchmark.rougel_score(response, reference) - hit_rate = self.benchmark.HitRate(nodes, reference_doc) - mrr = self.benchmark.MRR(nodes, reference_doc) + result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question) - result = { - "metrics": { - "bleu-avg": bleu_avg or 0.0, - "bleu-1": bleu1 or 0.0, - "bleu-2": bleu2 or 0.0, - "bleu-3": bleu3 or 0.0, - "bleu-4": bleu4 or 0.0, - "rouge-L": rouge_l, - "semantic similarity": await self.benchmark.SemanticSimilarity(response, reference), - "recall": recall, - "hit_rate": hit_rate, - "mrr": mrr, - "length": len(response), - }, - "log": { - "generated_text": response, - "ground_truth_text": reference, - }, - } return result @staticmethod