feat: merge main

This commit is contained in:
莘权 马 2024-03-29 11:20:10 +08:00
commit 33ca44739d
199 changed files with 7620 additions and 469 deletions

View file

@ -4,8 +4,8 @@
from pathlib import Path
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
from metagpt.environment.android.android_ext_env import AndroidExtEnv
from metagpt.environment.android.const import ADB_EXEC_FAIL
def mock_device_shape(self, adb_cmd: str) -> str:
@ -34,9 +34,7 @@ def mock_write_read_operation(self, adb_cmd: str) -> str:
def test_android_ext_env(mocker):
device_id = "emulator-5554"
mocker.patch(
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape
)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape)
ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/")
assert ext_env.adb_prefix == f"adb -s {device_id} "
@ -46,25 +44,21 @@ def test_android_ext_env(mocker):
assert ext_env.device_shape == (720, 1080)
mocker.patch(
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
)
assert ext_env.device_shape == (0, 0)
mocker.patch(
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices
)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices)
assert ext_env.list_devices() == [device_id]
mocker.patch(
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot
)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot)
assert ext_env.get_screenshot("screenshot_xxxx-xx-xx", "/data/") == Path("/data/screenshot_xxxx-xx-xx.png")
mocker.patch("metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
assert ext_env.get_xml("xml_xxxx-xx-xx", "/data/") == Path("/data/xml_xxxx-xx-xx.xml")
mocker.patch(
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
)
res = "OK"
assert ext_env.system_back() == res

View file

@ -3,8 +3,8 @@
# @Desc : the unittest of MinecraftExtEnv
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
from metagpt.environment.minecraft.const import MC_CKPT_DIR
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
def test_minecraft_ext_env():

View file

@ -4,12 +4,18 @@
from pathlib import Path
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
StanfordTownExtEnv,
from metagpt.environment.stanford_town.env_space import (
EnvAction,
EnvActionType,
EnvObsParams,
EnvObsType,
)
from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv
maze_asset_path = (
Path(__file__).absolute().parent.joinpath("..", "..", "..", "data", "environment", "stanford_town", "the_ville")
Path(__file__)
.absolute()
.parent.joinpath("..", "..", "..", "..", "metagpt/ext/stanford_town/static_dirs/assets/the_ville")
)
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
ext_env.add_tiles_event(tile[1], tile[0], event=event)
ext_env.add_event_from_tile(event, tile)
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
def test_stanford_town_ext_env_observe_step():
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
obs, info = ext_env.reset()
assert len(info) == 0
assert len(obs["address_tiles"]) == 306
tile = (58, 9)
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
assert obs == "the Ville"
action = ext_env.action_space.sample()
assert len(action) == 4
assert len(action["event"]) == 4
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1

View file

@ -2,6 +2,8 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of ExtEnv&Env
from typing import Any, Optional
import pytest
from metagpt.environment.api.env_api import EnvAPIAbstract
@ -12,11 +14,26 @@ from metagpt.environment.base_env import (
mark_as_readable,
mark_as_writeable,
)
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
class ForTestEnv(Environment):
value: int = 0
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict[str, Any]] = None,
) -> tuple[dict[str, Any], dict[str, Any]]:
pass
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
pass
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
pass
@mark_as_readable
def read_api_no_param(self):
return self.value
@ -44,11 +61,11 @@ async def test_ext_env():
assert len(apis) > 0
assert len(apis["read_api"]) == 3
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
assert env.value == 15
with pytest.raises(ValueError):
await env.observe("not_exist_api")
await env.read_from_api("not_exist_api")
assert await env.observe("read_api_no_param") == 15
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
assert await env.read_from_api("read_api_no_param") == 15
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10

View file

@ -2,7 +2,7 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of WerewolfExtEnv
from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv
from metagpt.environment.werewolf.werewolf_ext_env import RoleState, WerewolfExtEnv
from metagpt.roles.role import Role

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,79 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of actions/gen_action_details.py
import pytest
from metagpt.environment import StanfordTownEnv
from metagpt.environment.api.env_api import EnvAPIAbstract
from metagpt.ext.stanford_town.actions.gen_action_details import (
GenActionArena,
GenActionDetails,
GenActionObject,
GenActionSector,
GenActObjDescription,
)
from metagpt.ext.stanford_town.roles.st_role import STRole
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
@pytest.mark.asyncio
async def test_gen_action_details():
role = STRole(
name="Klaus Mueller",
start_time="February 13, 2023",
curr_time="February 13, 2023, 00:00:00",
sim_code="base_the_ville_isabella_maria_klaus",
)
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
await role.init_curr_tile()
act_desp = "sleeping"
act_dura = "120"
access_tile = await role.rc.env.read_from_api(
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
)
act_world = access_tile["world"]
assert act_world == "the Ville"
sector = await GenActionSector().run(role, access_tile, act_desp)
arena = await GenActionArena().run(role, act_desp, act_world, sector)
temp_address = f"{act_world}:{sector}:{arena}"
obj = await GenActionObject().run(role, act_desp, temp_address)
act_obj_desp = await GenActObjDescription().run(role, obj, act_desp)
result_dict = await GenActionDetails().run(role, act_desp, act_dura)
# gen_action_sector
assert isinstance(sector, str)
assert sector in role.s_mem.get_str_accessible_sectors(act_world)
# gen_action_arena
assert isinstance(arena, str)
assert arena in role.s_mem.get_str_accessible_sector_arenas(f"{act_world}:{sector}")
# gen_action_obj
assert isinstance(obj, str)
assert obj in role.s_mem.get_str_accessible_arena_game_objects(temp_address)
if result_dict:
for key in [
"action_address",
"action_duration",
"action_description",
"action_pronunciatio",
"action_event",
"chatting_with",
"chat",
"chatting_with_buffer",
"chatting_end_time",
"act_obj_description",
"act_obj_pronunciatio",
"act_obj_event",
]:
assert key in result_dict
assert result_dict["action_address"] == f"{temp_address}:{obj}"
assert result_dict["action_duration"] == int(act_dura)
assert result_dict["act_obj_description"] == act_obj_desp

View file

@ -0,0 +1,15 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of actions/summarize_conv
import pytest
from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv
@pytest.mark.asyncio
async def test_summarize_conv():
conv = [("Role_A", "what's the weather today?"), ("Role_B", "It looks pretty good, and I will take a walk then.")]
output = await SummarizeConv().run(conv)
assert "weather" in output

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,89 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of AgentMemory
from datetime import datetime, timedelta
import pytest
from metagpt.ext.stanford_town.memory.agent_memory import AgentMemory
from metagpt.ext.stanford_town.memory.retrieve import agent_retrieve
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
from metagpt.logs import logger
"""
memory测试思路
1. Basic Memory测试
2. Agent Memory测试
2.1 Load & Save方法测试; Load方法中使用了add方法验证Load即可验证所有add
2.2 Get方法测试
"""
memory_easy_storage_path = STORAGE_PATH.joinpath(
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
)
memroy_chat_storage_path = STORAGE_PATH.joinpath(
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
)
memory_save_easy_test_path = STORAGE_PATH.joinpath(
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
)
memory_save_chat_test_path = STORAGE_PATH.joinpath(
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
)
class TestAgentMemory:
@pytest.fixture
def agent_memory(self):
# 创建一个AgentMemory实例并返回可以在所有测试用例中共享
test_agent_memory = AgentMemory()
test_agent_memory.set_mem_path(memroy_chat_storage_path)
return test_agent_memory
def test_load(self, agent_memory):
logger.info(f"存储路径为:{agent_memory.memory_saved}")
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
assert agent_memory.embeddings is not None
def test_save(self, agent_memory):
try:
agent_memory.save(memory_save_chat_test_path)
logger.info("成功存储")
except:
pass
def test_summary_function(self, agent_memory):
logger.info(f"event长度为{len(agent_memory.event_list)}")
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
logger.info(f"chat长度为{len(agent_memory.chat_list)}")
result1 = agent_memory.get_summarized_latest_events(4)
logger.info(f"总结最近事件结果为:{result1}")
def test_get_last_chat_function(self, agent_memory):
result2 = agent_memory.get_last_chat("customers")
logger.info(f"上一次对话是{result2}")
def test_retrieve_function(self, agent_memory):
focus_points = ["who i love?"]
retrieved = dict()
for focal_pt in focus_points:
nodes = [
[i.last_accessed, i]
for i in agent_memory.event_list + agent_memory.thought_list
if "idle" not in i.embedding_key
]
nodes = sorted(nodes, key=lambda x: x[0])
nodes = [i for created, i in nodes]
results = agent_retrieve(agent_memory, datetime.now() - timedelta(days=120), 0.99, focal_pt, nodes, 5)
final_result = []
for n in results:
for i in agent_memory.storage:
if i.memory_id == n:
i.last_accessed = datetime.now() - timedelta(days=120)
final_result.append(i)
retrieved[focal_pt] = final_result
logger.info(f"检索结果为{retrieved}")

View file

@ -0,0 +1,76 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of BasicMemory
from datetime import datetime, timedelta
import pytest
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
from metagpt.logs import logger
"""
memory测试思路
1. Basic Memory测试
2. Agent Memory测试
2.1 Load & Save方法测试
2.2 Add方法测试
2.3 Get方法测试
"""
# Create some sample BasicMemory instances
memory1 = BasicMemory(
memory_id="1",
memory_count=1,
type_count=1,
memory_type="event",
depth=1,
created=datetime.now(),
expiration=datetime.now() + timedelta(days=30),
subject="Subject1",
predicate="Predicate1",
object="Object1",
content="This is content 1",
embedding_key="embedding_key_1",
poignancy=1,
keywords=["keyword1", "keyword2"],
filling=["memory_id_2"],
)
memory2 = BasicMemory(
memory_id="2",
memory_count=2,
type_count=2,
memory_type="thought",
depth=2,
created=datetime.now(),
expiration=datetime.now() + timedelta(days=30),
subject="Subject2",
predicate="Predicate2",
object="Object2",
content="This is content 2",
embedding_key="embedding_key_2",
poignancy=2,
keywords=["keyword3", "keyword4"],
filling=[],
)
@pytest.fixture
def basic_mem_set():
basic_mem2 = memory2
yield basic_mem2
def test_basic_mem_function(basic_mem_set):
a, b, c = basic_mem_set.summary()
logger.info(f"{a}{b}{c}")
assert a == "Subject2"
def test_basic_mem_save(basic_mem_set):
result = basic_mem_set.save_to_dict()
logger.info(f"save结果为{result}")
if __name__ == "__main__":
pytest.main()

View file

@ -0,0 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of MemoryTree
from metagpt.ext.stanford_town.memory.spatial_memory import MemoryTree
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
def test_spatial_memory():
f_path = STORAGE_PATH.joinpath(
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/spatial_memory.json"
)
x = MemoryTree()
x.set_mem_path(f_path)
assert x.tree
assert "the Ville" in x.tree
assert "Isabella Rodriguez's apartment" in x.get_str_accessible_sectors("the Ville")

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,67 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of roles conversation
from typing import Tuple
import pytest
from metagpt.environment import StanfordTownEnv
from metagpt.ext.stanford_town.plan.converse import agent_conversation
from metagpt.ext.stanford_town.roles.st_role import STRole
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH, STORAGE_PATH
from metagpt.ext.stanford_town.utils.mg_ga_transform import get_reverie_meta
from metagpt.ext.stanford_town.utils.utils import copy_folder
async def init_two_roles(fork_sim_code: str = "base_the_ville_isabella_maria_klaus") -> Tuple["STRole"]:
sim_code = "unittest_sim"
copy_folder(str(STORAGE_PATH.joinpath(fork_sim_code)), str(STORAGE_PATH.joinpath(sim_code)))
reverie_meta = get_reverie_meta(fork_sim_code)
role_ir_name = "Isabella Rodriguez"
role_km_name = "Klaus Mueller"
env = StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH)
role_ir = STRole(
name=role_ir_name,
sim_code=sim_code,
profile=role_ir_name,
step=reverie_meta.get("step"),
start_time=reverie_meta.get("start_date"),
curr_time=reverie_meta.get("curr_time"),
sec_per_step=reverie_meta.get("sec_per_step"),
)
role_ir.set_env(env)
await role_ir.init_curr_tile()
role_km = STRole(
name=role_km_name,
sim_code=sim_code,
profile=role_km_name,
step=reverie_meta.get("step"),
start_time=reverie_meta.get("start_date"),
curr_time=reverie_meta.get("curr_time"),
sec_per_step=reverie_meta.get("sec_per_step"),
)
role_km.set_env(env)
await role_km.init_curr_tile()
return role_ir, role_km
@pytest.mark.asyncio
async def test_agent_conversation():
role_ir, role_km = await init_two_roles()
curr_chat = await agent_conversation(role_ir, role_km, conv_rounds=2)
assert len(curr_chat) % 2 == 0
meet = False
for conv in curr_chat:
if "Valentine's Day party" in conv[1]:
# conv[0] speaker, conv[1] utterance
meet = True
assert meet

View file

@ -0,0 +1,25 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of st_plan
import pytest
from metagpt.ext.stanford_town.plan.st_plan import _choose_retrieved, _should_react
from tests.metagpt.ext.stanford_town.plan.test_conversation import init_two_roles
@pytest.mark.asyncio
async def test_should_react():
role_ir, role_km = await init_two_roles()
roles = {role_ir.name: role_ir, role_km.name: role_km}
role_ir.scratch.act_address = "mock data"
observed = await role_ir.observe()
retrieved = role_ir.retrieve(observed)
focused_event = _choose_retrieved(role_ir.name, retrieved)
if focused_event:
reaction_mode = await _should_react(role_ir, focused_event, roles) # chat with Isabella Rodriguez
assert not reaction_mode

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,26 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of STRole
import pytest
from metagpt.environment import StanfordTownEnv
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
from metagpt.ext.stanford_town.roles.st_role import STRole
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
@pytest.mark.asyncio
async def test_observe():
role = STRole(
sim_code="base_the_ville_isabella_maria_klaus",
start_time="February 13, 2023",
curr_time="February 13, 2023, 00:00:00",
)
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
await role.init_curr_tile()
ret_events = await role.observe()
assert ret_events
for event in ret_events:
assert isinstance(event, BasicMemory)

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of reflection
import pytest
from metagpt.environment import StanfordTownEnv
from metagpt.ext.stanford_town.actions.run_reflect_action import (
AgentEventTriple,
AgentFocusPt,
AgentInsightAndGuidance,
)
from metagpt.ext.stanford_town.roles.st_role import STRole
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
@pytest.mark.asyncio
async def test_reflect():
"""
init STRole form local json, set sim_code(path),curr_time & start_time
"""
role = STRole(
sim_code="base_the_ville_isabella_maria_klaus",
start_time="February 13, 2023",
curr_time="February 13, 2023, 00:00:00",
)
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
role.init_curr_tile()
run_focus = AgentFocusPt()
statements = ""
await run_focus.run(role, statements, n=3)
"""
这里有通过测试的结果但是更多时候LLM生成的结果缺少了because of考虑修改一下prompt
result = {'Klaus Mueller and Maria Lopez have a close relationship because they have been friends for a long time and have a strong bond': [1, 2, 5, 9, 11, 14], 'Klaus Mueller has a crush on Maria Lopez': [8, 15, 24], 'Klaus Mueller is academically inclined and actively researching a topic': [13, 20], 'Klaus Mueller is socially active and acquainted with Isabella Rodriguez': [17, 21, 22], 'Klaus Mueller is organized and prepared': [19]}
"""
run_insight = AgentInsightAndGuidance()
statements = "[user: Klaus Mueller has a close relationship with Maria Lopez, user:s Mueller and Maria Lopez have a close relationship, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller and Maria Lopez have a strong relationship, user: Klaus Mueller is a dormmate of Maria Lopez., user: Klaus Mueller and Maria Lopez have a strong bond, user: Klaus Mueller has a crush on Maria Lopez, user: Klaus Mueller and Maria Lopez have been friends for more than 2 years., user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller Maria Lopez is heading off to college., user: Klaus Mueller and Maria Lopez have a close relationship, user: Klaus Mueller is actively researching a topic, user: Klaus Mueller is close friends and classmates with Maria Lopez., user: Klaus Mueller is socially active, user: Klaus Mueller has a crush on Maria Lopez., user: Klaus Mueller and Maria Lopez have been friends for a long time, user: Klaus Mueller is academically inclined, user: For Klaus Mueller's planning: should remember to ask Maria Lopez about her research paper, as she found it interesting that he mentioned it., user: Klaus Mueller is acquainted with Isabella Rodriguez, user: Klaus Mueller is organized and prepared, user: Maria Lopez is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is a student, user: Klaus Mueller is a student, user: Klaus Mueller is conversing about two friends named Klaus Mueller and Maria Lopez discussing their morning plans and progress on a research paper before Maria heads off to college., user: Klaus Mueller is socially active, user: Klaus Mueller is socially active, user: Klaus Mueller is socially active and acquainted with Isabella Rodriguez, user: Klaus Mueller has a crush on Maria Lopez]"
await run_insight.run(role, statements, n=5)
run_triple = AgentEventTriple()
statements = "(Klaus Mueller is academically inclined)"
await run_triple.run(statements, role)
role.scratch.importance_trigger_curr = -1
role.reflect()

