mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-01 11:56:24 +02:00
add object ranker
This commit is contained in:
parent
aaae00441b
commit
a22d7d8983
4 changed files with 134 additions and 2 deletions
|
|
@ -6,14 +6,24 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
|||
from llama_index.postprocessor.colbert_rerank import ColbertRerank
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import (
|
||||
BaseRankerConfig,
|
||||
ColbertRerankConfig,
|
||||
LLMRankerConfig,
|
||||
ObjectRankerConfig,
|
||||
)
|
||||
|
||||
|
||||
class RankerFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker}
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
ColbertRerankConfig: self._create_colbert_ranker,
|
||||
ObjectRankerConfig: self._create_object_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
|
||||
|
|
@ -30,6 +40,9 @@ class RankerFactory(ConfigBasedFactory):
|
|||
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
|
||||
return ColbertRerank(**config.model_dump())
|
||||
|
||||
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
|
||||
return ObjectSortPostprocessor(**config.model_dump())
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
|
|
|||
54
metagpt/rag/rankers/object_ranker.py
Normal file
54
metagpt/rag/rankers/object_ranker.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""Object ranker."""
|
||||
|
||||
import heapq
|
||||
import json
|
||||
from typing import Literal, Optional
|
||||
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class ObjectSortPostprocessor(BaseNodePostprocessor):
|
||||
"""Sorted by object's field, desc or asc.
|
||||
|
||||
Assumes nodes is list of ObjectNode with score.
|
||||
"""
|
||||
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
top_n: int = 5
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ObjectSortPostprocessor"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""Postprocess nodes."""
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
|
||||
if not nodes:
|
||||
return []
|
||||
|
||||
self._check_metadata(nodes[0].node)
|
||||
sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name]
|
||||
return self._get_sort_func()(self.top_n, nodes, key=sort_key)
|
||||
|
||||
def _get_sort_func(self):
|
||||
return heapq.nlargest if self.order == "desc" else heapq.nsmallest
|
||||
|
||||
def _check_metadata(self, node: ObjectNode):
|
||||
try:
|
||||
obj_dict = json.loads(node.metadata.get("obj_json"))
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}")
|
||||
|
||||
if self.field_name not in obj_dict:
|
||||
raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}")
|
||||
|
|
@ -101,6 +101,11 @@ class ColbertRerankConfig(BaseRankerConfig):
|
|||
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
|
||||
|
||||
|
||||
class ObjectRankerConfig(BaseRankerConfig):
|
||||
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
|
||||
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
|
|
|
|||
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
score: int
|
||||
|
||||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
|
||||
]
|
||||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
Loading…
Add table
Add a link
Reference in a new issue