Add files via upload

This commit is contained in:
YangQianli92 2024-04-17 11:05:24 +08:00 committed by GitHub
parent 42791af6ef
commit b05663bf74
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -17,9 +17,12 @@ from metagpt.rag.schema import (
FAISSRetrieverConfig,
FlagEmbeddingConfig,
LLMRankerConfig,
CohereRerankConfig,
ColbertRerankConfig,
)
from metagpt.utils.common import write_json_file
import time
import pdb
DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt"
QUESTION = "2023年7月20日应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知规范倒损住房恢复重建救助相关工作。"
@ -44,27 +47,28 @@ class RAGExample:
self.embedding = get_rag_embedding()
self.llm = get_rag_llm()
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["CRUD"]):
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["RGB_en"]):
dataset_config = self.benchmark.load_dataset(dataset_name)
for dataset in dataset_config.datasets:
if "all" in dataset_name or dataset.name in dataset_name:
output_dir = DATA_PATH / f"{dataset.name}"
if os.path.exists(output_dir):
logger.info("Loading Exists index!")
logger.info(f"Index Path:{output_dir}")
self.engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=output_dir),
ranker_configs=[LLMRankerConfig()],
index_config=FAISSIndexConfig(),
ranker_configs=[ColbertRerankConfig()],
retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()],
)
else:
logger.info("Loading index from document!")
self.engine = SimpleEngine.from_docs(
input_files=dataset.document_files,
retriever_configs=[FAISSRetrieverConfig(persist_path=output_dir), BM25RetrieverConfig()],
ranker_configs=[FlagEmbeddingConfig()],
transformations=[SentenceSplitter(chunk_size=256, chunk_overlap=0)],
retriever_configs=[FAISSRetrieverConfig()],
ranker_configs=[CohereRerankConfig()],
transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)],
)
results = []
for gt_info in dataset.gt_info:
@ -74,7 +78,7 @@ class RAGExample:
ground_truth=gt_info["gt_answer"],
)
results.append(result)
logger.info(f"=====The {dataset.name} BenchMark dataset assessment is complete!=====")
logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====")
self._print_bm_result(results)
write_json_file(os.path.join(EXAMPLE_BENCHMARK_PATH, dataset.name, "bm_result.json"), results, "utf-8")
@ -115,37 +119,18 @@ class RAGExample:
self._print_title("RAG Pipeline")
try:
nodes = await self.engine.aretrieve(question)
# pdb.set_trace()
self._print_result(nodes, state="Retrieve")
answer = await self.engine.aquery(question)
self._print_result(answer, state="Query")
except Exception as e:
print(e)
return {
"metrics": {
"bleu-avg": 0.0,
"bleu-1": 0.0,
"bleu-2": 0.0,
"bleu-3": 0.0,
"bleu-4": 0.0,
"rouge-L": 0.0,
"semantic similarity": 0.0,
"recall": 0.0,
"hit_rate": 0.0,
"mrr": 0.0,
"length": 0,
},
"log": {
"generated_text": "Retrieve failed due to LLM wasn't follow instruction",
"ground_truth_text": ground_truth,
"question": question,
},
}
logger.error(e)
return self.benchmark._set_metrics(
generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question
)
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference)
result["log"]["question"] = question
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question)
logger.info("==========RAG BenchMark result demo as follows==========")
logger.info(result)
@ -186,32 +171,10 @@ class RAGExample:
reference: str = None,
nodes: list[NodeWithScore] = None,
reference_doc: list[str] = None,
question: str = None,
):
recall = self.benchmark.recall(nodes, reference_doc)
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.benchmark.bleu_score(response, reference)
rouge_l = self.benchmark.rougel_score(response, reference)
hit_rate = self.benchmark.HitRate(nodes, reference_doc)
mrr = self.benchmark.MRR(nodes, reference_doc)
result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question)
result = {
"metrics": {
"bleu-avg": bleu_avg or 0.0,
"bleu-1": bleu1 or 0.0,
"bleu-2": bleu2 or 0.0,
"bleu-3": bleu3 or 0.0,
"bleu-4": bleu4 or 0.0,
"rouge-L": rouge_l,
"semantic similarity": await self.benchmark.SemanticSimilarity(response, reference),
"recall": recall,
"hit_rate": hit_rate,
"mrr": mrr,
"length": len(response),
},
"log": {
"generated_text": response,
"ground_truth_text": reference,
},
}
return result
@staticmethod