mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
Merge branch 'mgx_ops' into di_mgx
This commit is contained in:
commit
3e10d34468
304 changed files with 10747 additions and 662 deletions
|
|
@ -18,6 +18,7 @@ from metagpt.utils.git_repository import ChangeType
|
|||
from metagpt.utils.graph_repository import SPO
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild(context, mocker):
|
||||
# Mock
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
document_store = ChromaStore("sample_collection_1", get_or_create=True)
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
|
@ -17,18 +15,24 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
|
||||
return embeds
|
||||
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
|
||||
return embeds.tolist()
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -37,9 +41,10 @@ async def test_search_json(mocker):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -48,9 +53,10 @@ async def test_search_xlsx(mocker):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
|
||||
|
||||
def mock_device_shape(self, adb_cmd: str) -> str:
|
||||
|
|
@ -34,9 +34,7 @@ def mock_write_read_operation(self, adb_cmd: str) -> str:
|
|||
|
||||
def test_android_ext_env(mocker):
|
||||
device_id = "emulator-5554"
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape)
|
||||
|
||||
ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/")
|
||||
assert ext_env.adb_prefix == f"adb -s {device_id} "
|
||||
|
|
@ -46,25 +44,21 @@ def test_android_ext_env(mocker):
|
|||
assert ext_env.device_shape == (720, 1080)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
)
|
||||
assert ext_env.device_shape == (0, 0)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices)
|
||||
assert ext_env.list_devices() == [device_id]
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot)
|
||||
assert ext_env.get_screenshot("screenshot_xxxx-xx-xx", "/data/") == Path("/data/screenshot_xxxx-xx-xx.png")
|
||||
|
||||
mocker.patch("metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
assert ext_env.get_xml("xml_xxxx-xx-xx", "/data/") == Path("/data/xml_xxxx-xx-xx.xml")
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
)
|
||||
res = "OK"
|
||||
assert ext_env.system_back() == res
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MincraftExtEnv
|
||||
# @Desc : the unittest of MinecraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
from metagpt.environment.minecraft.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
|
||||
|
||||
|
||||
def test_mincraft_ext_env():
|
||||
ext_env = MincraftExtEnv()
|
||||
def test_minecraft_ext_env():
|
||||
ext_env = MinecraftExtEnv()
|
||||
assert ext_env.server, f"{ext_env.server_host}:{ext_env.server_port}"
|
||||
assert MC_CKPT_DIR.joinpath("skill/code").exists()
|
||||
assert ext_env.warm_up.get("optional_inventory_items") == 7
|
||||
|
|
@ -4,12 +4,18 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
|
||||
StanfordTownExtEnv,
|
||||
from metagpt.environment.stanford_town.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv
|
||||
|
||||
maze_asset_path = (
|
||||
Path(__file__).absolute().parent.joinpath("..", "..", "..", "data", "environment", "stanford_town", "the_ville")
|
||||
Path(__file__)
|
||||
.absolute()
|
||||
.parent.joinpath("..", "..", "..", "..", "metagpt/ext/stanford_town/static_dirs/assets/the_ville")
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
|
|||
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
ext_env.add_tiles_event(tile[1], tile[0], event=event)
|
||||
ext_env.add_event_from_tile(event, tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
||||
|
|
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
|
|||
|
||||
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
|
||||
|
||||
|
||||
def test_stanford_town_ext_env_observe_step():
|
||||
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
|
||||
obs, info = ext_env.reset()
|
||||
assert len(info) == 0
|
||||
assert len(obs["address_tiles"]) == 306
|
||||
|
||||
tile = (58, 9)
|
||||
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
|
||||
assert obs == "the Ville"
|
||||
|
||||
action = ext_env.action_space.sample()
|
||||
assert len(action) == 4
|
||||
assert len(action["event"]) == 4
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ExtEnv&Env
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
|
|
@ -12,11 +14,26 @@ from metagpt.environment.base_env import (
|
|||
mark_as_readable,
|
||||
mark_as_writeable,
|
||||
)
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
|
||||
|
||||
class ForTestEnv(Environment):
|
||||
value: int = 0
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@mark_as_readable
|
||||
def read_api_no_param(self):
|
||||
return self.value
|
||||
|
|
@ -44,11 +61,11 @@ async def test_ext_env():
|
|||
assert len(apis) > 0
|
||||
assert len(apis["read_api"]) == 3
|
||||
|
||||
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
assert env.value == 15
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await env.observe("not_exist_api")
|
||||
await env.read_from_api("not_exist_api")
|
||||
|
||||
assert await env.observe("read_api_no_param") == 15
|
||||
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
assert await env.read_from_api("read_api_no_param") == 15
|
||||
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of WerewolfExtEnv
|
||||
|
||||
from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv
|
||||
from metagpt.environment.werewolf.werewolf_ext_env import RoleState, WerewolfExtEnv
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
|
|
|
|||
3
tests/metagpt/ext/__init__.py
Normal file
3
tests/metagpt/ext/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/gen_action_details.py
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.ext.stanford_town.actions.gen_action_details import (
|
||||
GenActionArena,
|
||||
GenActionDetails,
|
||||
GenActionObject,
|
||||
GenActionSector,
|
||||
GenActObjDescription,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_action_details():
|
||||
role = STRole(
|
||||
name="Klaus Mueller",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
act_desp = "sleeping"
|
||||
act_dura = "120"
|
||||
|
||||
access_tile = await role.rc.env.read_from_api(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
|
||||
)
|
||||
act_world = access_tile["world"]
|
||||
assert act_world == "the Ville"
|
||||
|
||||
sector = await GenActionSector().run(role, access_tile, act_desp)
|
||||
arena = await GenActionArena().run(role, act_desp, act_world, sector)
|
||||
temp_address = f"{act_world}:{sector}:{arena}"
|
||||
obj = await GenActionObject().run(role, act_desp, temp_address)
|
||||
|
||||
act_obj_desp = await GenActObjDescription().run(role, obj, act_desp)
|
||||
|
||||
result_dict = await GenActionDetails().run(role, act_desp, act_dura)
|
||||
|
||||
# gen_action_sector
|
||||
assert isinstance(sector, str)
|
||||
assert sector in role.s_mem.get_str_accessible_sectors(act_world)
|
||||
|
||||
# gen_action_arena
|
||||
assert isinstance(arena, str)
|
||||
assert arena in role.s_mem.get_str_accessible_sector_arenas(f"{act_world}:{sector}")
|
||||
|
||||
# gen_action_obj
|
||||
assert isinstance(obj, str)
|
||||
assert obj in role.s_mem.get_str_accessible_arena_game_objects(temp_address)
|
||||
|
||||
if result_dict:
|
||||
for key in [
|
||||
"action_address",
|
||||
"action_duration",
|
||||
"action_description",
|
||||
"action_pronunciatio",
|
||||
"action_event",
|
||||
"chatting_with",
|
||||
"chat",
|
||||
"chatting_with_buffer",
|
||||
"chatting_end_time",
|
||||
"act_obj_description",
|
||||
"act_obj_pronunciatio",
|
||||
"act_obj_event",
|
||||
]:
|
||||
assert key in result_dict
|
||||
assert result_dict["action_address"] == f"{temp_address}:{obj}"
|
||||
assert result_dict["action_duration"] == int(act_dura)
|
||||
assert result_dict["act_obj_description"] == act_obj_desp
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/summarize_conv
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_conv():
|
||||
conv = [("Role_A", "what's the weather today?"), ("Role_B", "It looks pretty good, and I will take a walk then.")]
|
||||
|
||||
output = await SummarizeConv().run(conv)
|
||||
assert "weather" in output
|
||||
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of AgentMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import AgentMemory
|
||||
from metagpt.ext.stanford_town.memory.retrieve import agent_retrieve
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试; Load方法中使用了add方法,验证Load即可验证所有add
|
||||
2.2 Get方法测试
|
||||
"""
|
||||
memory_easy_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memroy_chat_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memory_save_easy_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
memory_save_chat_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
@pytest.fixture
|
||||
def agent_memory(self):
|
||||
# 创建一个AgentMemory实例并返回,可以在所有测试用例中共享
|
||||
test_agent_memory = AgentMemory()
|
||||
test_agent_memory.set_mem_path(memroy_chat_storage_path)
|
||||
return test_agent_memory
|
||||
|
||||
def test_load(self, agent_memory):
|
||||
logger.info(f"存储路径为:{agent_memory.memory_saved}")
|
||||
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
|
||||
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
|
||||
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
|
||||
|
||||
assert agent_memory.embeddings is not None
|
||||
|
||||
def test_save(self, agent_memory):
|
||||
try:
|
||||
agent_memory.save(memory_save_chat_test_path)
|
||||
logger.info("成功存储")
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_summary_function(self, agent_memory):
|
||||
logger.info(f"event长度为{len(agent_memory.event_list)}")
|
||||
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
|
||||
logger.info(f"chat长度为{len(agent_memory.chat_list)}")
|
||||
result1 = agent_memory.get_summarized_latest_events(4)
|
||||
logger.info(f"总结最近事件结果为:{result1}")
|
||||
|
||||
def test_get_last_chat_function(self, agent_memory):
|
||||
result2 = agent_memory.get_last_chat("customers")
|
||||
logger.info(f"上一次对话是{result2}")
|
||||
|
||||
def test_retrieve_function(self, agent_memory):
|
||||
focus_points = ["who i love?"]
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
nodes = [
|
||||
[i.last_accessed, i]
|
||||
for i in agent_memory.event_list + agent_memory.thought_list
|
||||
if "idle" not in i.embedding_key
|
||||
]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(agent_memory, datetime.now() - timedelta(days=120), 0.99, focal_pt, nodes, 5)
|
||||
final_result = []
|
||||
for n in results:
|
||||
for i in agent_memory.storage:
|
||||
if i.memory_id == n:
|
||||
i.last_accessed = datetime.now() - timedelta(days=120)
|
||||
final_result.append(i)
|
||||
|
||||
retrieved[focal_pt] = final_result
|
||||
logger.info(f"检索结果为{retrieved}")
|
||||
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of BasicMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试
|
||||
2.2 Add方法测试
|
||||
2.3 Get方法测试
|
||||
"""
|
||||
|
||||
# Create some sample BasicMemory instances
|
||||
memory1 = BasicMemory(
|
||||
memory_id="1",
|
||||
memory_count=1,
|
||||
type_count=1,
|
||||
memory_type="event",
|
||||
depth=1,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject1",
|
||||
predicate="Predicate1",
|
||||
object="Object1",
|
||||
content="This is content 1",
|
||||
embedding_key="embedding_key_1",
|
||||
poignancy=1,
|
||||
keywords=["keyword1", "keyword2"],
|
||||
filling=["memory_id_2"],
|
||||
)
|
||||
memory2 = BasicMemory(
|
||||
memory_id="2",
|
||||
memory_count=2,
|
||||
type_count=2,
|
||||
memory_type="thought",
|
||||
depth=2,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject2",
|
||||
predicate="Predicate2",
|
||||
object="Object2",
|
||||
content="This is content 2",
|
||||
embedding_key="embedding_key_2",
|
||||
poignancy=2,
|
||||
keywords=["keyword3", "keyword4"],
|
||||
filling=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mem_set():
|
||||
basic_mem2 = memory2
|
||||
yield basic_mem2
|
||||
|
||||
|
||||
def test_basic_mem_function(basic_mem_set):
|
||||
a, b, c = basic_mem_set.summary()
|
||||
logger.info(f"{a}{b}{c}")
|
||||
assert a == "Subject2"
|
||||
|
||||
|
||||
def test_basic_mem_save(basic_mem_set):
|
||||
result = basic_mem_set.save_to_dict()
|
||||
logger.info(f"save结果为{result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MemoryTree
|
||||
|
||||
from metagpt.ext.stanford_town.memory.spatial_memory import MemoryTree
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
|
||||
|
||||
def test_spatial_memory():
|
||||
f_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/spatial_memory.json"
|
||||
)
|
||||
x = MemoryTree()
|
||||
x.set_mem_path(f_path)
|
||||
assert x.tree
|
||||
assert "the Ville" in x.tree
|
||||
assert "Isabella Rodriguez's apartment" in x.get_str_accessible_sectors("the Ville")
|
||||
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of roles conversation
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.plan.converse import agent_conversation
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH, STORAGE_PATH
|
||||
from metagpt.ext.stanford_town.utils.mg_ga_transform import get_reverie_meta
|
||||
from metagpt.ext.stanford_town.utils.utils import copy_folder
|
||||
|
||||
|
||||
async def init_two_roles(fork_sim_code: str = "base_the_ville_isabella_maria_klaus") -> Tuple["STRole"]:
|
||||
sim_code = "unittest_sim"
|
||||
|
||||
copy_folder(str(STORAGE_PATH.joinpath(fork_sim_code)), str(STORAGE_PATH.joinpath(sim_code)))
|
||||
|
||||
reverie_meta = get_reverie_meta(fork_sim_code)
|
||||
role_ir_name = "Isabella Rodriguez"
|
||||
role_km_name = "Klaus Mueller"
|
||||
|
||||
env = StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH)
|
||||
|
||||
role_ir = STRole(
|
||||
name=role_ir_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_ir_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_ir.set_env(env)
|
||||
await role_ir.init_curr_tile()
|
||||
|
||||
role_km = STRole(
|
||||
name=role_km_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_km_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_km.set_env(env)
|
||||
await role_km.init_curr_tile()
|
||||
|
||||
return role_ir, role_km
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_conversation():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
|
||||
curr_chat = await agent_conversation(role_ir, role_km, conv_rounds=2)
|
||||
assert len(curr_chat) % 2 == 0
|
||||
|
||||
meet = False
|
||||
for conv in curr_chat:
|
||||
if "Valentine's Day party" in conv[1]:
|
||||
# conv[0] speaker, conv[1] utterance
|
||||
meet = True
|
||||
assert meet
|
||||
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of st_plan
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.plan.st_plan import _choose_retrieved, _should_react
|
||||
from tests.metagpt.ext.stanford_town.plan.test_conversation import init_two_roles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_react():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
roles = {role_ir.name: role_ir, role_km.name: role_km}
|
||||
role_ir.scratch.act_address = "mock data"
|
||||
|
||||
observed = await role_ir.observe()
|
||||
retrieved = role_ir.retrieve(observed)
|
||||
|
||||
focused_event = _choose_retrieved(role_ir.name, retrieved)
|
||||
|
||||
if focused_event:
|
||||
reaction_mode = await _should_react(role_ir, focused_event, roles) # chat with Isabella Rodriguez
|
||||
assert not reaction_mode
|
||||
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of STRole
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe():
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
ret_events = await role.observe()
|
||||
assert ret_events
|
||||
for event in ret_events:
|
||||
assert isinstance(event, BasicMemory)
|
||||
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of reflection
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.actions.run_reflect_action import (
|
||||
AgentEventTriple,
|
||||
AgentFocusPt,
|
||||
AgentInsightAndGuidance,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflect():
|
||||
"""
|
||||
init STRole form local json, set sim_code(path),curr_time & start_time
|
||||
"""
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
role.init_curr_tile()
|
||||
|
||||
run_focus = AgentFocusPt()
|
||||
statements = ""
|
||||
await run_focus.run(role, statements, n=3)
|
||||
|
||||
"""
|
||||
这里有通过测试的结果,但是更多时候LLM生成的结果缺少了because of;考虑修改一下prompt
|
||||
result = {'Klaus Mueller and Maria Lopez have a close relationship because they have been friends for a long time and have a strong bond': [1, 2, 5, 9, 11, 14], 'Klaus Mueller has a crush on Maria Lopez': [8, 15, 24], 'Klaus Mueller is academically inclined and actively researching a topic': [13, 20], 'Klaus Mueller is socially active and acquainted with Isabella Rodriguez': [17, 21, 22], 'Klaus Mueller is organized and prepared': [19]}
|
||||
"""
|
||||
run_insight = AgentInsightAndGuidance()
|
||||
statements = "[user: Klaus Mueller has a close relationship with Maria Lopez, user:s Mueller and Maria Lopez have a close relationship, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller and Maria Lopez have a strong relationship, user: Klaus Mueller is a dormmate of Maria Lopez., user: Klaus Mueller and Maria Lopez have a strong bond, user: Klaus Mueller has a crush on Maria Lopez, user: Klaus Mueller and Maria Lopez have been friends for more than 2 years., user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller Maria Lopez is heading off to college., user: Klaus Mueller and Maria Lopez have a close relationship, user: Klaus Mueller is actively researching a topic, user: Klaus Mueller is close friends and classmates with Maria Lopez., user: Klaus Mueller is socially active, user: Klaus Mueller has a crush on Maria Lopez., user: Klaus Mueller and Maria Lopez have been friends for a long time, user: Klaus Mueller is academically inclined, user: For Klaus Mueller's planning: should remember to ask Maria Lopez about her research paper, as she found it interesting that he mentioned it., user: Klaus Mueller is acquainted with Isabella Rodriguez, user: Klaus Mueller is organized and prepared, user: Maria Lopez is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is a student, user: Klaus Mueller is a student, user: Klaus Mueller is conversing about two friends named Klaus Mueller and Maria Lopez discussing their morning plans and progress on a research paper before Maria heads off to college., user: Klaus Mueller is socially active, user: Klaus Mueller is socially active, user: Klaus Mueller is socially active and acquainted with Isabella Rodriguez, user: Klaus Mueller has a crush on Maria Lopez]"
|
||||
await run_insight.run(role, statements, n=5)
|
||||
|
||||
run_triple = AgentEventTriple()
|
||||
statements = "(Klaus Mueller is academically inclined)"
|
||||
await run_triple.run(statements, role)
|
||||
|
||||
role.scratch.importance_trigger_curr = -1
|
||||
role.reflect()
|
||||
|
|
@ -2,32 +2,41 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
|
||||
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
|
||||
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
|
||||
{"text": "Write a Battle City", "embed": embed_ones_arrr},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": np.zeros(shape=[1, dim]),
|
||||
"embed": embed_zeros_arrr,
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": np.ones(shape=[1, dim]),
|
||||
"embed": embed_ones_arrr,
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
async def mock_openai_aembed_document(self, text: str) -> list[float]:
|
||||
return mock_openai_embed_document(self, text)
|
||||
|
|
|
|||
|
|
@ -12,13 +12,20 @@ from metagpt.memory.longterm_memory import LongTermMemory
|
|||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_ltm_search(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_ltm_search(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
|
@ -31,39 +38,24 @@ def test_ltm_search(mocker):
|
|||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
news = await ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
ltm.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,19 +8,28 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_idea_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_idea_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
|
|
@ -29,28 +38,32 @@ def test_idea_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_actionout_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
|
|
@ -67,23 +80,22 @@ def test_actionout_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ import pytest
|
|||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
default_resp_cont,
|
||||
get_part_chat_completion,
|
||||
|
|
@ -22,7 +23,7 @@ name = "GPT"
|
|||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
self.config = config or mock_llm_config
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return get_part_chat_completion(name)
|
||||
|
|
|
|||
310
tests/metagpt/rag/engines/test_simple.py
Normal file
310
tests/metagpt/rag/engines/test_simple.py
Normal file
|
|
@ -0,0 +1,310 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_rankers(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
transformations = [mocker.MagicMock()]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
def test_from_docs_without_file(self):
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_docs()
|
||||
|
||||
def test_from_objs(self, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
class MockRAGObject:
|
||||
def rag_key(self):
|
||||
return "key"
|
||||
|
||||
def model_dump_json(self):
|
||||
return "{}"
|
||||
|
||||
objs = [MockRAGObject()]
|
||||
|
||||
# Setup
|
||||
retriever_configs = []
|
||||
ranker_configs = []
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_objs(
|
||||
objs=objs,
|
||||
llm=mock_llm,
|
||||
embed_model=mock_embedding,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
retriever_configs = [BM25RetrieverConfig()]
|
||||
|
||||
# Exec
|
||||
with pytest.raises(ValueError):
|
||||
SimpleEngine.from_objs(
|
||||
objs=[],
|
||||
llm=MockLLM(),
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=[],
|
||||
)
|
||||
|
||||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
# Exec
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=mock_index,
|
||||
embed_model=mock_embedding,
|
||||
llm=mock_llm,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = mocker.AsyncMock(return_value=expected_result)
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Exec
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assert
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self, mocker):
|
||||
# Mock
|
||||
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
|
||||
mock_super_aretrieve = mocker.patch(
|
||||
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
|
||||
)
|
||||
mock_super_aretrieve.return_value = [TextNode(text="node_with_score", metadata={"is_obj": False})]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Exec
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assert
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result[0].text == "node_with_score"
|
||||
|
||||
def test_add_docs(self, mocker):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
def test_add_objs(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
# Setup
|
||||
class CustomTextNode(TextNode):
|
||||
def rag_key(self):
|
||||
return ""
|
||||
|
||||
def model_dump_json(self):
|
||||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assert
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "is_obj" in node.metadata
|
||||
|
||||
def test_persist_successfully(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=PersistableRAGRetriever)
|
||||
mock_retriever.persist.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.persist(persist_dir="")
|
||||
|
||||
def test_ensure_retriever_of_type(self, mocker):
|
||||
# Mock
|
||||
class MyRetriever:
|
||||
def add_nodes(self):
|
||||
...
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=SimpleHybridRetriever)
|
||||
mock_retriever.retrievers = [MyRetriever()]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Assert
|
||||
engine._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
other_engine = SimpleEngine(retriever=mocker.MagicMock(spec=ModifiableRAGRetriever))
|
||||
other_engine._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def test_with_obj_metadata(self, mocker):
|
||||
# Mock
|
||||
node = NodeWithScore(
|
||||
node=ObjectNode(
|
||||
text="example",
|
||||
metadata={
|
||||
"is_obj": True,
|
||||
"obj_cls_name": "ExampleObject",
|
||||
"obj_mod_name": "__main__",
|
||||
"obj_json": json.dumps({"key": "test_key", "value": "test_value"}),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
class ExampleObject:
|
||||
def __init__(self, key, value):
|
||||
self.key = key
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.key == other.key and self.value == other.value
|
||||
|
||||
mock_import_class = mocker.patch("metagpt.rag.engines.simple.import_class")
|
||||
mock_import_class.return_value = ExampleObject
|
||||
|
||||
# Setup
|
||||
SimpleEngine._try_reconstruct_obj([node])
|
||||
|
||||
# Exec
|
||||
expected_obj = ExampleObject(key="test_key", value="test_value")
|
||||
|
||||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
102
tests/metagpt/rag/factories/test_base.py
Normal file
102
tests/metagpt/rag/factories/test_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
|
||||
|
||||
|
||||
class TestGenericFactory:
|
||||
@pytest.fixture
|
||||
def creators(self):
|
||||
return {
|
||||
"type1": lambda name: f"Instance of type1 with {name}",
|
||||
"type2": lambda name: f"Instance of type2 with {name}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self, creators):
|
||||
return GenericFactory(creators=creators)
|
||||
|
||||
def test_get_instance_success(self, factory):
|
||||
# Test successful retrieval of an instance
|
||||
key = "type1"
|
||||
instance = factory.get_instance(key, name="TestName")
|
||||
assert instance == "Instance of type1 with TestName"
|
||||
|
||||
def test_get_instance_failure(self, factory):
|
||||
# Test failure to retrieve an instance due to unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instance("unknown_key")
|
||||
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
|
||||
|
||||
def test_get_instances_success(self, factory):
|
||||
# Test successful retrieval of multiple instances
|
||||
keys = ["type1", "type2"]
|
||||
instances = factory.get_instances(keys, name="TestName")
|
||||
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
|
||||
assert instances == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected_exception_message",
|
||||
[
|
||||
(["unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
],
|
||||
)
|
||||
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
|
||||
# Test failure to retrieve instances due to at least one unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instances(keys, name="TestName")
|
||||
assert expected_exception_message in str(exc_info.value)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
"""A dummy config class for testing."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestConfigBasedFactory:
|
||||
@pytest.fixture
|
||||
def config_creators(self):
|
||||
return {
|
||||
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_factory(self, config_creators):
|
||||
return ConfigBasedFactory(creators=config_creators)
|
||||
|
||||
def test_get_instance_success(self, config_factory):
|
||||
# Test successful retrieval of an instance
|
||||
config = DummyConfig(name="TestConfig")
|
||||
instance = config_factory.get_instance(config, extra="additional data")
|
||||
assert instance == "Processed TestConfig with additional data"
|
||||
|
||||
def test_get_instance_failure(self, config_factory):
|
||||
# Test failure to retrieve an instance due to unknown config type
|
||||
class UnknownConfig:
|
||||
pass
|
||||
|
||||
config = UnknownConfig()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config_factory.get_instance(config)
|
||||
assert "Unknown config:" in str(exc_info.value)
|
||||
|
||||
def test_val_from_config_or_kwargs_priority(self):
|
||||
# Test that the value from the config object has priority over kwargs
|
||||
config = DummyConfig(name="ConfigName")
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "ConfigName"
|
||||
|
||||
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
|
||||
# Test fallback to kwargs when config object does not have the value
|
||||
config = DummyConfig(name=None)
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "KwargsName"
|
||||
|
||||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
43
tests/metagpt/rag/factories/test_embedding.py
Normal file
|
|
@ -0,0 +1,43 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
||||
class TestRAGEmbeddingFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_embedding_factory(self):
|
||||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding()
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
89
tests/metagpt/rag/factories/test_index.py
Normal file
89
tests/metagpt/rag/factories/test_index.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
import pytest
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
|
||||
from metagpt.rag.factories.index import RAGIndexFactory
|
||||
from metagpt.rag.schema import (
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRAGIndexFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
self.index_factory = RAGIndexFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_config(self):
|
||||
return FAISSIndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_config(self):
|
||||
return ChromaIndexConfig(persist_path="", collection_name="")
|
||||
|
||||
@pytest.fixture
|
||||
def bm25_config(self):
|
||||
return BM25IndexConfig(persist_path="")
|
||||
|
||||
@pytest.fixture
|
||||
def es_config(self, mocker):
|
||||
return ElasticsearchIndexConfig(store_config=ElasticsearchStoreConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage_context(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.StorageContext.from_defaults")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_load_index_from_storage(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.load_index_from_storage")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_from_vector_store(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.index.VectorStoreIndex.from_vector_store")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_create_faiss_index(
|
||||
self, mocker, faiss_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
# Mock
|
||||
mock_faiss_store = mocker.patch("metagpt.rag.factories.index.FaissVectorStore.from_persist_dir")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(faiss_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_faiss_store.assert_called_once()
|
||||
|
||||
def test_create_bm25_index(
|
||||
self, mocker, bm25_config, mock_storage_context, mock_load_index_from_storage, mock_embedding
|
||||
):
|
||||
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
|
||||
|
||||
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")
|
||||
mock_chroma_db.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
|
||||
mock_chroma_store = mocker.patch("metagpt.rag.factories.index.ChromaVectorStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(chroma_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_chroma_store.assert_called_once()
|
||||
|
||||
def test_create_es_index(self, mocker, es_config, mock_from_vector_store, mock_embedding):
|
||||
# Mock
|
||||
mock_es_store = mocker.patch("metagpt.rag.factories.index.ElasticsearchStore")
|
||||
|
||||
# Exec
|
||||
self.index_factory.get_index(es_config, embed_model=mock_embedding)
|
||||
|
||||
# Assert
|
||||
mock_es_store.assert_called_once()
|
||||
71
tests/metagpt/rag/factories/test_llm.py
Normal file
71
tests/metagpt/rag/factories/test_llm.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import LLMMetadata
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.rag.factories.llm import RAGLLM, get_rag_llm
|
||||
|
||||
|
||||
class MockLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig):
|
||||
...
|
||||
|
||||
async def _achat_completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
"""_achat_completion implemented by inherited class"""
|
||||
|
||||
async def acompletion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
def completion(self, messages: list[dict], timeout=USE_CONFIG_TIMEOUT):
|
||||
return "ok"
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
|
||||
"""_achat_completion_stream implemented by inherited class"""
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: Union[str, list[dict[str, str]]],
|
||||
system_msgs: Optional[list[str]] = None,
|
||||
format_msgs: Optional[list[dict[str, str]]] = None,
|
||||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
stream=True,
|
||||
) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
class TestRAGLLM:
|
||||
@pytest.fixture
|
||||
def mock_model_infer(self):
|
||||
return MockLLM(config=LLMConfig())
|
||||
|
||||
@pytest.fixture
|
||||
def rag_llm(self, mock_model_infer):
|
||||
return RAGLLM(model_infer=mock_model_infer)
|
||||
|
||||
def test_metadata(self, rag_llm):
|
||||
metadata = rag_llm.metadata
|
||||
assert isinstance(metadata, LLMMetadata)
|
||||
assert metadata.context_window == rag_llm.context_window
|
||||
assert metadata.num_output == rag_llm.num_output
|
||||
assert metadata.model_name == rag_llm.model_name
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acomplete(self, rag_llm, mock_model_infer):
|
||||
response = await rag_llm.acomplete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_complete(self, rag_llm, mock_model_infer):
|
||||
response = rag_llm.complete("question")
|
||||
assert response.text == "ok"
|
||||
|
||||
def test_stream_complete(self, rag_llm, mock_model_infer):
|
||||
rag_llm.stream_complete("question")
|
||||
|
||||
|
||||
def test_get_rag_llm():
|
||||
result = get_rag_llm(MockLLM(config=LLMConfig()))
|
||||
assert isinstance(result, RAGLLM)
|
||||
60
tests/metagpt/rag/factories/test_ranker.py
Normal file
60
tests/metagpt/rag/factories/test_ranker.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
import contextlib
|
||||
|
||||
import pytest
|
||||
from llama_index.core.llms import MockLLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import ColbertRerankConfig, LLMRankerConfig, ObjectRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def ranker_factory(self):
|
||||
self.ranker_factory: RankerFactory = RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self):
|
||||
return MockLLM()
|
||||
|
||||
def test_get_rankers_with_no_configs(self, mock_llm, mocker):
|
||||
mocker.patch.object(self.ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = self.ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = self.ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = self.ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, mock_llm):
|
||||
extracted_llm = self.ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_create_llm_ranker(self, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_create_colbert_ranker(self, mocker, mock_llm):
|
||||
with contextlib.suppress(ImportError):
|
||||
mocker.patch("llama_index.postprocessor.colbert_rerank.ColbertRerank", return_value="colbert")
|
||||
|
||||
mock_config = ColbertRerankConfig(llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_colbert_ranker(mock_config)
|
||||
|
||||
assert ranker == "colbert"
|
||||
|
||||
def test_create_object_ranker(self, mocker, mock_llm):
|
||||
mocker.patch("metagpt.rag.factories.ranker.ObjectSortPostprocessor", return_value="object")
|
||||
|
||||
mock_config = ObjectRankerConfig(field_name="fake", llm=mock_llm)
|
||||
ranker = self.ranker_factory._create_object_ranker(mock_config)
|
||||
|
||||
assert ranker == "object"
|
||||
113
tests/metagpt/rag/factories/test_retriever.py
Normal file
113
tests/metagpt/rag/factories/test_retriever.py
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture(autouse=True)
|
||||
def retriever_factory(self):
|
||||
self.retriever_factory: RetrieverFactory = RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
return mocker.MagicMock(spec=faiss.IndexFlatL2)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock._embed_model = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_chroma_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ChromaVectorStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
def test_create_default_retriever(self, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = self.retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = self.retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, mock_vector_store_index):
|
||||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
23
tests/metagpt/rag/rankers/test_base_ranker.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
|
||||
|
||||
from metagpt.rag.rankers.base import RAGRanker
|
||||
|
||||
|
||||
class SimpleRAGRanker(RAGRanker):
|
||||
def _postprocess_nodes(self, nodes, query_bundle=None):
|
||||
return [NodeWithScore(node=node.node, score=node.score + 1) for node in nodes]
|
||||
|
||||
|
||||
class TestSimpleRAGRanker:
|
||||
@pytest.fixture
|
||||
def ranker(self):
|
||||
return SimpleRAGRanker()
|
||||
|
||||
def test_postprocess_nodes_increases_scores(self, ranker):
|
||||
nodes = [NodeWithScore(node=TextNode(text="a"), score=10), NodeWithScore(node=TextNode(text="b"), score=20)]
|
||||
query_bundle = QueryBundle(query_str="test query")
|
||||
|
||||
processed_nodes = ranker._postprocess_nodes(nodes, query_bundle)
|
||||
|
||||
assert all(node.score == original_node.score + 1 for node, original_node in zip(processed_nodes, nodes))
|
||||
69
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
69
tests/metagpt/rag/rankers/test_object_ranker.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.rag.rankers.object_ranker import ObjectSortPostprocessor
|
||||
from metagpt.rag.schema import ObjectNode
|
||||
|
||||
|
||||
class Record(BaseModel):
|
||||
score: int
|
||||
|
||||
|
||||
class TestObjectSortPostprocessor:
|
||||
@pytest.fixture
|
||||
def mock_nodes_with_scores(self):
|
||||
nodes = [
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=10).model_dump_json()}), score=10),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=20).model_dump_json()}), score=20),
|
||||
NodeWithScore(node=ObjectNode(metadata={"obj_json": Record(score=5).model_dump_json()}), score=5),
|
||||
]
|
||||
return nodes
|
||||
|
||||
@pytest.fixture
|
||||
def mock_query_bundle(self, mocker):
|
||||
return mocker.MagicMock(spec=QueryBundle)
|
||||
|
||||
def test_sort_descending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [20, 10, 5]
|
||||
|
||||
def test_sort_ascending(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="asc")
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert [node.score for node in sorted_nodes] == [5, 10, 20]
|
||||
|
||||
def test_top_n_limit(self, mock_nodes_with_scores, mock_query_bundle):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc", top_n=2)
|
||||
sorted_nodes = postprocessor._postprocess_nodes(mock_nodes_with_scores, mock_query_bundle)
|
||||
assert len(sorted_nodes) == 2
|
||||
assert [node.score for node in sorted_nodes] == [20, 10]
|
||||
|
||||
def test_invalid_json_metadata(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": "invalid_json"}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
|
||||
def test_missing_query_bundle(self, mock_nodes_with_scores):
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(mock_nodes_with_scores, query_bundle=None)
|
||||
|
||||
def test_field_not_found_in_object(self, mock_query_bundle):
|
||||
nodes = [NodeWithScore(node=ObjectNode(metadata={"obj_json": json.dumps({"not_score": 10})}), score=10)]
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
with pytest.raises(ValueError):
|
||||
postprocessor._postprocess_nodes(nodes, query_bundle=mock_query_bundle)
|
||||
|
||||
def test_not_nodes(self, mock_query_bundle):
|
||||
nodes = []
|
||||
postprocessor = ObjectSortPostprocessor(field_name="score", order="desc")
|
||||
result = postprocessor._postprocess_nodes(nodes, mock_query_bundle)
|
||||
assert result == []
|
||||
|
||||
def test_class_name(self):
|
||||
assert ObjectSortPostprocessor.class_name() == "ObjectSortPostprocessor"
|
||||
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
21
tests/metagpt/rag/retrievers/test_base_retriever.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
|
||||
|
||||
class SubModifiableRAGRetriever(ModifiableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class SubPersistableRAGRetriever(PersistableRAGRetriever):
|
||||
...
|
||||
|
||||
|
||||
class TestModifiableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubModifiableRAGRetriever.__subclasshook__(SubModifiableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
|
||||
|
||||
class TestPersistableRAGRetriever:
|
||||
def test_subclasshook(self):
|
||||
result = SubPersistableRAGRetriever.__subclasshook__(SubPersistableRAGRetriever)
|
||||
assert result is NotImplemented
|
||||
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
||||
|
||||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
index.storage_context.persist.return_value = "ok"
|
||||
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Exec
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assert
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_chroma_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
|
||||
|
||||
class TestChromaRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ChromaRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
20
tests/metagpt/rag/retrievers/test_es_retriever.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
|
||||
|
||||
|
||||
class TestElasticsearchRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = ElasticsearchRetriever(self.mock_index)
|
||||
|
||||
def test_add_nodes(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
25
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
25
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
|
||||
|
||||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
self.mock_index.insert_nodes.assert_called()
|
||||
|
||||
def test_persist(self):
|
||||
self.retriever.persist("")
|
||||
|
||||
self.mock_index.storage_context.persist.assert_called()
|
||||
57
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
57
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.fixture
|
||||
def mock_retriever(self, mocker):
|
||||
return mocker.MagicMock()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_hybrid_retriever(self, mock_retriever) -> SimpleHybridRetriever:
|
||||
return SimpleHybridRetriever(mock_retriever)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_node(self):
|
||||
return NodeWithScore(node=TextNode(id_="2"), score=0.95)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self, mocker):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = mocker.AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = mocker.AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
]
|
||||
|
||||
# Instantiate the SimpleHybridRetriever with the mock retrievers
|
||||
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
|
||||
|
||||
# Call the _aretrieve method
|
||||
results = await hybrid_retriever._aretrieve(question)
|
||||
|
||||
# Check if the results are as expected
|
||||
assert len(results) == 3 # Should be 3 unique nodes
|
||||
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
|
||||
|
||||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
|
||||
def test_add_nodes(self, mock_hybrid_retriever: SimpleHybridRetriever, mock_node):
|
||||
mock_hybrid_retriever.add_nodes([mock_node])
|
||||
mock_hybrid_retriever.retrievers[0].add_nodes.assert_called_once()
|
||||
|
||||
def test_persist(self, mock_hybrid_retriever: SimpleHybridRetriever):
|
||||
mock_hybrid_retriever.persist("")
|
||||
mock_hybrid_retriever.retrievers[0].persist.assert_called_once()
|
||||
|
|
@ -6,11 +6,11 @@
|
|||
@File : test_tutorial_assistant.py
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.const import TUTORIAL_PATH
|
||||
from metagpt.roles.tutorial_assistant import TutorialAssistant
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -20,9 +20,8 @@ async def test_tutorial_assistant(language: str, topic: str, context):
|
|||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
async with aiofiles.open(filename, mode="r", encoding="utf-8") as reader:
|
||||
content = await reader.read()
|
||||
assert "pip" in content
|
||||
content = await aread(filename=filename)
|
||||
assert "pip" in content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ from metagpt.tools.libs.web_scraping import scrape_web_playwright
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scrape_web_playwright():
|
||||
test_url = "https://www.deepwisdom.ai"
|
||||
async def test_scrape_web_playwright(http_server):
|
||||
server, test_url = await http_server()
|
||||
|
||||
result = await scrape_web_playwright(test_url)
|
||||
|
||||
|
|
@ -21,3 +21,4 @@ async def test_scrape_web_playwright():
|
|||
assert not result["inner_text"].endswith(" ")
|
||||
assert not result["html"].startswith(" ")
|
||||
assert not result["html"].endswith(" ")
|
||||
await server.stop()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Callable
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
|
@ -38,6 +37,7 @@ class MockSearchEnine:
|
|||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.BING, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
|
|
@ -53,14 +53,11 @@ async def test_search_engine(
|
|||
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-google-key"
|
||||
search_engine_config["cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["api_key"] = "mock-serper-key"
|
||||
|
||||
async def test(search_engine):
|
||||
|
|
|
|||
|
|
@ -9,14 +9,16 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, url, urls",
|
||||
"browser_type",
|
||||
[
|
||||
(WebBrowserEngineType.PLAYWRIGHT, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
(WebBrowserEngineType.SELENIUM, "https://deepwisdom.ai", ("https://deepwisdom.ai",)),
|
||||
WebBrowserEngineType.PLAYWRIGHT,
|
||||
WebBrowserEngineType.SELENIUM,
|
||||
],
|
||||
ids=["playwright", "selenium"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, url, urls):
|
||||
async def test_scrape_web_page(browser_type, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
browser = web_browser_engine.WebBrowserEngine(engine=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -27,6 +29,7 @@ async def test_scrape_web_page(browser_type, url, urls):
|
|||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,18 +9,28 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, kwagrs, url, urls",
|
||||
"browser_type, use_proxy, kwagrs,",
|
||||
[
|
||||
("chromium", {"proxy": True}, {}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("firefox", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("webkit", {}, {"ignore_https_errors": True}, "https://www.deepwisdom.ai", ("https://www.deepwisdom.ai",)),
|
||||
("chromium", {"proxy": True}, {}),
|
||||
(
|
||||
"firefox",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
(
|
||||
"webkit",
|
||||
{},
|
||||
{"ignore_https_errors": True},
|
||||
),
|
||||
],
|
||||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, proxy, capfd, http_server):
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
|
|
@ -32,8 +42,10 @@ async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy
|
|||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import browsers
|
||||
import pytest
|
||||
|
||||
|
|
@ -10,51 +11,48 @@ from metagpt.utils.parse_html import WebPage
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"browser_type, use_proxy, url, urls",
|
||||
"browser_type, use_proxy,",
|
||||
[
|
||||
pytest.param(
|
||||
"chrome",
|
||||
True,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
False,
|
||||
marks=pytest.mark.skipif(not browsers.get("chrome"), reason="chrome browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"firefox",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("firefox"), reason="firefox browser not found"),
|
||||
),
|
||||
pytest.param(
|
||||
"edge",
|
||||
False,
|
||||
"https://deepwisdom.ai",
|
||||
("https://deepwisdom.ai",),
|
||||
marks=pytest.mark.skipif(not browsers.get("msedge"), reason="edge browser not found"),
|
||||
),
|
||||
],
|
||||
ids=["chrome-normal", "firefox-normal", "edge-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
async def test_scrape_web_page(browser_type, use_proxy, proxy, capfd, http_server):
|
||||
# Prerequisites
|
||||
# firefox, chrome, Microsoft Edge
|
||||
server, url = await http_server()
|
||||
urls = [url, url, url]
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
proxy_server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
proxy_server.close()
|
||||
await proxy_server.wait_closed()
|
||||
assert "Proxy: localhost" in capfd.readouterr().out
|
||||
await server.stop()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
from typing import Any, Set
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -125,9 +124,7 @@ class TestGetProjectRoot:
|
|||
async def test_parse_data_exception(self, filename, want):
|
||||
pathname = Path(__file__).parent.parent.parent / "data/output_parser" / filename
|
||||
assert pathname.exists()
|
||||
async with aiofiles.open(str(pathname), mode="r") as reader:
|
||||
data = await reader.read()
|
||||
|
||||
data = await aread(filename=pathname)
|
||||
result = OutputParser.parse_data(data=data)
|
||||
assert want in result
|
||||
|
||||
|
|
@ -198,12 +195,25 @@ class TestGetProjectRoot:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write(self):
|
||||
pathname = Path(__file__).parent / uuid.uuid4().hex / "test.tmp"
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.tmp"
|
||||
await awrite(pathname, "ABC")
|
||||
data = await aread(pathname)
|
||||
assert data == "ABC"
|
||||
pathname.unlink(missing_ok=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_write_error_charset(self):
|
||||
pathname = Path(__file__).parent / f"../../../workspace/unittest/{uuid.uuid4().hex}" / "test.txt"
|
||||
content = "中国abc123\u27f6"
|
||||
await awrite(filename=pathname, data=content)
|
||||
data = await aread(filename=pathname)
|
||||
assert data == content
|
||||
|
||||
content = "GB18030 是中国国家标准局发布的新一代中文字符集标准,是 GBK 的升级版,支持更广泛的字符范围。"
|
||||
await awrite(filename=pathname, data=content, encoding="gb2312")
|
||||
data = await aread(filename=pathname, encoding="utf-8")
|
||||
assert data == content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -10,15 +10,14 @@
|
|||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.common import awrite
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
||||
async def mock_file(filename, content=""):
|
||||
async with aiofiles.open(str(filename), mode="w") as file:
|
||||
await file.write(content)
|
||||
await awrite(filename=filename, data=content)
|
||||
|
||||
|
||||
async def mock_repo(local_path) -> (GitRepository, Path):
|
||||
|
|
|
|||
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
25
tests/metagpt/utils/test_repo_to_markdown.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.repo_to_markdown import repo_to_markdown
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
["repo_path", "output"],
|
||||
[(Path(__file__).parent.parent, Path(__file__).parent.parent.parent / f"workspace/unittest/{uuid.uuid4().hex}.md")],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_repo_to_markdown(repo_path: Path, output: Path):
|
||||
markdown = await repo_to_markdown(repo_path=repo_path, output=output)
|
||||
assert output.exists()
|
||||
assert markdown
|
||||
|
||||
output.unlink(missing_ok=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -9,7 +9,6 @@ import uuid
|
|||
from pathlib import Path
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import Config
|
||||
|
|
@ -46,7 +45,7 @@ async def test_s3(mocker):
|
|||
conn = S3(s3)
|
||||
object_name = "unittest.bak"
|
||||
await conn.upload_file(bucket=s3.bucket, local_path=__file__, object_name=object_name)
|
||||
pathname = (Path(__file__).parent / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname = (Path(__file__).parent / "../../../workspace/unittest" / uuid.uuid4().hex).with_suffix(".bak")
|
||||
pathname.unlink(missing_ok=True)
|
||||
await conn.download_file(bucket=s3.bucket, object_name=object_name, local_path=str(pathname))
|
||||
assert pathname.exists()
|
||||
|
|
@ -54,8 +53,7 @@ async def test_s3(mocker):
|
|||
assert url
|
||||
bin_data = await conn.get_object(bucket=s3.bucket, object_name=object_name)
|
||||
assert bin_data
|
||||
async with aiofiles.open(__file__, mode="r", encoding="utf-8") as reader:
|
||||
data = await reader.read()
|
||||
data = await aread(filename=__file__)
|
||||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
|
@ -69,8 +67,6 @@ async def test_s3(mocker):
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
await reader.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def _paragraphs(n):
|
|||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "Hello," * 1000, 3000, 5),
|
||||
(_msgs(), "gpt-4", "System", 2000, 3),
|
||||
|
|
@ -32,22 +32,23 @@ def _paragraphs(n):
|
|||
],
|
||||
)
|
||||
def test_reduce_message_length(msgs, model_name, system_text, reserved, expected):
|
||||
assert len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000 == expected
|
||||
length = len(reduce_message_length(msgs, model_name, system_text, reserved)) / (len("Hello,")) / 1000
|
||||
assert length == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
(" ".join("Hello World." for _ in range(4000)), "Prompt: {}", "gpt-4", "System", 2000, 2),
|
||||
(" ".join("Hello World." for _ in range(8000)), "Prompt: {}", "gpt-4-32k", "System", 4000, 1),
|
||||
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo", "System", 1000, 8),
|
||||
(" ".join("Hello World" for _ in range(8000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1000, 8),
|
||||
],
|
||||
)
|
||||
def test_generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved, expected):
|
||||
ret = list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved))
|
||||
assert len(ret) == expected
|
||||
chunk = len(list(generate_prompt_chunk(text, prompt_template, model_name, system_text, reserved)))
|
||||
assert chunk == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
64
tests/metagpt/utils/test_tree.py
Normal file
64
tests/metagpt/utils/test_tree.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.tree import _print_tree, tree
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("root", "rules"),
|
||||
[
|
||||
(str(Path(__file__).parent / "../.."), None),
|
||||
(str(Path(__file__).parent / "../.."), str(Path(__file__).parent / "../../../.gitignore")),
|
||||
],
|
||||
)
|
||||
def test_tree_command(root: str, rules: str):
|
||||
v = tree(root=root, gitignore=rules, run_command=True)
|
||||
assert v
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("tree", "want"),
|
||||
[
|
||||
({"a": {"b": {}, "c": {}}}, ["a", "+-- b", "+-- c"]),
|
||||
({"a": {"b": {}, "c": {"d": {}}}}, ["a", "+-- b", "+-- c", " +-- d"]),
|
||||
(
|
||||
{"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}},
|
||||
["a", "+-- b", "| +-- e", "| +-- f", "| +-- g", "+-- c", " +-- d"],
|
||||
),
|
||||
(
|
||||
{"h": {"a": {"b": {"e": {"f": {}, "g": {}}}, "c": {"d": {}}}, "i": {}}},
|
||||
[
|
||||
"h",
|
||||
"+-- a",
|
||||
"| +-- b",
|
||||
"| | +-- e",
|
||||
"| | +-- f",
|
||||
"| | +-- g",
|
||||
"| +-- c",
|
||||
"| +-- d",
|
||||
"+-- i",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test__print_tree(tree: dict, want: List[str]):
|
||||
v = _print_tree(tree)
|
||||
assert v == want
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
Loading…
Add table
Add a link
Reference in a new issue