mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 07:42:38 +02:00
Add files via upload
This commit is contained in:
parent
42791af6ef
commit
b05663bf74
1 changed files with 20 additions and 57 deletions
|
|
@ -17,9 +17,12 @@ from metagpt.rag.schema import (
|
|||
FAISSRetrieverConfig,
|
||||
FlagEmbeddingConfig,
|
||||
LLMRankerConfig,
|
||||
CohereRerankConfig,
|
||||
ColbertRerankConfig,
|
||||
)
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
import time
|
||||
import pdb
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt"
|
||||
QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。"
|
||||
|
||||
|
|
@ -44,27 +47,28 @@ class RAGExample:
|
|||
self.embedding = get_rag_embedding()
|
||||
self.llm = get_rag_llm()
|
||||
|
||||
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["CRUD"]):
|
||||
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["RGB_en"]):
|
||||
dataset_config = self.benchmark.load_dataset(dataset_name)
|
||||
|
||||
|
||||
for dataset in dataset_config.datasets:
|
||||
if "all" in dataset_name or dataset.name in dataset_name:
|
||||
output_dir = DATA_PATH / f"{dataset.name}"
|
||||
|
||||
if os.path.exists(output_dir):
|
||||
logger.info("Loading Exists index!")
|
||||
logger.info(f"Index Path:{output_dir}")
|
||||
self.engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=output_dir),
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
index_config=FAISSIndexConfig(),
|
||||
ranker_configs=[ColbertRerankConfig()],
|
||||
retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()],
|
||||
)
|
||||
else:
|
||||
logger.info("Loading index from document!")
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=dataset.document_files,
|
||||
retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()],
|
||||
ranker_configs=[FlagEmbeddingConfig()],
|
||||
transformations=[SentenceSplitter(chunk_size=256, chunk_overlap=0)],
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[CohereRerankConfig()],
|
||||
transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)],
|
||||
)
|
||||
results = []
|
||||
for gt_info in dataset.gt_info:
|
||||
|
|
@ -74,7 +78,7 @@ class RAGExample:
|
|||
ground_truth=gt_info["gt_answer"],
|
||||
)
|
||||
results.append(result)
|
||||
logger.info(f"=====The {dataset.name} BenchMark dataset assessment is complete!=====")
|
||||
logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====")
|
||||
self._print_bm_result(results)
|
||||
|
||||
write_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, dataset.name, "bm_result.json"), results, "utf-8")
|
||||
|
|
@ -115,37 +119,18 @@ class RAGExample:
|
|||
self._print_title("RAG Pipeline")
|
||||
try:
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
# pdb.set_trace()
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return {
|
||||
"metrics": {
|
||||
"bleu-avg": 0.0,
|
||||
"bleu-1": 0.0,
|
||||
"bleu-2": 0.0,
|
||||
"bleu-3": 0.0,
|
||||
"bleu-4": 0.0,
|
||||
"rouge-L": 0.0,
|
||||
"semantic similarity": 0.0,
|
||||
"recall": 0.0,
|
||||
"hit_rate": 0.0,
|
||||
"mrr": 0.0,
|
||||
"length": 0,
|
||||
},
|
||||
"log": {
|
||||
"generated_text": "Retrieve failed due to LLM wasn't follow instruction",
|
||||
"ground_truth_text": ground_truth,
|
||||
"question": question,
|
||||
},
|
||||
}
|
||||
logger.error(e)
|
||||
return self.benchmark._set_metrics(
|
||||
generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question
|
||||
)
|
||||
|
||||
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference)
|
||||
result["log"]["question"] = question
|
||||
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question)
|
||||
|
||||
logger.info("==========RAG BenchMark result demo as follows==========")
|
||||
logger.info(result)
|
||||
|
|
@ -186,32 +171,10 @@ class RAGExample:
|
|||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
recall = self.benchmark.recall(nodes, reference_doc)
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.benchmark.bleu_score(response, reference)
|
||||
rouge_l = self.benchmark.rougel_score(response, reference)
|
||||
hit_rate = self.benchmark.HitRate(nodes, reference_doc)
|
||||
mrr = self.benchmark.MRR(nodes, reference_doc)
|
||||
result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question)
|
||||
|
||||
result = {
|
||||
"metrics": {
|
||||
"bleu-avg": bleu_avg or 0.0,
|
||||
"bleu-1": bleu1 or 0.0,
|
||||
"bleu-2": bleu2 or 0.0,
|
||||
"bleu-3": bleu3 or 0.0,
|
||||
"bleu-4": bleu4 or 0.0,
|
||||
"rouge-L": rouge_l,
|
||||
"semantic similarity": await self.benchmark.SemanticSimilarity(response, reference),
|
||||
"recall": recall,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr,
|
||||
"length": len(response),
|
||||
},
|
||||
"log": {
|
||||
"generated_text": response,
|
||||
"ground_truth_text": reference,
|
||||
},
|
||||
}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue