add README.md

This commit is contained in:
seehi 2024-07-16 19:13:10 +08:00
parent 0b604c42b5
commit 5693594afe
13 changed files with 238 additions and 401 deletions

View file

@ -3,9 +3,8 @@ import pytest
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import Experience
from metagpt.rag.engines import SimpleEngine
from metagpt.exp_pool.manager import Experience, ExperienceManager
from metagpt.exp_pool.schema import QueryType
class TestExperienceManager:
@ -15,50 +14,65 @@ class TestExperienceManager:
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock(spec=SimpleEngine)
engine = mocker.MagicMock()
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
engine._retriever = mocker.MagicMock()
engine._retriever._vector_store = mocker.MagicMock()
engine._retriever._vector_store._get = mocker.MagicMock(return_value=mocker.MagicMock(ids=[]))
engine._retriever._vector_store._collection = mocker.MagicMock()
engine._retriever._vector_store._collection.count = mocker.MagicMock(return_value=10)
return engine
@pytest.fixture
def mock_experience_manager(self, mock_config, mock_storage):
return ExperienceManager(config=mock_config, storage=mock_storage)
def exp_manager(self, mock_config, mock_storage):
manager = ExperienceManager(config=mock_config)
manager._storage = mock_storage
return manager
@pytest.fixture
def mock_experience(self):
return Experience(req="req", resp="resp")
def test_initialize_storage(self, mock_experience_manager, mock_storage):
assert mock_experience_manager.storage is mock_storage
def test_create_exp(self, mock_experience_manager, mock_experience):
mock_experience_manager.create_exp(mock_experience)
mock_experience_manager.storage.add_objs.assert_called_with([mock_experience])
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
mock_config.exp_pool.enable_write = False
mock_experience_manager.create_exp(mock_experience)
mock_experience_manager.storage.add_objs.assert_not_called()
def test_vector_store_property(self, exp_manager):
assert exp_manager.vector_store == exp_manager.storage._retriever._vector_store
@pytest.mark.asyncio
async def test_query_exps(self, mock_experience_manager, mocker):
req = "req"
resp = "resp"
tag = "test"
experiences = [Experience(req=req, resp=resp, tag="test"), Experience(req=req, resp=resp, tag="other")]
mock_experience_manager.storage.aretrieve.return_value = [
mocker.MagicMock(metadata={"obj": exp}) for exp in experiences
]
async def test_query_exps_with_exact_match(self, exp_manager, mocker):
req = "exact query"
exp1 = Experience(req=req, resp="response1")
exp2 = Experience(req="different query", resp="response2")
result = await mock_experience_manager.query_exps(req, tag)
mock_node1 = mocker.MagicMock(metadata={"obj": exp1})
mock_node2 = mocker.MagicMock(metadata={"obj": exp2})
exp_manager.storage.aretrieve.return_value = [mock_node1, mock_node2]
result = await exp_manager.query_exps(req, query_type=QueryType.EXACT)
assert len(result) == 1
assert result[0].tag == "test"
assert result[0].req == req
@pytest.mark.asyncio
async def test_query_exps_no_read_permission(self, mock_experience_manager, mock_config):
async def test_query_exps_with_tag_filter(self, exp_manager, mocker):
tag = "test_tag"
exp1 = Experience(req="query1", resp="response1", tag=tag)
exp2 = Experience(req="query2", resp="response2", tag="other_tag")
mock_node1 = mocker.MagicMock(metadata={"obj": exp1})
mock_node2 = mocker.MagicMock(metadata={"obj": exp2})
exp_manager.storage.aretrieve.return_value = [mock_node1, mock_node2]
result = await exp_manager.query_exps("query", tag=tag)
assert len(result) == 1
assert result[0].tag == tag
def test_get_exps_count(self, exp_manager):
assert exp_manager.get_exps_count() == 10
def test_create_exp_write_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_write = False
exp = Experience(req="test", resp="response")
exp_manager.create_exp(exp)
exp_manager.storage.add_objs.assert_not_called()
@pytest.mark.asyncio
async def test_query_exps_read_disabled(self, exp_manager, mock_config):
mock_config.exp_pool.enable_read = False
result = await mock_experience_manager.query_exps("query")
result = await exp_manager.query_exps("query")
assert result == []

View file

@ -22,7 +22,7 @@ class TestRoleZeroSerializer:
return [{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
def test_serialize_req_empty_input(self, serializer: RoleZeroSerializer):
assert serializer.serialize_req([]) == ""
assert serializer.serialize_req(req=[]) == ""
def test_serialize_req_with_content(self, serializer: RoleZeroSerializer, last_item: dict):
req = [
@ -33,7 +33,7 @@ class TestRoleZeroSerializer:
expected_output = json.dumps(
[{"role": "user", "content": "Command Editor.read executed: file_path=test.py"}, last_item]
)
assert serializer.serialize_req(req) == expected_output
assert serializer.serialize_req(req=req) == expected_output
def test_filter_req(self, serializer: RoleZeroSerializer):
req = [

View file

@ -8,28 +8,28 @@ class TestSimpleSerializer:
def serializer(self):
return SimpleSerializer()
def test_serialize_req(self, serializer):
def test_serialize_req(self, serializer: SimpleSerializer):
# Test with different types of input
assert serializer.serialize_req(123) == "123"
assert serializer.serialize_req("test") == "test"
assert serializer.serialize_req([1, 2, 3]) == "[1, 2, 3]"
assert serializer.serialize_req({"a": 1}) == "{'a': 1}"
assert serializer.serialize_req(req=123) == "123"
assert serializer.serialize_req(req="test") == "test"
assert serializer.serialize_req(req=[1, 2, 3]) == "[1, 2, 3]"
assert serializer.serialize_req(req={"a": 1}) == "{'a': 1}"
def test_serialize_resp(self, serializer):
def test_serialize_resp(self, serializer: SimpleSerializer):
# Test with different types of input
assert serializer.serialize_resp(456) == "456"
assert serializer.serialize_resp("response") == "response"
assert serializer.serialize_resp([4, 5, 6]) == "[4, 5, 6]"
assert serializer.serialize_resp({"b": 2}) == "{'b': 2}"
def test_deserialize_resp(self, serializer):
def test_deserialize_resp(self, serializer: SimpleSerializer):
# Test with different types of input
assert serializer.deserialize_resp("789") == "789"
assert serializer.deserialize_resp("test_response") == "test_response"
assert serializer.deserialize_resp("[7, 8, 9]") == "[7, 8, 9]"
assert serializer.deserialize_resp("{'c': 3}") == "{'c': 3}"
def test_roundtrip(self, serializer):
def test_roundtrip(self, serializer: SimpleSerializer):
# Test serialization and deserialization roundtrip
original = "test_roundtrip"
serialized = serializer.serialize_resp(original)
@ -37,8 +37,8 @@ class TestSimpleSerializer:
assert deserialized == original
@pytest.mark.parametrize("input_value", [123, "test", [1, 2, 3], {"a": 1}, None])
def test_serialize_req_types(self, serializer, input_value):
def test_serialize_req_types(self, serializer: SimpleSerializer, input_value):
# Test serialize_req with various input types
result = serializer.serialize_req(input_value)
result = serializer.serialize_req(req=input_value)
assert isinstance(result, str)
assert result == str(input_value)