Merge branch 'geekan:main' into main

This commit is contained in:
usamimeri_renko 2024-04-26 17:33:11 +08:00 committed by GitHub
commit 4c77d6c454
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 3991 additions and 0 deletions

View file

@ -51,6 +51,7 @@ DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
DATA_PATH = METAGPT_ROOT / "data"
EXAMPLE_BENCHMARK_PATH = EXAMPLE_PATH / "data/rag_bm"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"

View file

@ -0,0 +1,3 @@
from metagpt.rag.benchmark.base import RAGBenchmark
__all__ = ["RAGBenchmark"]

View file

@ -0,0 +1,201 @@
import asyncio
from typing import List, Tuple, Union
import evaluate
import jieba
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.evaluation import SemanticSimilarityEvaluator
from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel
from metagpt.const import EXAMPLE_BENCHMARK_PATH
from metagpt.logs import logger
from metagpt.rag.factories import get_rag_embedding
from metagpt.utils.common import read_json_file
class DatasetInfo(BaseModel):
name: str
document_files: List[str]
gt_info: List[dict]
class DatasetConfig(BaseModel):
datasets: List[DatasetInfo]
class RAGBenchmark:
def __init__(
self,
embed_model: BaseEmbedding = None,
):
self.evaluator = SemanticSimilarityEvaluator(
embed_model=embed_model or get_rag_embedding(),
)
def set_metrics(
self,
bleu_avg: float = 0.0,
bleu_1: float = 0.0,
bleu_2: float = 0.0,
bleu_3: float = 0.0,
bleu_4: float = 0.0,
rouge_l: float = 0.0,
semantic_similarity: float = 0.0,
recall: float = 0.0,
hit_rate: float = 0.0,
mrr: float = 0.0,
length: float = 0.0,
generated_text: str = None,
ground_truth_text: str = None,
question: str = None,
):
metrics = {
"bleu-avg": bleu_avg,
"bleu-1": bleu_1,
"bleu-2": bleu_2,
"bleu-3": bleu_3,
"bleu-4": bleu_4,
"rouge-L": rouge_l,
"semantic similarity": semantic_similarity,
"recall": recall,
"hit_rate": hit_rate,
"mrr": mrr,
"length": length,
}
log = {
"generated_text": generated_text,
"ground_truth_text": ground_truth_text,
"question": question,
}
return {"metrics": metrics, "log": log}
def bleu_score(self, response: str, reference: str, with_penalty=False) -> Union[float, Tuple[float]]:
f = lambda text: list(jieba.cut(text))
bleu = evaluate.load(path="bleu")
results = bleu.compute(predictions=[response], references=[[reference]], tokenizer=f)
bleu_avg = results["bleu"]
bleu1 = results["precisions"][0]
bleu2 = results["precisions"][1]
bleu3 = results["precisions"][2]
bleu4 = results["precisions"][3]
brevity_penalty = results["brevity_penalty"]
if with_penalty:
return bleu_avg, bleu1, bleu2, bleu3, bleu4
else:
return 0.0 if brevity_penalty == 0 else bleu_avg / brevity_penalty, bleu1, bleu2, bleu3, bleu4
def rougel_score(self, response: str, reference: str) -> float:
# pip install rouge_score
f = lambda text: list(jieba.cut(text))
rouge = evaluate.load(path="rouge")
results = rouge.compute(predictions=[response], references=[[reference]], tokenizer=f, rouge_types=["rougeL"])
score = results["rougeL"]
return score
def recall(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
if nodes:
total_recall = sum(any(node.text in doc for node in nodes) for doc in reference_docs)
return total_recall / len(reference_docs)
else:
return 0.0
def hit_rate(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
if nodes:
return 1.0 if any(node.text in doc for doc in reference_docs for node in nodes) else 0.0
else:
return 0.0
def mean_reciprocal_rank(self, nodes: list[NodeWithScore], reference_docs: list[str]) -> float:
mrr_sum = 0.0
for i, doc in enumerate(reference_docs, start=1):
for node in nodes:
if node.text in doc:
mrr_sum += 1.0 / i
break
return mrr_sum / len(reference_docs) if reference_docs else 0.0
async def semantic_similarity(self, response: str, reference: str) -> float:
result = await self.evaluator.aevaluate(
response=response,
reference=reference,
)
return result.score
async def compute_metric(
self,
response: str = None,
reference: str = None,
nodes: list[NodeWithScore] = None,
reference_doc: list[str] = None,
question: str = None,
):
recall = self.recall(nodes, reference_doc)
bleu_avg, bleu1, bleu2, bleu3, bleu4 = self.bleu_score(response, reference)
rouge_l = self.rougel_score(response, reference)
hit_rate = self.hit_rate(nodes, reference_doc)
mrr = self.mean_reciprocal_rank(nodes, reference_doc)
similarity = await self.semantic_similarity(response, reference)
result = self.set_metrics(
bleu_avg,
bleu1,
bleu2,
bleu3,
bleu4,
rouge_l,
similarity,
recall,
hit_rate,
mrr,
len(response),
response,
reference,
question,
)
return result
@staticmethod
def load_dataset(ds_names: list[str] = ["all"]):
infos = read_json_file((EXAMPLE_BENCHMARK_PATH / "dataset_info.json").as_posix())
dataset_config = DatasetConfig(
datasets=[
DatasetInfo(
name=name,
document_files=[
(EXAMPLE_BENCHMARK_PATH / name / file).as_posix() for file in info["document_file"]
],
gt_info=read_json_file((EXAMPLE_BENCHMARK_PATH / name / info["gt_file"]).as_posix()),
)
for dataset_info in infos
for name, info in dataset_info.items()
if name in ds_names or "all" in ds_names
]
)
return dataset_config
if __name__ == "__main__":
benchmark = RAGBenchmark()
answer = "是的根据提供的信息2023年7月20日应急管理部和财政部确实联合发布了《因灾倒塌、损坏住房恢复重建救助工作规范》的通知。这份《规范》旨在进一步规范因灾倒塌、损坏住房的恢复重建救助相关工作。它明确了地方各级政府负责实施救助工作应急管理部和财政部则负责统筹指导。地方财政应安排足够的资金中央财政也会提供适当的补助。救助资金将通过专账管理并采取特定的管理方式。救助对象是那些因自然灾害导致住房倒塌或损坏并向政府提出申请且符合条件的受灾家庭。相关部门将组织调查统计救助对象信息并建立档案。此外《规范》还强调了资金发放的具体方式和公开透明的要求。"
ground_truth = "“启明行动”是为了防控儿童青少年的近视问题,并发布了《防控儿童青少年近视核心知识十条》。"
bleu_avg, bleu1, bleu2, bleu3, bleu4 = benchmark.bleu_score(answer, ground_truth)
rougeL_score = benchmark.rougel_score(answer, ground_truth)
similarity = asyncio.run(benchmark.SemanticSimilarity(answer, ground_truth))
logger.info(
f"BLEU Scores: bleu_avg = {bleu_avg}, bleu1 = {bleu1}, bleu2 = {bleu2}, bleu3 = {bleu3}, bleu4 = {bleu4}, "
f"RougeL Score: {rougeL_score}, "
f"Semantic Similarity: {similarity}"
)

View file

@ -11,6 +11,8 @@ from metagpt.rag.schema import (
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
CohereRerankConfig,
BGERerankConfig
)
@ -22,6 +24,8 @@ class RankerFactory(ConfigBasedFactory):
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
CohereRerankConfig: self._create_cohere_rerank,
BGERerankConfig: self._create_bge_rerank,
}
super().__init__(creators)
@ -45,6 +49,24 @@ class RankerFactory(ConfigBasedFactory):
)
return ColbertRerank(**config.model_dump())
def _create_cohere_rerank(self, config: CohereRerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.cohere_rerank import CohereRerank
except ImportError:
raise ImportError(
"`llama-index-postprocessor-cohere-rerank` package not found, please run `pip install llama-index-postprocessor-cohere-rerank`"
)
return CohereRerank(**config.model_dump())
def _create_bge_rerank(self, config: BGERerankConfig, **kwargs) -> LLMRerank:
try:
from llama_index.postprocessor.flag_embedding_reranker import FlagEmbeddingReranker
except ImportError:
raise ImportError(
"`llama-index-postprocessor-flag-embedding-reranker` package not found, please run `pip install llama-index-postprocessor-flag-embedding-reranker`"
)
return FlagEmbeddingReranker(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())

View file

@ -119,6 +119,16 @@ class ColbertRerankConfig(BaseRankerConfig):
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class CohereRerankConfig(BaseRankerConfig):
model: str = Field(default="rerank-english-v3.0")
api_key: str = Field(default="YOUR_COHERE_API")
class BGERerankConfig(BaseRankerConfig):
model: str = Field(default="BAAI/bge-reranker-large", description="BAAI Reranker model name.")
use_fp16: bool = Field(default=True, description="Whether to use fp16 for inference.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")