View file

@ -1,12 +1,26 @@
import json
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Document, TextNode
from llama_index.core.embeddings import MockEmbedding
from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
class TestSimpleEngine:
@pytest.fixture
def mock_llm(self):
return MockLLM()
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@ -54,7 +68,7 @@ class TestSimpleEngine:
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
# Exec
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
@ -65,7 +79,7 @@ class TestSimpleEngine:
ranker_configs=ranker_configs,
)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
@ -75,6 +89,68 @@ class TestSimpleEngine:
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
def test_from_docs_without_file(self):
with pytest.raises(ValueError):
SimpleEngine.from_docs()
def test_from_objs(self, mock_llm, mock_embedding):
# Mock
class MockRAGObject:
def rag_key(self):
return "key"
def model_dump_json(self):
return "{}"
objs = [MockRAGObject()]
# Setup
retriever_configs = []
ranker_configs = []
# Exec
engine = SimpleEngine.from_objs(
objs=objs,
llm=mock_llm,
embed_model=mock_embedding,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is not None
def test_from_objs_with_bm25_config(self):
# Setup
retriever_configs = [BM25RetrieverConfig()]
# Exec
with pytest.raises(ValueError):
SimpleEngine.from_objs(
objs=[],
llm=MockLLM(),
retriever_configs=retriever_configs,
ranker_configs=[],
)
def test_from_index(self, mocker, mock_llm, mock_embedding):
# Mock
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
mock_get_index.return_value = mock_index
# Exec
engine = SimpleEngine.from_index(
index_config=mock_index,
embed_model=mock_embedding,
llm=mock_llm,
)
# Assert
assert isinstance(engine, SimpleEngine)
assert engine.index is mock_index
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
@ -86,10 +162,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
# Exec
result = await engine.asearch(test_query)
# Assertions
# Assert
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@ -106,10 +182,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
# Exec
result = await engine.aretrieve(test_query)
# Assertions
# Assert
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
@ -134,10 +210,10 @@ class TestSimpleEngine:
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
# Exec
engine.add_docs(input_files=input_files)
# Assertions
# Assert
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
@ -156,11 +232,79 @@ class TestSimpleEngine:
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
# Exec
engine.add_objs(objs=objs)
# Assertions
# Assert
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata
def test_persist_successfully(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
mock_retriever.persist.return_value = mocker.MagicMock()
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Exec
engine.persist(persist_dir="")
def test_ensure_retriever_of_type(self, mocker):
# Mock
class MyRetriever:
def add_nodes(self):
...
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
mock_retriever.retrievers = [MyRetriever()]
# Setup
engine = SimpleEngine(retriever=mock_retriever)
# Assert
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
with pytest.raises(TypeError):
engine._ensure_retriever_of_type(PersistableRAGRetriever)
with pytest.raises(TypeError):
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
def test_with_obj_metadata(self, mocker):
# Mock
node = NodeWithScore(
node=ObjectNode(
text="example",
metadata={
"is_obj": True,
"obj_cls_name": "ExampleObject",
"obj_mod_name": "__main__",
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
},
)
)
class ExampleObject:
def __init__(self, key, value):
self.key = key
self.value = value
def __eq__(self, other):
return self.key == other.key and self.value == other.value
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
mock_import_class.return_value = ExampleObject
# Setup
SimpleEngine._try_reconstruct_obj([node])
# Exec
expected_obj = ExampleObject(key="test_key", value="test_value")
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj

View file

@ -0,0 +1,43 @@
import pytest
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
class TestRAGEmbeddingFactory:
@pytest.fixture(autouse=True)
def mock_embedding_factory(self):
self.embedding_factory = RAGEmbeddingFactory()
@pytest.fixture
def mock_openai_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
@pytest.fixture
def mock_azure_embedding(self, mocker):
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
def test_get_rag_embedding_openai(self, mock_openai_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
# Assert
mock_openai_embedding.assert_called_once()
def test_get_rag_embedding_azure(self, mock_azure_embedding):
# Exec
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
# Assert
mock_azure_embedding.assert_called_once()
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
# Mock
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
mock_config.llm.api_type = LLMType.OPENAI
# Exec
self.embedding_factory.get_rag_embedding()
# Assert
mock_openai_embedding.assert_called_once()

View file

@ -0,0 +1,89 @@
import pytest
from llama_index.core.embeddings import MockEmbedding
from metagpt.rag.factories.index import RAGIndexFactory
from metagpt.rag.schema import (
BM25IndexConfig,
ChromaIndexConfig,
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
)
class TestRAGIndexFactory:
@pytest.fixture(autouse=True)
def setup(self):
self.index_factory = RAGIndexFactory()
@pytest.fixture
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@pytest.fixture
def bm25_config(self):
return BM25IndexConfig(persist_path="")
@pytest.fixture
def es_config(self, mocker):
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
@pytest.fixture
def mock_storage_context(self, mocker):
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
@pytest.fixture
def mock_load_index_from_storage(self, mocker):
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
@pytest.fixture
def mock_from_vector_store(self, mocker):
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
@pytest.fixture
def mock_embedding(self):
return MockEmbedding(embed_dim=1)
def test_create_faiss_index(
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
# Mock
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
# Exec
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
# Assert
mock_faiss_store.assert_called_once()
def test_create_bm25_index(
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
# Exec
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
# Assert
mock_chroma_store.assert_called_once()
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
# Mock
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
# Exec
self.index_factory.get_index(es_config, embed_model=mock_embedding)
# Assert
mock_es_store.assert_called_once()

View file

@ -0,0 +1,71 @@
from typing import Optional, Union
import pytest
from llama_index.core.llms import LLMMetadata
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.provider.base_llm import BaseLLM
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
class MockLLM(BaseLLM):
def __init__(self, config: LLMConfig):
...
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
"""_achat_completion implemented by inherited class"""
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
return "ok"
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
"""_achat_completion_stream implemented by inherited class"""
async def aask(
self,
msg: Union[str, list[dict[str, str]]],
system_msgs: Optional[list[str]] = None,
format_msgs: Optional[list[dict[str, str]]] = None,
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
stream=True,
) -> str:
return "ok"
class TestRAGLLM:
@pytest.fixture
def mock_model_infer(self):
return MockLLM(config=LLMConfig())
@pytest.fixture
def rag_llm(self, mock_model_infer):
return RAGLLM(model_infer=mock_model_infer)
def test_metadata(self, rag_llm):
metadata = rag_llm.metadata
assert isinstance(metadata, LLMMetadata)
assert metadata.context_window == rag_llm.context_window
assert metadata.num_output == rag_llm.num_output
assert metadata.model_name == rag_llm.model_name
@pytest.mark.asyncio
async def test_acomplete(self, rag_llm, mock_model_infer):
response = await rag_llm.acomplete("question")
assert response.text == "ok"
def test_complete(self, rag_llm, mock_model_infer):
response = rag_llm.complete("question")
assert response.text == "ok"
def test_stream_complete(self, rag_llm, mock_model_infer):
rag_llm.stream_complete("question")
def test_get_rag_llm():
result = get_rag_llm(MockLLM(config=LLMConfig()))
assert isinstance(result, RAGLLM)

View file

@ -1,41 +1,60 @@
import contextlib
import pytest
from llama_index.core.llms import LLM
from llama_index.core.llms import MockLLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture(autouse=True)
def ranker_factory(self):
self.ranker_factory: RankerFactory = RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def mock_llm(self):
return MockLLM()
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = self.ranker_factory.get_rankers()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
def test_get_rankers_with_configs(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
def test_extract_llm_from_config(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, mock_llm):
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
def test_create_llm_ranker(self, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = self.ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_create_colbert_ranker(self, mocker, mock_llm):
with contextlib.suppress(ImportError):
mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert")
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm
mock_config = ColbertRerankConfig(llm=mock_llm)
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
assert ranker == "colbert"
def test_create_object_ranker(self, mocker, mock_llm):
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
ranker = self.ranker_factory._create_object_ranker(mock_config)
assert ranker == "object"

View file

@ -1,18 +1,28 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
)
class TestRetrieverFactory:
@pytest.fixture
@pytest.fixture(autouse=True)
def retriever_factory(self):
return RetrieverFactory()
self.retriever_factory: RetrieverFactory = RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
@ -25,55 +35,79 @@ class TestRetrieverFactory:
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
@pytest.fixture
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, ElasticsearchRetriever)
def test_create_default_retriever(self, mocker, mock_vector_store_index):
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
retriever = self.retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
def test_extract_index_from_config(self, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
extracted_index = self.retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
def test_extract_index_from_kwargs(self, mock_vector_store_index):
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -0,0 +1,23 @@
import pytest
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from metagpt.rag.rankers.base import RAGRanker
class SimpleRAGRanker(RAGRanker):
def _postprocess_nodes(self, nodes, query_bundle=None):
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
class TestSimpleRAGRanker:
@pytest.fixture
def ranker(self):
return SimpleRAGRanker()
def test_postprocess_nodes_increases_scores(self, ranker):
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
query_bundle = QueryBundle(query_str="test query")
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))

View file

@ -14,7 +14,7 @@ class Record(BaseModel):
class TestObjectSortPostprocessor:
@pytest.fixture
def nodes_with_scores(self):
def mock_nodes_with_scores(self):
nodes = [
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
@ -23,38 +23,47 @@ class TestObjectSortPostprocessor:
return nodes
@pytest.fixture
def query_bundle(self, mocker):
def mock_query_bundle(self, mocker):
return mocker.MagicMock(spec=QueryBundle)
def test_sort_descending(self, nodes_with_scores, query_bundle):
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [20, 10, 5]
def test_sort_ascending(self, nodes_with_scores, query_bundle):
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert [node.score for node in sorted_nodes] == [5, 10, 20]
def test_top_n_limit(self, nodes_with_scores, query_bundle):
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
sorted_nodes = postprocessor._postprocess_nodes(nodes_with_scores, query_bundle)
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
assert len(sorted_nodes) == 2
assert [node.score for node in sorted_nodes] == [20, 10]
def test_invalid_json_metadata(self, query_bundle):
def test_invalid_json_metadata(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes, query_bundle)
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
def test_missing_query_bundle(self, nodes_with_scores):
def test_missing_query_bundle(self, mock_nodes_with_scores):
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes_with_scores, query_bundle=None)
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
def test_field_not_found_in_object(self):
def test_field_not_found_in_object(self, mock_query_bundle):
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
with pytest.raises(ValueError):
postprocessor._postprocess_nodes(nodes)
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
def test_not_nodes(self, mock_query_bundle):
nodes = []
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
assert result == []
def test_class_name(self):
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"

View file

@ -0,0 +1,21 @@
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
...
class SubPersistableRAGRetriever(PersistableRAGRetriever):
...
class TestModifiableRAGRetriever:
def test_subclasshook(self):
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
assert result is NotImplemented
class TestPersistableRAGRetriever:
def test_subclasshook(self):
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
assert result is NotImplemented

View file

@ -8,30 +8,30 @@ from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
index.storage_context.persist.return_value = "ok"
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
# Exec
self.retriever.add_nodes(self.mock_nodes)
# Assertions
# Assert
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()
def test_persist(self):
self.retriever.persist("")

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
class TestChromaRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ChromaRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -0,0 +1,20 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
class TestElasticsearchRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
self.mock_index = mocker.MagicMock()
self.retriever = ElasticsearchRetriever(self.mock_index)
def test_add_nodes(self):
self.retriever.add_nodes(self.mock_nodes)
self.mock_index.insert_nodes.assert_called()

View file

@ -7,16 +7,19 @@ from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
def test_add_docs_calls_insert_for_each_document(self):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called
self.mock_index.insert_nodes.assert_called()
def test_persist(self):
self.retriever.persist("")
self.mock_index.storage_context.persist.assert_called()

View file

@ -1,5 +1,3 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
@ -7,18 +5,30 @@ from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.fixture
def mock_retriever(self, mocker):
return mocker.MagicMock()
@pytest.fixture
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
return SimpleHybridRetriever(mock_retriever)
@pytest.fixture
def mock_node(self):
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
@pytest.mark.asyncio
async def test_aretrieve(self):
async def test_aretrieve(self, mocker):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1 = mocker.AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2 = mocker.AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
@ -37,3 +47,11 @@ class TestSimpleHybridRetriever:
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
mock_hybrid_retriever.add_nodes([mock_node])
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
mock_hybrid_retriever.persist("")
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()

View file

@ -37,6 +37,7 @@ class MockSearchEnine:
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
(SearchEngineType.BING, None, 6, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],