mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
Merge branch 'geekan:main' into main
This commit is contained in:
commit
4c77d6c454
14 changed files with 3991 additions and 0 deletions
|
|
@ -51,6 +51,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
|||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
|
|
|
|||
3
metagpt/rag/benchmark/__init__.py
Normal file
3
metagpt/rag/benchmark/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.benchmark.base import RAGBenchmark
|
||||
|
||||
__all__ = ["RAGBenchmark"]
|
||||
201
metagpt/rag/benchmark/base.py
Normal file
201
metagpt/rag/benchmark/base.py
Normal file
|
|
@ -0,0 +1,201 @@
|
|||
import asyncio
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import evaluate
|
||||
import jieba
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.evaluation import SemanticSimilarityEvaluator
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import EXAMPLE_BENCHMARK_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.factories import get_rag_embedding
|
||||
from metagpt.utils.common import read_json_file
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
name: str
|
||||
document_files: List[str]
|
||||
gt_info: List[dict]
|
||||
|
||||
|
||||
class DatasetConfig(BaseModel):
|
||||
datasets: List[DatasetInfo]
|
||||
|
||||
|
||||
class RAGBenchmark:
|
||||
def __init__(
|
||||
self,
|
||||
embed_model: BaseEmbedding = None,
|
||||
):
|
||||
self.evaluator = SemanticSimilarityEvaluator(
|
||||
embed_model=embed_model or get_rag_embedding(),
|
||||
)
|
||||
|
||||
def set_metrics(
|
||||
self,
|
||||
bleu_avg: float = 0.0,
|
||||
bleu_1: float = 0.0,
|
||||
bleu_2: float = 0.0,
|
||||
bleu_3: float = 0.0,
|
||||
bleu_4: float = 0.0,
|
||||
rouge_l: float = 0.0,
|
||||
semantic_similarity: float = 0.0,
|
||||
recall: float = 0.0,
|
||||
hit_rate: float = 0.0,
|
||||
mrr: float = 0.0,
|
||||
length: float = 0.0,
|
||||
generated_text: str = None,
|
||||
ground_truth_text: str = None,
|
||||
question: str = None,
|
||||
):
|
||||
metrics = {
|
||||
"bleu-avg": bleu_avg,
|
||||
"bleu-1": bleu_1,
|
||||
"bleu-2": bleu_2,
|
||||
"bleu-3": bleu_3,
|
||||
"bleu-4": bleu_4,
|
||||
"rouge-L": rouge_l,
|
||||
"semantic similarity": semantic_similarity,
|
||||
"recall": recall,
|
||||
"hit_rate": hit_rate,
|
||||
"mrr": mrr,
|
||||
"length": length,
|
||||
}
|
||||
|
||||
log = {
|
||||
"generated_text": generated_text,
|
||||
"ground_truth_text": ground_truth_text,
|
||||
"question": question,
|
||||
}
|
||||
|
||||
return {"metrics": metrics, "log": log}
|
||||
|
||||
def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]:
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
bleu = evaluate.load(path="bleu")
|
||||
results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
|
||||
|
||||
bleu_avg = results["bleu"]
|
||||
bleu1 = results["precisions"][0]
|
||||
bleu2 = results["precisions"][1]
|
||||
bleu3 = results["precisions"][2]
|
||||
bleu4 = results["precisions"][3]
|
||||
brevity_penalty = results["brevity_penalty"]
|
||||
|
||||
if with_penalty:
|
||||
return bleu_avg, bleu1, bleu2, bleu3, bleu4
|
||||
else:
|
||||
return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
|
||||
|
||||
def rougel_score(self, response: str, reference: str) -> float:
|
||||
# pip install rouge_score
|
||||
f = lambda text: list(jieba.cut(text))
|
||||
rouge = evaluate.load(path="rouge")
|
||||
|
||||
results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
|
||||
score = results["rougeL"]
|
||||
return score
|
||||
|
||||
def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
|
||||
return total_recall / len(reference_docs)
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
if nodes:
|
||||
return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
|
||||
else:
|
||||
return 0.0
|
||||
|
||||
def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
|
||||
mrr_sum = 0.0
|
||||
|
||||
for i, doc in enumerate(reference_docs, start=1):
|
||||
for node in nodes:
|
||||
if node.text in doc:
|
||||
mrr_sum += 1.0 / i
|
||||
break
|
||||
|
||||
return mrr_sum / len(reference_docs) if reference_docs else 0.0
|
||||
|
||||
async def semantic_similarity(self, response: str, reference: str) -> float:
|
||||
result = await self.evaluator.aevaluate(
|
||||
response=response,
|
||||
reference=reference,
|
||||
)
|
||||
|
||||
return result.score
|
||||
|
||||
async def compute_metric(
|
||||
self,
|
||||
response: str = None,
|
||||
reference: str = None,
|
||||
nodes: list[NodeWithScore] = None,
|
||||
reference_doc: list[str] = None,
|
||||
question: str = None,
|
||||
):
|
||||
recall = self.recall(nodes, reference_doc)
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
|
||||
rouge_l = self.rougel_score(response, reference)
|
||||
hit_rate = self.hit_rate(nodes, reference_doc)
|
||||
mrr = self.mean_reciprocal_rank(nodes, reference_doc)
|
||||
|
||||
similarity = await self.semantic_similarity(response, reference)
|
||||
|
||||
result = self.set_metrics(
|
||||
bleu_avg,
|
||||
bleu1,
|
||||
bleu2,
|
||||
bleu3,
|
||||
bleu4,
|
||||
rouge_l,
|
||||
similarity,
|
||||
recall,
|
||||
hit_rate,
|
||||
mrr,
|
||||
len(response),
|
||||
response,
|
||||
reference,
|
||||
question,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def load_dataset(ds_names: list[str] = ["all"]):
|
||||
infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix())
|
||||
dataset_config = DatasetConfig(
|
||||
datasets=[
|
||||
DatasetInfo(
|
||||
name=name,
|
||||
document_files=[
|
||||
(EXAMPLE_BENCHMARK_PATH / name / file).as_posix() for file in info["document_file"]
|
||||
],
|
||||
gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()),
|
||||
)
|
||||
for dataset_info in infos
|
||||
for name, info in dataset_info.items()
|
||||
if name in ds_names or "all" in ds_names
|
||||
]
|
||||
)
|
||||
|
||||
return dataset_config
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark = RAGBenchmark()
|
||||
answer = "是的,根据提供的信息,2023年7月20日,应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作,应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金,中央财政也会提供适当的补助。救助资金将通过专账管理,并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏,并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息,并建立档案。此外,《规范》还强调了资金发放的具体方式和公开透明的要求。"
|
||||
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
|
||||
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
|
||||
rougeL_score = benchmark.rougel_score(answer, ground_truth)
|
||||
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
|
||||
|
||||
logger.info(
|
||||
f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, "
|
||||
f"RougeL Score: {rougeL_score}, "
|
||||
f"Semantic Similarity: {similarity}"
|
||||
)
|
||||
|
|
@ -11,6 +11,8 @@ from metagpt.rag.schema import (
|
|||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
CohereRerankConfig,
|
||||
BGERerankConfig
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
|
|||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
CohereRerankConfig: self._create_cohere_rerank,
|
||||
BGERerankConfig: self._create_bge_rerank,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
|
|
@ -45,6 +49,24 @@ class RankerFactory(ConfigBasedFactory):
|
|||
)
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.cohere_rerank import CohereRerank
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
|
||||
)
|
||||
return CohereRerank(**config.model_dump())
|
||||
|
||||
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
|
||||
try:
|
||||
from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
|
||||
)
|
||||
return FlagEmbeddingReranker(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
|
|
|
|||
|
|
@ -119,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
|
|||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="rerank-english-v3.0")
|
||||
api_key: str = Field(default="YOUR_COHERE_API")
|
||||
|
||||
|
||||
class BGERerankConfig(BaseRankerConfig):
|
||||
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
|
||||
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue