mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
update crlf
This commit is contained in:
parent
2476672833
commit
95e6560924
2 changed files with 431 additions and 429 deletions
|
|
@ -1,234 +1,230 @@
|
|||
"""RAG benchmark pipeline"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_BENCHMARK_PATH, EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.benchmark import RAGBenchmark
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.factories import get_rag_embedding, get_rag_llm
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
FAISSIndexConfig,
|
||||
FAISSRetrieverConfig,
|
||||
BGERerankConfig,
|
||||
LLMRankerConfig,
|
||||
CohereRerankConfig,
|
||||
ColbertRerankConfig,
|
||||
)
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt"
|
||||
QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/documents.txt"
|
||||
TRAVEL_QUESTION = "国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控哪个群体的哪种健康问题,并请列出活动发布的指导性文件名称。"
|
||||
|
||||
DATASET_PATH = EXAMPLE_DATA_PATH / "rag_bm/test.json"
|
||||
SAVE_PATH = EXAMPLE_DATA_PATH / "rag_bm/result.json"
|
||||
GROUND_TRUTH_TRANVEL = "2023-07-28 10:14:27作者:白剑峰来源:人民日报 ,正文:为在全社会形成重视儿童眼健康的良好氛围,持续推进综合防控儿童青少年近视工作落实,国家卫生健康委决定在全国持续开展“启明行动”——防控儿童青少年近视健康促进活动,并发布了《防控儿童青少年近视核心知识十条》。本次活动的主题为:重视儿童眼保健,守护孩子明眸“视”界。强调预防为主,推动关口前移,倡导和推动家庭及全社会共同行动起来,营造爱眼护眼的视觉友好环境,共同呵护好孩子的眼睛,让他们拥有一个光明的未来。国家卫生健康委要求,开展社会宣传和健康教育。充分利用网络、广播电视、报纸杂志、海报墙报、培训讲座等多种形式,广泛开展宣传倡导,向社会公众传播开展儿童眼保健、保护儿童视力健康的重要意义,以《防控儿童青少年近视核心知识十条》为重点普及预防近视科学知识。创新健康教育方式和载体,开发制作群众喜闻乐见的健康教育科普作品,利用互联网媒体扩大传播效果,提高健康教育的针对性、精准性和实效性。指导相关医疗机构将儿童眼保健和近视防控等科学知识纳入孕妇学校、家长课堂内容。开展儿童眼保健及视力检查咨询指导。医疗机构要以儿童家长和养育人为重点,结合眼保健和眼科临床服务,开展个性化咨询指导。要针对儿童常见眼病和近视防控等重点问题,通过面对面咨询指导,引导儿童家长树立近视防控意识,改变不良生活方式,加强户外活动,养成爱眼护眼健康行为习惯。提高儿童眼保健专科服务能力。各地要积极推进儿童眼保健专科建设,扎实组织好妇幼健康职业技能竞赛“儿童眼保健”项目,推动各层级开展比武练兵,提升业务能力。"
|
||||
GROUND_TRUTH_ANSWER = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
LLM_ERROR = "Retrieve failed due to LLM wasn't follow instruction"
|
||||
EMPTY_ERROR = "Empty Response"
|
||||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG for evaluation."""
|
||||
|
||||
def __init__(self):
|
||||
self.benchmark = RAGBenchmark()
|
||||
self.embedding = get_rag_embedding()
|
||||
self.llm = get_rag_llm()
|
||||
|
||||
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["all"]):
|
||||
dataset_config = self.benchmark.load_dataset(dataset_name)
|
||||
|
||||
for dataset in dataset_config.datasets:
|
||||
if "all" in dataset_name or dataset.name in dataset_name:
|
||||
output_dir = DATA_PATH / f"{dataset.name}"
|
||||
|
||||
if output_dir.exists():
|
||||
logger.info("Loading Existed index!")
|
||||
logger.info(f"Index Path:{output_dir}")
|
||||
self.engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=output_dir),
|
||||
ranker_configs=[ColbertRerankConfig()],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
)
|
||||
else:
|
||||
logger.info("Loading index from documents!")
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=dataset.document_files,
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[CohereRerankConfig()],
|
||||
transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)],
|
||||
)
|
||||
results = []
|
||||
for gt_info in dataset.gt_info:
|
||||
result = await self.rag_evaluate_single(
|
||||
question=gt_info["question"],
|
||||
reference=gt_info["gt_reference"],
|
||||
ground_truth=gt_info["gt_answer"],
|
||||
)
|
||||
results.append(result)
|
||||
logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====")
|
||||
self._print_bm_result(results)
|
||||
|
||||
write_json_file((EXAMPLE_BENCHMARK_PATH / dataset.name / "bm_result.json").as_posix(), results, "utf-8")
|
||||
|
||||
async def rag_evaluate_single(self, question, reference, ground_truth, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
|
||||
Query Result:
|
||||
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
|
||||
|
||||
RAG BenchMark result:
|
||||
{
|
||||
'metrics':
|
||||
{
|
||||
'bleu-avg': 0.48318624982047,
|
||||
'bleu-1': 0.5609756097560976,
|
||||
'bleu-2': 0.5,
|
||||
'bleu-3': 0.46153846153846156,
|
||||
'bleu-4': 0.42105263157894735,
|
||||
'rouge-L': 0.6865671641791045,
|
||||
'semantic similarity': 0.9487444961487591,
|
||||
'length': 74
|
||||
},
|
||||
'log': {
|
||||
'generated_text':
|
||||
'国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控儿童青少年的近视问题。活动发布的指导性文件名称为《防控儿童青少年近视核心知识十条》。',
|
||||
'ground_truth_text':
|
||||
'“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。'
|
||||
}
|
||||
}
|
||||
"""
|
||||
if print_title:
|
||||
self._print_title("RAG Pipeline")
|
||||
try:
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.benchmark.set_metrics(
|
||||
generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question
|
||||
)
|
||||
|
||||
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question)
|
||||
|
||||
logger.info("==========RAG BenchMark result demo as follows==========")
|
||||
logger.info(result)
|
||||
|
||||
return result
|
||||
|
||||
async def rag_faissdb(self):
|
||||
"""This example show how to use FAISS. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("RAG FAISS")
|
||||
|
||||
# save index
|
||||
output_dir = DATA_PATH / "rag_faiss"
|
||||
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(dimensions=512, persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=output_dir),
|
||||
)
|
||||
|
||||
# query
|
||||
nodes = engine.retrieve(QUESTION)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
async def evaluate_result(
|
||||
self,
|
||||
response: str = None,
|
||||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
logger.info(f"{'#'*30} {title} {'#'*30}")
|
||||
|
||||
@staticmethod
|
||||
def _print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
logger.info(f"{state} Result:")
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
logger.info("======Retrieve Finished======")
|
||||
return
|
||||
|
||||
logger.info(f"{result}\n")
|
||||
|
||||
@staticmethod
|
||||
def _print_bm_result(result):
|
||||
import pandas as pd
|
||||
|
||||
metrics = [
|
||||
item["metrics"]
|
||||
for item in result
|
||||
if item["log"]["generated_text"] != LLM_ERROR and item["log"]["generated_text"] != EMPTY_ERROR
|
||||
]
|
||||
|
||||
data = pd.DataFrame(metrics)
|
||||
logger.info(f"\n {data.mean()}")
|
||||
|
||||
llm_errors = [item for item in result if item["log"]["generated_text"] == LLM_ERROR]
|
||||
retrieve_errors = [item for item in result if item["log"]["generated_text"] == EMPTY_ERROR]
|
||||
logger.info(
|
||||
f"Percentage of retrieval failures due to incorrect LLM instruction following:"
|
||||
f" {100.0 * len(llm_errors) / len(result)}%"
|
||||
)
|
||||
logger.info(
|
||||
f"Percentage of retrieval failures due to retriever not recalling any documents is:"
|
||||
f" {100.0 * len(retrieve_errors) / len(result)}%"
|
||||
)
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
return nodes
|
||||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
e = RAGExample()
|
||||
await e.rag_evaluate_pipeline()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
"""RAG benchmark pipeline"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_BENCHMARK_PATH, EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.benchmark import RAGBenchmark
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.factories import get_rag_embedding, get_rag_llm
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
CohereRerankConfig,
|
||||
ColbertRerankConfig,
|
||||
FAISSIndexConfig,
|
||||
FAISSRetrieverConfig,
|
||||
)
|
||||
from metagpt.utils.common import write_json_file
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/summary_writer.txt"
|
||||
QUESTION = "2023年7月20日,应急管理部、财政部联合下发《因灾倒塌、损坏住房恢复重建救助工作规范》的通知,规范倒损住房恢复重建救助相关工作。"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag_bm/documents.txt"
|
||||
TRAVEL_QUESTION = "国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控哪个群体的哪种健康问题,并请列出活动发布的指导性文件名称。"
|
||||
|
||||
DATASET_PATH = EXAMPLE_DATA_PATH / "rag_bm/test.json"
|
||||
SAVE_PATH = EXAMPLE_DATA_PATH / "rag_bm/result.json"
|
||||
GROUND_TRUTH_TRANVEL = "2023-07-28 10:14:27作者:白剑峰来源:人民日报 ,正文:为在全社会形成重视儿童眼健康的良好氛围,持续推进综合防控儿童青少年近视工作落实,国家卫生健康委决定在全国持续开展“启明行动”——防控儿童青少年近视健康促进活动,并发布了《防控儿童青少年近视核心知识十条》。本次活动的主题为:重视儿童眼保健,守护孩子明眸“视”界。强调预防为主,推动关口前移,倡导和推动家庭及全社会共同行动起来,营造爱眼护眼的视觉友好环境,共同呵护好孩子的眼睛,让他们拥有一个光明的未来。国家卫生健康委要求,开展社会宣传和健康教育。充分利用网络、广播电视、报纸杂志、海报墙报、培训讲座等多种形式,广泛开展宣传倡导,向社会公众传播开展儿童眼保健、保护儿童视力健康的重要意义,以《防控儿童青少年近视核心知识十条》为重点普及预防近视科学知识。创新健康教育方式和载体,开发制作群众喜闻乐见的健康教育科普作品,利用互联网媒体扩大传播效果,提高健康教育的针对性、精准性和实效性。指导相关医疗机构将儿童眼保健和近视防控等科学知识纳入孕妇学校、家长课堂内容。开展儿童眼保健及视力检查咨询指导。医疗机构要以儿童家长和养育人为重点,结合眼保健和眼科临床服务,开展个性化咨询指导。要针对儿童常见眼病和近视防控等重点问题,通过面对面咨询指导,引导儿童家长树立近视防控意识,改变不良生活方式,加强户外活动,养成爱眼护眼健康行为习惯。提高儿童眼保健专科服务能力。各地要积极推进儿童眼保健专科建设,扎实组织好妇幼健康职业技能竞赛“儿童眼保健”项目,推动各层级开展比武练兵,提升业务能力。"
|
||||
GROUND_TRUTH_ANSWER = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
LLM_ERROR = "Retrieve failed due to LLM wasn't follow instruction"
|
||||
EMPTY_ERROR = "Empty Response"
|
||||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG for evaluation."""
|
||||
|
||||
def __init__(self):
|
||||
self.benchmark = RAGBenchmark()
|
||||
self.embedding = get_rag_embedding()
|
||||
self.llm = get_rag_llm()
|
||||
|
||||
async def rag_evaluate_pipeline(self, dataset_name: list[str] = ["all"]):
|
||||
dataset_config = self.benchmark.load_dataset(dataset_name)
|
||||
|
||||
for dataset in dataset_config.datasets:
|
||||
if "all" in dataset_name or dataset.name in dataset_name:
|
||||
output_dir = DATA_PATH / f"{dataset.name}"
|
||||
|
||||
if output_dir.exists():
|
||||
logger.info("Loading Existed index!")
|
||||
logger.info(f"Index Path:{output_dir}")
|
||||
self.engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=output_dir),
|
||||
ranker_configs=[ColbertRerankConfig()],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
)
|
||||
else:
|
||||
logger.info("Loading index from documents!")
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=dataset.document_files,
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
ranker_configs=[CohereRerankConfig()],
|
||||
transformations=[SentenceSplitter(chunk_size=1024, chunk_overlap=0)],
|
||||
)
|
||||
results = []
|
||||
for gt_info in dataset.gt_info:
|
||||
result = await self.rag_evaluate_single(
|
||||
question=gt_info["question"],
|
||||
reference=gt_info["gt_reference"],
|
||||
ground_truth=gt_info["gt_answer"],
|
||||
)
|
||||
results.append(result)
|
||||
logger.info(f"=====The {dataset.name} Benchmark dataset assessment is complete!=====")
|
||||
self._print_bm_result(results)
|
||||
|
||||
write_json_file((EXAMPLE_BENCHMARK_PATH / dataset.name / "bm_result.json").as_posix(), results, "utf-8")
|
||||
|
||||
async def rag_evaluate_single(self, question, reference, ground_truth, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
|
||||
Query Result:
|
||||
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
|
||||
|
||||
RAG BenchMark result:
|
||||
{
|
||||
'metrics':
|
||||
{
|
||||
'bleu-avg': 0.48318624982047,
|
||||
'bleu-1': 0.5609756097560976,
|
||||
'bleu-2': 0.5,
|
||||
'bleu-3': 0.46153846153846156,
|
||||
'bleu-4': 0.42105263157894735,
|
||||
'rouge-L': 0.6865671641791045,
|
||||
'semantic similarity': 0.9487444961487591,
|
||||
'length': 74
|
||||
},
|
||||
'log': {
|
||||
'generated_text':
|
||||
'国家卫生健康委在2023年7月28日开展的“启明行动”是为了防控儿童青少年的近视问题。活动发布的指导性文件名称为《防控儿童青少年近视核心知识十条》。',
|
||||
'ground_truth_text':
|
||||
'“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。'
|
||||
}
|
||||
}
|
||||
"""
|
||||
if print_title:
|
||||
self._print_title("RAG Pipeline")
|
||||
try:
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return self.benchmark.set_metrics(
|
||||
generated_text=LLM_ERROR, ground_truth_text=ground_truth, question=question
|
||||
)
|
||||
|
||||
result = await self.evaluate_result(answer.response, ground_truth, nodes, reference, question)
|
||||
|
||||
logger.info("==========RAG BenchMark result demo as follows==========")
|
||||
logger.info(result)
|
||||
|
||||
return result
|
||||
|
||||
async def rag_faissdb(self):
|
||||
"""This example show how to use FAISS. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("RAG FAISS")
|
||||
|
||||
# save index
|
||||
output_dir = DATA_PATH / "rag_faiss"
|
||||
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(dimensions=512, persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=output_dir),
|
||||
)
|
||||
|
||||
# query
|
||||
nodes = engine.retrieve(QUESTION)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
self._print_result(answer, state="Query")
|
||||
|
||||
async def evaluate_result(
|
||||
self,
|
||||
response: str = None,
|
||||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
result = await self.benchmark.compute_metric(response, reference, nodes, reference_doc, question)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
logger.info(f"{'#'*30} {title} {'#'*30}")
|
||||
|
||||
@staticmethod
|
||||
def _print_result(result, state="Retrieve"):
|
||||
"""print retrieve or query result"""
|
||||
logger.info(f"{state} Result:")
|
||||
|
||||
if state == "Retrieve":
|
||||
for i, node in enumerate(result):
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
logger.info("======Retrieve Finished======")
|
||||
return
|
||||
|
||||
logger.info(f"{result}\n")
|
||||
|
||||
@staticmethod
|
||||
def _print_bm_result(result):
|
||||
import pandas as pd
|
||||
|
||||
metrics = [
|
||||
item["metrics"]
|
||||
for item in result
|
||||
if item["log"]["generated_text"] != LLM_ERROR and item["log"]["generated_text"] != EMPTY_ERROR
|
||||
]
|
||||
|
||||
data = pd.DataFrame(metrics)
|
||||
logger.info(f"\n {data.mean()}")
|
||||
|
||||
llm_errors = [item for item in result if item["log"]["generated_text"] == LLM_ERROR]
|
||||
retrieve_errors = [item for item in result if item["log"]["generated_text"] == EMPTY_ERROR]
|
||||
logger.info(
|
||||
f"Percentage of retrieval failures due to incorrect LLM instruction following:"
|
||||
f" {100.0 * len(llm_errors) / len(result)}%"
|
||||
)
|
||||
logger.info(
|
||||
f"Percentage of retrieval failures due to retriever not recalling any documents is:"
|
||||
f" {100.0 * len(retrieve_errors) / len(result)}%"
|
||||
)
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_result(nodes, state="Retrieve")
|
||||
return nodes
|
||||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
e = RAGExample()
|
||||
await e.rag_evaluate_pipeline()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -1,195 +1,201 @@
|
|||
import asyncio
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
|
||||
import evaluate
|
||||
import jieba
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.evaluation import SemanticSimilarityEvaluator
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import EXAMPLE_BENCHMARK_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.factories import get_rag_embedding
|
||||
from metagpt.utils.common import read_json_file
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
name: str
|
||||
document_files: List[str]
|
||||
gt_info: List[dict]
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
datasets: List[DatasetInfo]
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
embed_model: BaseEmbedding = None,
|
||||
):
|
||||
self.evaluator = SemanticSimilarityEvaluator(
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
)
|
||||
|
||||
def set_metrics(
|
||||
self,
|
||||
bleu_avg :float = 0.0,
|
||||
bleu_1 :float = 0.0,
|
||||
bleu_2 :float = 0.0,
|
||||
bleu_3 :float = 0.0,
|
||||
bleu_4 :float = 0.0,
|
||||
rouge_l :float = 0.0,
|
||||
semantic_similarity :float = 0.0,
|
||||
recall :float = 0.0,
|
||||
hit_rate :float = 0.0,
|
||||
mrr :float = 0.0,
|
||||
length :float = 0.0,
|
||||
generated_text :str = None,
|
||||
ground_truth_text: str = None,
|
||||
question: str = None
|
||||
):
|
||||
metrics = {
|
||||
"bleu-avg": bleu_avg,
|
||||
"bleu-1": bleu_1,
|
||||
"bleu-2": bleu_2,
|
||||
"bleu-3": bleu_3,
|
||||
"bleu-4": bleu_4,
|
||||
"rouge-L": rouge_l,
|
||||
"semantic similarity": semantic_similarity,
|
||||
"recall": recall,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr,
|
||||
"length": length,
|
||||
}
|
||||
|
||||
log = {
|
||||
"generated_text": generated_text,
|
||||
"ground_truth_text": ground_truth_text,
|
||||
"question": question,
|
||||
}
|
||||
|
||||
return {"metrics": metrics, "log": log}
|
||||
|
||||
|
||||
def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]:
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
bleu = evaluate.load(path="bleu")
|
||||
results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
|
||||
|
||||
bleu_avg = results["bleu"]
|
||||
bleu1 = results["precisions"][0]
|
||||
bleu2 = results["precisions"][1]
|
||||
bleu3 = results["precisions"][2]
|
||||
bleu4 = results["precisions"][3]
|
||||
brevity_penalty = results["brevity_penalty"]
|
||||
|
||||
if with_penalty:
|
||||
return bleu_avg, bleu1, bleu2, bleu3, bleu4
|
||||
else:
|
||||
return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
|
||||
|
||||
def rougel_score(self, response: str, reference: str) -> float:
|
||||
# pip install rouge_score
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
rouge = evaluate.load(path="rouge")
|
||||
|
||||
results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
|
||||
score = results["rougeL"]
|
||||
return score
|
||||
|
||||
def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
|
||||
return total_recall / len(reference_docs)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
mrr_sum = 0.0
|
||||
|
||||
for i, doc in enumerate(reference_docs, start=1):
|
||||
for node in nodes:
|
||||
if node.text in doc:
|
||||
mrr_sum += 1.0 / i
|
||||
break
|
||||
|
||||
return mrr_sum / len(reference_docs) if reference_docs else 0.0
|
||||
|
||||
async def semantic_similarity(self, response: str, reference: str) -> float:
|
||||
result = await self.evaluator.aevaluate(
|
||||
response=response,
|
||||
reference=reference,
|
||||
)
|
||||
|
||||
return result.score
|
||||
|
||||
async def compute_metric(
|
||||
self,
|
||||
response: str = None,
|
||||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
recall = self.recall(nodes, reference_doc)
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
|
||||
rouge_l = self.rougel_score(response, reference)
|
||||
hit_rate = self.hit_rate(nodes, reference_doc)
|
||||
mrr = self.mean_reciprocal_rank(nodes, reference_doc)
|
||||
|
||||
similarity = await self.semantic_similarity(response, reference)
|
||||
|
||||
result = self.set_metrics(
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4, rouge_l,
|
||||
similarity,
|
||||
recall, hit_rate, mrr, len(response), response, reference, question
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(ds_names: list[str] = ["all"]):
|
||||
infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix())
|
||||
dataset_config = DatasetConfig(
|
||||
datasets=[
|
||||
DatasetInfo(
|
||||
name=name,
|
||||
document_files=[
|
||||
(EXAMPLE_BENCHMARK_PATH / name / file).as_posix()
|
||||
for file in info["document_file"]
|
||||
],
|
||||
gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()),
|
||||
)
|
||||
for dataset_info in infos
|
||||
for name, info in dataset_info.items()
|
||||
if name in ds_names or "all" in ds_names
|
||||
]
|
||||
)
|
||||
|
||||
return dataset_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark = RAGBenchmark()
|
||||
answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。"
|
||||
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
|
||||
rougeL_score = benchmark.rougel_score(answer, ground_truth)
|
||||
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
|
||||
|
||||
logger.info(
|
||||
f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, "
|
||||
f"RougeL Score: {rougeL_score}, "
|
||||
f"Semantic Similarity: {similarity}"
|
||||
)
|
||||
|
||||
|
||||
import asyncio
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import jieba
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.evaluation import SemanticSimilarityEvaluator
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import EXAMPLE_BENCHMARK_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.factories import get_rag_embedding
|
||||
from metagpt.utils.common import read_json_file
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
name: str
|
||||
document_files: List[str]
|
||||
gt_info: List[dict]
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
datasets: List[DatasetInfo]
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
embed_model: BaseEmbedding = None,
|
||||
):
|
||||
self.evaluator = SemanticSimilarityEvaluator(
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
)
|
||||
|
||||
def set_metrics(
|
||||
self,
|
||||
bleu_avg: float = 0.0,
|
||||
bleu_1: float = 0.0,
|
||||
bleu_2: float = 0.0,
|
||||
bleu_3: float = 0.0,
|
||||
bleu_4: float = 0.0,
|
||||
rouge_l: float = 0.0,
|
||||
semantic_similarity: float = 0.0,
|
||||
recall: float = 0.0,
|
||||
hit_rate: float = 0.0,
|
||||
mrr: float = 0.0,
|
||||
length: float = 0.0,
|
||||
generated_text: str = None,
|
||||
ground_truth_text: str = None,
|
||||
question: str = None,
|
||||
):
|
||||
metrics = {
|
||||
"bleu-avg": bleu_avg,
|
||||
"bleu-1": bleu_1,
|
||||
"bleu-2": bleu_2,
|
||||
"bleu-3": bleu_3,
|
||||
"bleu-4": bleu_4,
|
||||
"rouge-L": rouge_l,
|
||||
"semantic similarity": semantic_similarity,
|
||||
"recall": recall,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr,
|
||||
"length": length,
|
||||
}
|
||||
|
||||
log = {
|
||||
"generated_text": generated_text,
|
||||
"ground_truth_text": ground_truth_text,
|
||||
"question": question,
|
||||
}
|
||||
|
||||
return {"metrics": metrics, "log": log}
|
||||
|
||||
def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]:
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
bleu = evaluate.load(path="bleu")
|
||||
results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
|
||||
|
||||
bleu_avg = results["bleu"]
|
||||
bleu1 = results["precisions"][0]
|
||||
bleu2 = results["precisions"][1]
|
||||
bleu3 = results["precisions"][2]
|
||||
bleu4 = results["precisions"][3]
|
||||
brevity_penalty = results["brevity_penalty"]
|
||||
|
||||
if with_penalty:
|
||||
return bleu_avg, bleu1, bleu2, bleu3, bleu4
|
||||
else:
|
||||
return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
|
||||
|
||||
def rougel_score(self, response: str, reference: str) -> float:
|
||||
# pip install rouge_score
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
rouge = evaluate.load(path="rouge")
|
||||
|
||||
results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
|
||||
score = results["rougeL"]
|
||||
return score
|
||||
|
||||
def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
|
||||
return total_recall / len(reference_docs)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
mrr_sum = 0.0
|
||||
|
||||
for i, doc in enumerate(reference_docs, start=1):
|
||||
for node in nodes:
|
||||
if node.text in doc:
|
||||
mrr_sum += 1.0 / i
|
||||
break
|
||||
|
||||
return mrr_sum / len(reference_docs) if reference_docs else 0.0
|
||||
|
||||
async def semantic_similarity(self, response: str, reference: str) -> float:
|
||||
result = await self.evaluator.aevaluate(
|
||||
response=response,
|
||||
reference=reference,
|
||||
)
|
||||
|
||||
return result.score
|
||||
|
||||
async def compute_metric(
|
||||
self,
|
||||
response: str = None,
|
||||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
recall = self.recall(nodes, reference_doc)
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
|
||||
rouge_l = self.rougel_score(response, reference)
|
||||
hit_rate = self.hit_rate(nodes, reference_doc)
|
||||
mrr = self.mean_reciprocal_rank(nodes, reference_doc)
|
||||
|
||||
similarity = await self.semantic_similarity(response, reference)
|
||||
|
||||
result = self.set_metrics(
|
||||
bleu_avg,
|
||||
bleu1,
|
||||
bleu2,
|
||||
bleu3,
|
||||
bleu4,
|
||||
rouge_l,
|
||||
similarity,
|
||||
recall,
|
||||
hit_rate,
|
||||
mrr,
|
||||
len(response),
|
||||
response,
|
||||
reference,
|
||||
question,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(ds_names: list[str] = ["all"]):
|
||||
infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix())
|
||||
dataset_config = DatasetConfig(
|
||||
datasets=[
|
||||
DatasetInfo(
|
||||
name=name,
|
||||
document_files=[
|
||||
(EXAMPLE_BENCHMARK_PATH / name / file).as_posix() for file in info["document_file"]
|
||||
],
|
||||
gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()),
|
||||
)
|
||||
for dataset_info in infos
|
||||
for name, info in dataset_info.items()
|
||||
if name in ds_names or "all" in ds_names
|
||||
]
|
||||
)
|
||||
|
||||
return dataset_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark = RAGBenchmark()
|
||||
answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。"
|
||||
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
|
||||
rougeL_score = benchmark.rougel_score(answer, ground_truth)
|
||||
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
|
||||
|
||||
logger.info(
|
||||
f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, "
|
||||
f"RougeL Score: {rougeL_score}, "
|
||||
f"Semantic Similarity: {similarity}"
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue