add object ranker

This commit is contained in:
seehi 2024-03-26 16:36:45 +08:00
parent aaae00441b
commit a22d7d8983
4 changed files with 134 additions and 2 deletions

View file

@ -6,14 +6,24 @@ from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.colbert_rerank import ColbertRerank
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, ColbertRerankConfig, LLMRankerConfig
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import (
BaseRankerConfig,
ColbertRerankConfig,
LLMRankerConfig,
ObjectRankerConfig,
)
class RankerFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {LLMRankerConfig: self._create_llm_ranker, ColbertRerankConfig: self._create_colbert_ranker}
creators = {
LLMRankerConfig: self._create_llm_ranker,
ColbertRerankConfig: self._create_colbert_ranker,
ObjectRankerConfig: self._create_object_ranker,
}
super().__init__(creators)
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
@ -30,6 +40,9 @@ class RankerFactory(ConfigBasedFactory):
def _create_colbert_ranker(self, config: ColbertRerankConfig, **kwargs) -> LLMRerank:
return ColbertRerank(**config.model_dump())
def _create_object_ranker(self, config: ObjectRankerConfig, **kwargs) -> LLMRerank:
return ObjectSortPostprocessor(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)

View file

@ -0,0 +1,54 @@
"""Object ranker."""
import heapq
import json
from typing import Literal, Optional
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import Field
from metagpt.rag.schema import ObjectNode
class ObjectSortPostprocessor(BaseNodePostprocessor):
"""Sorted by object's field, desc or asc.
Assumes nodes is list of ObjectNode with score.
"""
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
top_n: int = 5
@classmethod
def class_name(cls) -> str:
return "ObjectSortPostprocessor"
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> list[NodeWithScore]:
"""Postprocess nodes."""
if query_bundle is None:
raise ValueError("Missing query bundle in extra info.")
if not nodes:
return []
self._check_metadata(nodes[0].node)
sort_key = lambda node: json.loads(node.node.metadata["obj_json"])[self.field_name]
return self._get_sort_func()(self.top_n, nodes, key=sort_key)
def _get_sort_func(self):
return heapq.nlargest if self.order == "desc" else heapq.nsmallest
def _check_metadata(self, node: ObjectNode):
try:
obj_dict = json.loads(node.metadata.get("obj_json"))
except Exception as e:
raise ValueError(f"Invalid object json in metadata: {node.metadata}, error: {e}")
if self.field_name not in obj_dict:
raise ValueError(f"Field '{self.field_name}' not found in object: {obj_dict}")

View file

@ -101,6 +101,11 @@ class ColbertRerankConfig(BaseRankerConfig):
keep_retrieval_score: bool = Field(default=False, description="Whether to keep the retrieval score in metadata.")
class ObjectRankerConfig(BaseRankerConfig):
field_name: str = Field(..., description="field name of the object, field's value must can be compared.")
order: Literal["desc", "asc"] = Field(default="desc", description="the direction of order.")
class BaseIndexConfig(BaseModel):
"""Common config for index.

View file

@ -0,0 +1,60 @@
import json
import pytest
from llama_index.core.schema import NodeWithScore, QueryBundle
from pydantic import BaseModel
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
from metagpt.rag.schema import ObjectNode
class Record(BaseModel):
score: int
class TestObjectSortPostprocessor:
@pytest.fixture
def nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
]
return nodes
@pytest.fixture
def query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
def test_sort_descending(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
def test_sort_ascending(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
def test_top_n_limit(self, nodes_with_scores, query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
def test_invalid_json_metadata(self, query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes, query_bundle)
def test_missing_query_bundle(self, nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
def test_field_not_found_in_object(self):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes)