mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-28 10:26:32 +02:00
add object ranker
This commit is contained in:
parent
aaae00441b
commit
a22d7d8983
4 changed files with 134 additions and 2 deletions
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
60
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
score: int
|
||||
|
||||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
|
||||
]
|
||||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, nodes_with_scores, query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes)
|
||||
Loading…
Add table
Add a link
Reference in a new issue