mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-27 01:36:29 +02:00
feat: merge main
This commit is contained in:
commit
8a92fa0f36
404 changed files with 20076 additions and 1163 deletions
|
|
@ -38,7 +38,7 @@ DESIGN = {
|
|||
|
||||
|
||||
TASK = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
|
|
|
|||
0
tests/metagpt/configs/__init__.py
Normal file
0
tests/metagpt/configs/__init__.py
Normal file
34
tests/metagpt/configs/test_models_config.py
Normal file
34
tests/metagpt/configs/test_models_config.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH
|
||||
from metagpt.utils.common import aread, awrite
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_configs(context):
|
||||
default_model = ModelsConfig.default()
|
||||
assert default_model is not None
|
||||
|
||||
models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml")
|
||||
assert models
|
||||
|
||||
default_models = ModelsConfig.default()
|
||||
backup = ""
|
||||
if not default_models.models:
|
||||
backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml")
|
||||
test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml")
|
||||
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data)
|
||||
|
||||
try:
|
||||
action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1")
|
||||
assert action.private_llm.config.model == "YOUR_MODEL_NAME_1"
|
||||
assert context.config.llm.model != "YOUR_MODEL_NAME_1"
|
||||
finally:
|
||||
if backup:
|
||||
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
|
||||
|
||||
def mock_device_shape(self, adb_cmd: str) -> str:
|
||||
|
|
@ -16,8 +16,8 @@ def mock_device_shape_invalid(self, adb_cmd: str) -> str:
|
|||
return ADB_EXEC_FAIL
|
||||
|
||||
|
||||
def mock_list_devices(self, adb_cmd: str) -> str:
|
||||
return "devices\nemulator-5554"
|
||||
def mock_list_devices(self) -> str:
|
||||
return ["emulator-5554"]
|
||||
|
||||
|
||||
def mock_get_screenshot(self, adb_cmd: str) -> str:
|
||||
|
|
@ -34,9 +34,8 @@ def mock_write_read_operation(self, adb_cmd: str) -> str:
|
|||
|
||||
def test_android_ext_env(mocker):
|
||||
device_id = "emulator-5554"
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices", mock_list_devices)
|
||||
|
||||
ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/")
|
||||
assert ext_env.adb_prefix == f"adb -s {device_id} "
|
||||
|
|
@ -46,25 +45,20 @@ def test_android_ext_env(mocker):
|
|||
assert ext_env.device_shape == (720, 1080)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
)
|
||||
assert ext_env.device_shape == (0, 0)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices
|
||||
)
|
||||
assert ext_env.list_devices() == [device_id]
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot)
|
||||
assert ext_env.get_screenshot("screenshot_xxxx-xx-xx", "/data/") == Path("/data/screenshot_xxxx-xx-xx.png")
|
||||
|
||||
mocker.patch("metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
assert ext_env.get_xml("xml_xxxx-xx-xx", "/data/") == Path("/data/xml_xxxx-xx-xx.xml")
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
)
|
||||
res = "OK"
|
||||
assert ext_env.system_back() == res
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
# @Desc : the unittest of MinecraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.environment.minecraft.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
|
||||
|
||||
|
||||
def test_minecraft_ext_env():
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
|
||||
StanfordTownExtEnv,
|
||||
from metagpt.environment.stanford_town.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv
|
||||
|
||||
maze_asset_path = (
|
||||
Path(__file__).absolute().parent.joinpath("..", "..", "..", "data", "environment", "stanford_town", "the_ville")
|
||||
Path(__file__)
|
||||
.absolute()
|
||||
.parent.joinpath("..", "..", "..", "..", "metagpt/ext/stanford_town/static_dirs/assets/the_ville")
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
|
|||
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
ext_env.add_tiles_event(tile[1], tile[0], event=event)
|
||||
ext_env.add_event_from_tile(event, tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
||||
|
|
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
|
|||
|
||||
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
|
||||
|
||||
|
||||
def test_stanford_town_ext_env_observe_step():
|
||||
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
|
||||
obs, info = ext_env.reset()
|
||||
assert len(info) == 0
|
||||
assert len(obs["address_tiles"]) == 306
|
||||
|
||||
tile = (58, 9)
|
||||
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
|
||||
assert obs == "the Ville"
|
||||
|
||||
action = ext_env.action_space.sample()
|
||||
assert len(action) == 4
|
||||
assert len(action["event"]) == 4
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ExtEnv&Env
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
|
|
@ -12,11 +14,26 @@ from metagpt.environment.base_env import (
|
|||
mark_as_readable,
|
||||
mark_as_writeable,
|
||||
)
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
|
||||
|
||||
class ForTestEnv(Environment):
|
||||
value: int = 0
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@mark_as_readable
|
||||
def read_api_no_param(self):
|
||||
return self.value
|
||||
|
|
@ -44,11 +61,11 @@ async def test_ext_env():
|
|||
assert len(apis) > 0
|
||||
assert len(apis["read_api"]) == 3
|
||||
|
||||
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
assert env.value == 15
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await env.observe("not_exist_api")
|
||||
with pytest.raises(KeyError):
|
||||
await env.read_from_api("not_exist_api")
|
||||
|
||||
assert await env.observe("read_api_no_param") == 15
|
||||
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
assert await env.read_from_api("read_api_no_param") == 15
|
||||
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
|
|
|
|||
|
|
@ -2,33 +2,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of WerewolfExtEnv
|
||||
|
||||
from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv
|
||||
from metagpt.environment.werewolf.const import RoleState, RoleType
|
||||
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
class Werewolf(Role):
|
||||
profile: str = "Werewolf"
|
||||
profile: str = RoleType.WEREWOLF.value
|
||||
|
||||
|
||||
class Villager(Role):
|
||||
profile: str = "Villager"
|
||||
profile: str = RoleType.VILLAGER.value
|
||||
|
||||
|
||||
class Witch(Role):
|
||||
profile: str = "Witch"
|
||||
profile: str = RoleType.WITCH.value
|
||||
|
||||
|
||||
class Guard(Role):
|
||||
profile: str = "Guard"
|
||||
profile: str = RoleType.GUARD.value
|
||||
|
||||
|
||||
def test_werewolf_ext_env():
|
||||
players_state = {
|
||||
"Player0": ("Werewolf", RoleState.ALIVE),
|
||||
"Player1": ("Werewolf", RoleState.ALIVE),
|
||||
"Player2": ("Villager", RoleState.ALIVE),
|
||||
"Player3": ("Witch", RoleState.ALIVE),
|
||||
"Player4": ("Guard", RoleState.ALIVE),
|
||||
"Player0": (RoleType.WEREWOLF.value, RoleState.ALIVE),
|
||||
"Player1": (RoleType.WEREWOLF.value, RoleState.ALIVE),
|
||||
"Player2": (RoleType.VILLAGER.value, RoleState.ALIVE),
|
||||
"Player3": (RoleType.WITCH.value, RoleState.ALIVE),
|
||||
"Player4": (RoleType.GUARD.value, RoleState.ALIVE),
|
||||
}
|
||||
ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])
|
||||
|
||||
|
|
@ -41,9 +42,9 @@ def test_werewolf_ext_env():
|
|||
assert "Werewolves, please open your eyes" in curr_instr["content"]
|
||||
|
||||
# current step_idx = 5
|
||||
ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player10", player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player0", player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player1", player_name="Player4")
|
||||
assert ext_env.player_hunted == "Player4"
|
||||
assert len(ext_env.living_players) == 5 # hunted but can be saved by witch
|
||||
|
||||
|
|
@ -52,11 +53,11 @@ def test_werewolf_ext_env():
|
|||
|
||||
# current step_idx = 18
|
||||
assert ext_env.step_idx == 18
|
||||
ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2")
|
||||
ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3")
|
||||
ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3")
|
||||
ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4")
|
||||
ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2")
|
||||
ext_env.vote_kill_someone(voter_name="Player0", player_name="Player2")
|
||||
ext_env.vote_kill_someone(voter_name="Player1", player_name="Player3")
|
||||
ext_env.vote_kill_someone(voter_name="Player2", player_name="Player3")
|
||||
ext_env.vote_kill_someone(voter_name="Player3", player_name="Player4")
|
||||
ext_env.vote_kill_someone(voter_name="Player4", player_name="Player2")
|
||||
assert ext_env.player_current_dead == "Player2"
|
||||
assert len(ext_env.living_players) == 4
|
||||
|
||||
|
|
|
|||
3
tests/metagpt/ext/__init__.py
Normal file
3
tests/metagpt/ext/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/android_assistant/__init__.py
Normal file
3
tests/metagpt/ext/android_assistant/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
95
tests/metagpt/ext/android_assistant/test_an.py
Normal file
95
tests/metagpt/ext/android_assistant/test_an.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : test on android emulator action. After Modify Role Test, this script is discarded.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import metagpt
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.ext.android_assistant.actions.manual_record import ManualRecord
|
||||
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
|
||||
from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse
|
||||
from metagpt.ext.android_assistant.actions.self_learn_and_reflect import (
|
||||
SelfLearnAndReflect,
|
||||
)
|
||||
from tests.metagpt.environment.android_env.test_android_ext_env import (
|
||||
mock_device_shape,
|
||||
mock_list_devices,
|
||||
)
|
||||
|
||||
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/unitest_Contacts")
|
||||
TASK_PATH.mkdir(parents=True, exist_ok=True)
|
||||
DEMO_NAME = str(time.time())
|
||||
SELF_EXPLORE_DOC_PATH = TASK_PATH.joinpath("auto_docs")
|
||||
PARSE_RECORD_DOC_PATH = TASK_PATH.joinpath("demo_docs")
|
||||
|
||||
device_id = "emulator-5554"
|
||||
xml_dir = Path("/sdcard")
|
||||
screenshot_dir = Path("/sdcard/Pictures/Screenshots")
|
||||
|
||||
|
||||
metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd = mock_device_shape
|
||||
metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices = mock_list_devices
|
||||
|
||||
|
||||
test_env_self_learn_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_self_learning = SelfLearnAndReflect()
|
||||
|
||||
test_env_manual_learn_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_manual_record = ManualRecord()
|
||||
test_manual_parse = ParseRecord()
|
||||
|
||||
test_env_screenshot_parse_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_screenshot_parse = ScreenshotParse()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
test_action_list = [
|
||||
test_self_learning.run(
|
||||
round_count=20,
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
last_act="",
|
||||
task_dir=TASK_PATH / "demos" / f"self_learning_{DEMO_NAME}",
|
||||
docs_dir=SELF_EXPLORE_DOC_PATH,
|
||||
env=test_env_self_learn_android,
|
||||
),
|
||||
test_manual_record.run(
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}",
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
env=test_env_manual_learn_android,
|
||||
),
|
||||
test_manual_parse.run(
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改
|
||||
docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改
|
||||
env=test_env_manual_learn_android,
|
||||
),
|
||||
test_screenshot_parse.run(
|
||||
round_count=20,
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
last_act="",
|
||||
task_dir=TASK_PATH / f"act_{DEMO_NAME}",
|
||||
docs_dir=PARSE_RECORD_DOC_PATH,
|
||||
env=test_env_screenshot_parse_android,
|
||||
grid_on=False,
|
||||
),
|
||||
]
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*test_action_list))
|
||||
loop.close()
|
||||
29
tests/metagpt/ext/android_assistant/test_parse_record.py
Normal file
29
tests/metagpt/ext/android_assistant/test_parse_record.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : test case (imgs from appagent's)
|
||||
|
||||
import asyncio
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
|
||||
|
||||
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/demo_Contacts")
|
||||
TEST_BEFORE_PATH = TASK_PATH.joinpath("labeled_screenshots/0_labeled.png")
|
||||
TEST_AFTER_PATH = TASK_PATH.joinpath("labeled_screenshots/1_labeled.png")
|
||||
RECORD_PATH = TASK_PATH.joinpath("record.txt")
|
||||
TASK_DESC_PATH = TASK_PATH.joinpath("task_desc.txt")
|
||||
DOCS_DIR = TASK_PATH.joinpath("storage")
|
||||
|
||||
test_action = Action(name="test")
|
||||
|
||||
|
||||
async def manual_learn_test():
|
||||
parse_record = ParseRecord()
|
||||
await parse_record.run(app_name="demo_Contacts", task_dir=TASK_PATH, docs_dir=DOCS_DIR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(manual_learn_test())
|
||||
loop.close()
|
||||
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/gen_action_details.py
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.ext.stanford_town.actions.gen_action_details import (
|
||||
GenActionArena,
|
||||
GenActionDetails,
|
||||
GenActionObject,
|
||||
GenActionSector,
|
||||
GenActObjDescription,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_action_details():
|
||||
role = STRole(
|
||||
name="Klaus Mueller",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
act_desp = "sleeping"
|
||||
act_dura = "120"
|
||||
|
||||
access_tile = await role.rc.env.read_from_api(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
|
||||
)
|
||||
act_world = access_tile["world"]
|
||||
assert act_world == "the Ville"
|
||||
|
||||
sector = await GenActionSector().run(role, access_tile, act_desp)
|
||||
arena = await GenActionArena().run(role, act_desp, act_world, sector)
|
||||
temp_address = f"{act_world}:{sector}:{arena}"
|
||||
obj = await GenActionObject().run(role, act_desp, temp_address)
|
||||
|
||||
act_obj_desp = await GenActObjDescription().run(role, obj, act_desp)
|
||||
|
||||
result_dict = await GenActionDetails().run(role, act_desp, act_dura)
|
||||
|
||||
# gen_action_sector
|
||||
assert isinstance(sector, str)
|
||||
assert sector in role.s_mem.get_str_accessible_sectors(act_world)
|
||||
|
||||
# gen_action_arena
|
||||
assert isinstance(arena, str)
|
||||
assert arena in role.s_mem.get_str_accessible_sector_arenas(f"{act_world}:{sector}")
|
||||
|
||||
# gen_action_obj
|
||||
assert isinstance(obj, str)
|
||||
assert obj in role.s_mem.get_str_accessible_arena_game_objects(temp_address)
|
||||
|
||||
if result_dict:
|
||||
for key in [
|
||||
"action_address",
|
||||
"action_duration",
|
||||
"action_description",
|
||||
"action_pronunciatio",
|
||||
"action_event",
|
||||
"chatting_with",
|
||||
"chat",
|
||||
"chatting_with_buffer",
|
||||
"chatting_end_time",
|
||||
"act_obj_description",
|
||||
"act_obj_pronunciatio",
|
||||
"act_obj_event",
|
||||
]:
|
||||
assert key in result_dict
|
||||
assert result_dict["action_address"] == f"{temp_address}:{obj}"
|
||||
assert result_dict["action_duration"] == int(act_dura)
|
||||
assert result_dict["act_obj_description"] == act_obj_desp
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/summarize_conv
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_conv():
|
||||
conv = [("Role_A", "what's the weather today?"), ("Role_B", "It looks pretty good, and I will take a walk then.")]
|
||||
|
||||
output = await SummarizeConv().run(conv)
|
||||
assert "weather" in output
|
||||
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of AgentMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import AgentMemory
|
||||
from metagpt.ext.stanford_town.memory.retrieve import agent_retrieve
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试; Load方法中使用了add方法,验证Load即可验证所有add
|
||||
2.2 Get方法测试
|
||||
"""
|
||||
memory_easy_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memroy_chat_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memory_save_easy_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
memory_save_chat_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
@pytest.fixture
|
||||
def agent_memory(self):
|
||||
# 创建一个AgentMemory实例并返回,可以在所有测试用例中共享
|
||||
test_agent_memory = AgentMemory()
|
||||
test_agent_memory.set_mem_path(memroy_chat_storage_path)
|
||||
return test_agent_memory
|
||||
|
||||
def test_load(self, agent_memory):
|
||||
logger.info(f"存储路径为:{agent_memory.memory_saved}")
|
||||
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
|
||||
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
|
||||
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
|
||||
|
||||
assert agent_memory.embeddings is not None
|
||||
|
||||
def test_save(self, agent_memory):
|
||||
try:
|
||||
agent_memory.save(memory_save_chat_test_path)
|
||||
logger.info("成功存储")
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_summary_function(self, agent_memory):
|
||||
logger.info(f"event长度为{len(agent_memory.event_list)}")
|
||||
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
|
||||
logger.info(f"chat长度为{len(agent_memory.chat_list)}")
|
||||
result1 = agent_memory.get_summarized_latest_events(4)
|
||||
logger.info(f"总结最近事件结果为:{result1}")
|
||||
|
||||
def test_get_last_chat_function(self, agent_memory):
|
||||
result2 = agent_memory.get_last_chat("customers")
|
||||
logger.info(f"上一次对话是{result2}")
|
||||
|
||||
def test_retrieve_function(self, agent_memory):
|
||||
focus_points = ["who i love?"]
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
nodes = [
|
||||
[i.last_accessed, i]
|
||||
for i in agent_memory.event_list + agent_memory.thought_list
|
||||
if "idle" not in i.embedding_key
|
||||
]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(agent_memory, datetime.now() - timedelta(days=120), 0.99, focal_pt, nodes, 5)
|
||||
final_result = []
|
||||
for n in results:
|
||||
for i in agent_memory.storage:
|
||||
if i.memory_id == n:
|
||||
i.last_accessed = datetime.now() - timedelta(days=120)
|
||||
final_result.append(i)
|
||||
|
||||
retrieved[focal_pt] = final_result
|
||||
logger.info(f"检索结果为{retrieved}")
|
||||
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of BasicMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试
|
||||
2.2 Add方法测试
|
||||
2.3 Get方法测试
|
||||
"""
|
||||
|
||||
# Create some sample BasicMemory instances
|
||||
memory1 = BasicMemory(
|
||||
memory_id="1",
|
||||
memory_count=1,
|
||||
type_count=1,
|
||||
memory_type="event",
|
||||
depth=1,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject1",
|
||||
predicate="Predicate1",
|
||||
object="Object1",
|
||||
content="This is content 1",
|
||||
embedding_key="embedding_key_1",
|
||||
poignancy=1,
|
||||
keywords=["keyword1", "keyword2"],
|
||||
filling=["memory_id_2"],
|
||||
)
|
||||
memory2 = BasicMemory(
|
||||
memory_id="2",
|
||||
memory_count=2,
|
||||
type_count=2,
|
||||
memory_type="thought",
|
||||
depth=2,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject2",
|
||||
predicate="Predicate2",
|
||||
object="Object2",
|
||||
content="This is content 2",
|
||||
embedding_key="embedding_key_2",
|
||||
poignancy=2,
|
||||
keywords=["keyword3", "keyword4"],
|
||||
filling=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mem_set():
|
||||
basic_mem2 = memory2
|
||||
yield basic_mem2
|
||||
|
||||
|
||||
def test_basic_mem_function(basic_mem_set):
|
||||
a, b, c = basic_mem_set.summary()
|
||||
logger.info(f"{a}{b}{c}")
|
||||
assert a == "Subject2"
|
||||
|
||||
|
||||
def test_basic_mem_save(basic_mem_set):
|
||||
result = basic_mem_set.save_to_dict()
|
||||
logger.info(f"save结果为{result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MemoryTree
|
||||
|
||||
from metagpt.ext.stanford_town.memory.spatial_memory import MemoryTree
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
|
||||
|
||||
def test_spatial_memory():
|
||||
f_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/spatial_memory.json"
|
||||
)
|
||||
x = MemoryTree()
|
||||
x.set_mem_path(f_path)
|
||||
assert x.tree
|
||||
assert "the Ville" in x.tree
|
||||
assert "Isabella Rodriguez's apartment" in x.get_str_accessible_sectors("the Ville")
|
||||
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of roles conversation
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.plan.converse import agent_conversation
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH, STORAGE_PATH
|
||||
from metagpt.ext.stanford_town.utils.mg_ga_transform import get_reverie_meta
|
||||
from metagpt.ext.stanford_town.utils.utils import copy_folder
|
||||
|
||||
|
||||
async def init_two_roles(fork_sim_code: str = "base_the_ville_isabella_maria_klaus") -> Tuple["STRole"]:
|
||||
sim_code = "unittest_sim"
|
||||
|
||||
copy_folder(str(STORAGE_PATH.joinpath(fork_sim_code)), str(STORAGE_PATH.joinpath(sim_code)))
|
||||
|
||||
reverie_meta = get_reverie_meta(fork_sim_code)
|
||||
role_ir_name = "Isabella Rodriguez"
|
||||
role_km_name = "Klaus Mueller"
|
||||
|
||||
env = StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH)
|
||||
|
||||
role_ir = STRole(
|
||||
name=role_ir_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_ir_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_ir.set_env(env)
|
||||
await role_ir.init_curr_tile()
|
||||
|
||||
role_km = STRole(
|
||||
name=role_km_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_km_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_km.set_env(env)
|
||||
await role_km.init_curr_tile()
|
||||
|
||||
return role_ir, role_km
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_conversation():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
|
||||
curr_chat = await agent_conversation(role_ir, role_km, conv_rounds=2)
|
||||
assert len(curr_chat) % 2 == 0
|
||||
|
||||
meet = False
|
||||
for conv in curr_chat:
|
||||
if "Valentine's Day party" in conv[1]:
|
||||
# conv[0] speaker, conv[1] utterance
|
||||
meet = True
|
||||
assert meet
|
||||
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of st_plan
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.plan.st_plan import _choose_retrieved, _should_react
|
||||
from tests.metagpt.ext.stanford_town.plan.test_conversation import init_two_roles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_react():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
roles = {role_ir.name: role_ir, role_km.name: role_km}
|
||||
role_ir.scratch.act_address = "mock data"
|
||||
|
||||
observed = await role_ir.observe()
|
||||
retrieved = role_ir.retrieve(observed)
|
||||
|
||||
focused_event = _choose_retrieved(role_ir.name, retrieved)
|
||||
|
||||
if focused_event:
|
||||
reaction_mode = await _should_react(role_ir, focused_event, roles) # chat with Isabella Rodriguez
|
||||
assert not reaction_mode
|
||||
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of STRole
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe():
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
ret_events = await role.observe()
|
||||
assert ret_events
|
||||
for event in ret_events:
|
||||
assert isinstance(event, BasicMemory)
|
||||
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of reflection
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.actions.run_reflect_action import (
|
||||
AgentEventTriple,
|
||||
AgentFocusPt,
|
||||
AgentInsightAndGuidance,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflect():
|
||||
"""
|
||||
init STRole form local json, set sim_code(path),curr_time & start_time
|
||||
"""
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
role.init_curr_tile()
|
||||
|
||||
run_focus = AgentFocusPt()
|
||||
statements = ""
|
||||
await run_focus.run(role, statements, n=3)
|
||||
|
||||
"""
|
||||
这里有通过测试的结果,但是更多时候LLM生成的结果缺少了because of;考虑修改一下prompt
|
||||
result = {'Klaus Mueller and Maria Lopez have a close relationship because they have been friends for a long time and have a strong bond': [1, 2, 5, 9, 11, 14], 'Klaus Mueller has a crush on Maria Lopez': [8, 15, 24], 'Klaus Mueller is academically inclined and actively researching a topic': [13, 20], 'Klaus Mueller is socially active and acquainted with Isabella Rodriguez': [17, 21, 22], 'Klaus Mueller is organized and prepared': [19]}
|
||||
"""
|
||||
run_insight = AgentInsightAndGuidance()
|
||||
statements = "[user: Klaus Mueller has a close relationship with Maria Lopez, user:s Mueller and Maria Lopez have a close relationship, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller and Maria Lopez have a strong relationship, user: Klaus Mueller is a dormmate of Maria Lopez., user: Klaus Mueller and Maria Lopez have a strong bond, user: Klaus Mueller has a crush on Maria Lopez, user: Klaus Mueller and Maria Lopez have been friends for more than 2 years., user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller Maria Lopez is heading off to college., user: Klaus Mueller and Maria Lopez have a close relationship, user: Klaus Mueller is actively researching a topic, user: Klaus Mueller is close friends and classmates with Maria Lopez., user: Klaus Mueller is socially active, user: Klaus Mueller has a crush on Maria Lopez., user: Klaus Mueller and Maria Lopez have been friends for a long time, user: Klaus Mueller is academically inclined, user: For Klaus Mueller's planning: should remember to ask Maria Lopez about her research paper, as she found it interesting that he mentioned it., user: Klaus Mueller is acquainted with Isabella Rodriguez, user: Klaus Mueller is organized and prepared, user: Maria Lopez is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is a student, user: Klaus Mueller is a student, user: Klaus Mueller is conversing about two friends named Klaus Mueller and Maria Lopez discussing their morning plans and progress on a research paper before Maria heads off to college., user: Klaus Mueller is socially active, user: Klaus Mueller is socially active, user: Klaus Mueller is socially active and acquainted with Isabella Rodriguez, user: Klaus Mueller has a crush on Maria Lopez]"
|
||||
await run_insight.run(role, statements, n=5)
|
||||
|
||||
run_triple = AgentEventTriple()
|
||||
statements = "(Klaus Mueller is academically inclined)"
|
||||
await run_triple.run(statements, role)
|
||||
|
||||
role.scratch.importance_trigger_curr = -1
|
||||
role.reflect()
|
||||
3
tests/metagpt/ext/werewolf/__init__.py
Normal file
3
tests/metagpt/ext/werewolf/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/werewolf/actions/__init__.py
Normal file
3
tests/metagpt/ext/werewolf/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
164
tests/metagpt/ext/werewolf/actions/test_experience_operation.py
Normal file
164
tests/metagpt/ext/werewolf/actions/test_experience_operation.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.ext.werewolf.actions import AddNewExperiences, RetrieveExperiences
|
||||
from metagpt.ext.werewolf.schema import RoleExperience
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class TestExperiencesOperation:
|
||||
collection_name = "test"
|
||||
test_round_id = "test_01"
|
||||
version = "test"
|
||||
samples_to_add = [
|
||||
RoleExperience(
|
||||
profile="Witch",
|
||||
reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. "
|
||||
"Player4's behavior is suspicious.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="Witch",
|
||||
reflection="The game is in a critical state with only three players left, "
|
||||
"and I need to make a wise decision to save Player7 or not.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="Seer",
|
||||
reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, "
|
||||
"sided with him. I, as the real Seer, am under suspicion.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection1",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_01-10",
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection2",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_11-20",
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection3",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_21-30",
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add(self):
|
||||
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
|
||||
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
)
|
||||
if saved_file.exists():
|
||||
saved_file.unlink()
|
||||
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.run(self.samples_to_add)
|
||||
|
||||
# test insertion
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
# test if we record the samples correctly to local file
|
||||
# & test if we could recover a embedding db from the file
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.add_from_file(saved_file)
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "one player claimed to be Seer and the other Witch"
|
||||
results = action.run(query, profile="Witch")
|
||||
results = json.loads(results)
|
||||
|
||||
assert len(results) == 2, "Witch should have 2 experiences"
|
||||
assert "The game is intense with two players" in results[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_filtering(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "some test query"
|
||||
profile = "TestRole"
|
||||
|
||||
excluded_version = ""
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 3
|
||||
|
||||
excluded_version = self.version + "_21-30"
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestActualRetrieve:
|
||||
collection_name = "role_reflection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_experience_pool(self):
|
||||
logger.info("check experience pool")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
if action.engine:
|
||||
all_experiences = action.engine.retriever._index._vector_store._collection.get()
|
||||
logger.info(f"{len(all_experiences['metadatas'])=}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_werewolf_experience(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
logger.info(f"test retrieval with {query=}")
|
||||
action.run(query, "Werewolf")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
logger.info(f"test retrieval with {query=}")
|
||||
results = action.run(query, "Seer")
|
||||
assert "conflict" not in results # 相似局面应该需要包含conflict关键词
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience_filtering(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
excluded_version = "01-10"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
excluded_version = "11-20"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
assert results_01_10 == results_11_20
|
||||
|
|
@ -60,3 +60,14 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
|
|||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
||||
mock_llm_config_bedrock = LLMConfig(
|
||||
api_type="bedrock",
|
||||
model="gpt-100",
|
||||
region_name="somewhere",
|
||||
access_key="123abc",
|
||||
secret_key="123abc",
|
||||
max_token=10000,
|
||||
)
|
||||
|
||||
mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx")
|
||||
|
|
|
|||
|
|
@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
|
|||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
# For Amazon Bedrock
|
||||
# Check the API documentation of each model
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide
|
||||
BEDROCK_PROVIDER_REQUEST_BODY = {
|
||||
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
|
||||
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
|
||||
"ai21": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"topP": 0.0,
|
||||
"maxTokens": 0,
|
||||
"stopSequences": [],
|
||||
"countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0},
|
||||
"frequencyPenalty": {"scale": 0.0},
|
||||
},
|
||||
"cohere": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"p": 0.0,
|
||||
"k": 0.0,
|
||||
"max_tokens": 0,
|
||||
"stop_sequences": [],
|
||||
"return_likelihoods": "NONE",
|
||||
"stream": False,
|
||||
"num_generations": 0,
|
||||
"logit_bias": {},
|
||||
"truncate": "NONE",
|
||||
},
|
||||
"anthropic": {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 0,
|
||||
"system": "",
|
||||
"messages": [{"role": "", "content": ""}],
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.0,
|
||||
"top_k": 0,
|
||||
"stop_sequences": [],
|
||||
},
|
||||
"amazon": {
|
||||
"inputText": "",
|
||||
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
|
||||
},
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY = {
|
||||
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
|
||||
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
|
||||
"ai21": {
|
||||
"id": "",
|
||||
"prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [
|
||||
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
|
||||
],
|
||||
},
|
||||
"cohere": {
|
||||
"generations": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"id": "",
|
||||
"text": "Hello World",
|
||||
"likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}],
|
||||
"is_finished": True,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"id": "",
|
||||
"prompt": "",
|
||||
},
|
||||
"anthropic": {
|
||||
"id": "",
|
||||
"model": "",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "",
|
||||
"stop_sequence": "",
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
"amazon": {
|
||||
"inputTextTokenCount": 0,
|
||||
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
|
||||
},
|
||||
}
|
||||
|
|
|
|||
85
tests/metagpt/provider/test_ark.py
Normal file
85
tests/metagpt/provider/test_ark.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
用于火山方舟Python SDK V3的测试用例
|
||||
API文档:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "AI assistant"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000}
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp.model = "doubao-pro-32k-240515"
|
||||
default_resp.usage = USAGE
|
||||
|
||||
|
||||
def create_chat_completion_chunk(
|
||||
content: str, finish_reason: str = None, choices: List[Choice] = None
|
||||
) -> ChatCompletionChunk:
|
||||
if choices is None:
|
||||
choices = [
|
||||
Choice(
|
||||
delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None),
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id="012",
|
||||
choices=choices,
|
||||
created=1716278586,
|
||||
model="doubao-pro-32k-240515",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
usage=None if choices else USAGE,
|
||||
)
|
||||
|
||||
|
||||
ark_resp_chunk = create_chat_completion_chunk(content="")
|
||||
ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop")
|
||||
ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[])
|
||||
|
||||
|
||||
async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def mock_ark_acompletions_create(
|
||||
self, stream: bool = False, **kwargs
|
||||
) -> Union[ChatCompletionChunk, ChatCompletion]:
|
||||
if stream:
|
||||
chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last]
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ark_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create)
|
||||
|
||||
llm = ArkLLM(mock_llm_config_ark)
|
||||
|
||||
resp = await llm.acompletion(messages)
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.choices[0].message.content == resp_cont
|
||||
assert resp.usage == USAGE
|
||||
|
||||
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
|
||||
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
BEDROCK_PROVIDER_REQUEST_BODY,
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY,
|
||||
)
|
||||
|
||||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
"completion_tokens": 1000000,
|
||||
}
|
||||
|
||||
|
||||
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
self._update_costs(usage, self.config.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
||||
provider = self.config.model.split(".")[0]
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
elif provider == "anthropic":
|
||||
response_body_bytes = dict2bytes(
|
||||
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
|
||||
)
|
||||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
def get_bedrock_request_body(model_id) -> dict:
|
||||
provider = model_id.split(".")[0]
|
||||
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
|
||||
|
||||
|
||||
def is_subset(subset, superset) -> bool:
|
||||
"""Ensure all fields in request body are allowed.
|
||||
|
||||
```python
|
||||
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
|
||||
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
|
||||
is_subset(subset, superset)
|
||||
```
|
||||
|
||||
"""
|
||||
for key, value in subset.items():
|
||||
if key not in superset:
|
||||
return False
|
||||
if isinstance(value, dict):
|
||||
if not isinstance(superset[key], dict):
|
||||
return False
|
||||
if not is_subset(value, superset[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="class", params=models)
|
||||
def bedrock_api(request) -> BedrockLLM:
|
||||
model_id = request.param
|
||||
mock_llm_config_bedrock.model = model_id
|
||||
api = BedrockLLM(mock_llm_config_bedrock)
|
||||
return api
|
||||
|
||||
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_invoke_model_stream,
|
||||
)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=True) == "Hello World"
|
||||
|
|
@ -1,62 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
"""
|
||||
用于讯飞星火SDK的测试用例
|
||||
文档:https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
|
||||
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
|
||||
from sparkai.core.outputs.chat_generation import ChatGeneration
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
USAGE = {
|
||||
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
|
||||
}
|
||||
spark_agenerate_result = LLMResult(
|
||||
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
|
||||
)
|
||||
|
||||
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
assert resp == resp_cont
|
||||
async def mock_spark_acreate(self, messages, stream):
|
||||
if stream:
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return spark_agenerate_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert spark_llm.get_choice_text(resp) == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
|
|||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
|
@ -25,10 +26,6 @@ class TestSimpleEngine:
|
|||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
|
@ -41,14 +38,18 @@ class TestSimpleEngine:
|
|||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_file_extractor(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
mock_get_file_extractor,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
|
|
@ -58,6 +59,8 @@ class TestSimpleEngine:
|
|||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
file_extractor = mocker.MagicMock()
|
||||
mock_get_file_extractor.return_value = file_extractor
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
|
|
@ -80,12 +83,11 @@ class TestSimpleEngine:
|
|||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
mock_simple_directory_reader.assert_called_once_with(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
|
|
@ -119,7 +121,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
assert engine._transformations is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
|
|
@ -137,6 +139,7 @@ class TestSimpleEngine:
|
|||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.as_retriever.return_value = "retriever"
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
|
|
@ -149,7 +152,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
assert engine._retriever == "retriever"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
|
|
@ -200,14 +203,11 @@ class TestSimpleEngine:
|
|||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
|
|
@ -230,7 +230,7 @@ class TestSimpleEngine:
|
|||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
|
@ -308,3 +308,17 @@ class TestSimpleEngine:
|
|||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
||||
def test_get_file_extractor(self, mocker):
|
||||
# mock no omniparse config
|
||||
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
|
||||
mock_omniparse_config.base_url = ""
|
||||
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert file_extractor == {}
|
||||
|
||||
# mock have omniparse config
|
||||
mock_omniparse_config.base_url = "http://localhost:8000"
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert ".pdf" in file_extractor
|
||||
assert isinstance(file_extractor[".pdf"], OmniParse)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,5 @@ class TestConfigBasedFactory:
|
|||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert val is None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
|
|
@ -43,6 +45,14 @@ class TestRetrieverFactory:
|
|||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(self, mocker):
|
||||
return [TextNode(text="msg")]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
|
|
@ -52,42 +62,40 @@ class TestRetrieverFactory:
|
|||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(
|
||||
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
|
||||
)
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
|
|
@ -111,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
from llama_index.core import Document
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import (
|
||||
OmniParsedResult,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
|
||||
|
||||
class TestOmniParseClient:
|
||||
parse_client = OmniParseClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_document(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_document"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# bytes data must provide bytes_filename
|
||||
await self.parse_client.parse_document(file_bytes)
|
||||
|
||||
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_video(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_video"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
# Wrong file extension test
|
||||
await self.parse_client.parse_video(TEST_DOCX)
|
||||
|
||||
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_audio(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_audio"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
|
||||
class TestOmniParse:
|
||||
@pytest.fixture
|
||||
def mock_omniparse(self):
|
||||
parser = OmniParse(
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
)
|
||||
)
|
||||
return parser
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(self, mock_omniparse, mock_request_parse):
|
||||
# mock
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
# single file
|
||||
documents = mock_omniparse.load_data(file_path=TEST_PDF)
|
||||
doc = documents[0]
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
||||
# multi files
|
||||
file_paths = [TEST_DOCX, TEST_PDF]
|
||||
mock_omniparse.parse_type = OmniParseType.DOCUMENT
|
||||
documents = await mock_omniparse.aload_data(file_path=file_paths)
|
||||
doc = documents[0]
|
||||
|
||||
# assert
|
||||
assert isinstance(doc, Document)
|
||||
assert len(documents) == len(file_paths)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
|
@ -10,7 +10,6 @@ import json
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -30,11 +29,7 @@ async def test_product_manager(new_filename):
|
|||
rsp = await product_manager.run(MockMessages.req)
|
||||
assert context.git_repo
|
||||
assert context.repo
|
||||
assert rsp.cause_by == any_to_str(PrepareDocuments)
|
||||
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
|
||||
|
||||
# write prd
|
||||
rsp = await product_manager.run(rsp)
|
||||
assert rsp.cause_by == any_to_str(WritePRD)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
|
|
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
|
|||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context(context):
|
||||
context.kwargs.set("a", "a")
|
||||
context.cost_manager.max_budget = 9
|
||||
company = Team(context=context)
|
||||
|
||||
save_to = context.repo.workdir / "serial"
|
||||
company.serialize(save_to)
|
||||
|
||||
company.deserialize(save_to, Context())
|
||||
assert company.env.context.repo
|
||||
assert company.env.context.repo.workdir == context.repo.workdir
|
||||
assert company.env.context.kwargs.a == "a"
|
||||
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -105,11 +105,11 @@ def test_config_mixin_4_multi_inheritance_override_config():
|
|||
async def test_config_priority():
|
||||
"""If action's config is set, then its llm will be set, otherwise, it will use the role's llm"""
|
||||
home_dir = Path.home() / CONFIG_ROOT
|
||||
gpt4t = Config.from_home("gpt-4-1106-preview.yaml")
|
||||
gpt4t = Config.from_home("gpt-4-turbo.yaml")
|
||||
if not home_dir.exists():
|
||||
assert gpt4t is None
|
||||
gpt35 = Config.default()
|
||||
gpt35.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt35.llm.model = "gpt-4-turbo"
|
||||
gpt4 = Config.default()
|
||||
gpt4.llm.model = "gpt-4-0613"
|
||||
|
||||
|
|
@ -127,8 +127,8 @@ async def test_config_priority():
|
|||
env = Environment(desc="US election live broadcast")
|
||||
Team(investment=10.0, env=env, roles=[A, B, C])
|
||||
|
||||
assert a1.llm.model == "gpt-4-1106-preview" if Path(home_dir / "gpt-4-1106-preview.yaml").exists() else "gpt-4-0613"
|
||||
assert a1.llm.model == "gpt-4-turbo" if Path(home_dir / "gpt-4-turbo.yaml").exists() else "gpt-4-0613"
|
||||
assert a2.llm.model == "gpt-4-0613"
|
||||
assert a3.llm.model == "gpt-3.5-turbo-1106"
|
||||
assert a3.llm.model == "gpt-4-turbo"
|
||||
|
||||
# history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="a1", n_round=3)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def test_add_role(env: Environment):
|
|||
name="Alice", profile="product manager", goal="create a new product", constraints="limited resources"
|
||||
)
|
||||
env.add_role(role)
|
||||
assert env.get_role(role.profile) == role
|
||||
assert env.get_role(str(role._setting)) == role
|
||||
|
||||
|
||||
def test_get_roles(env: Environment):
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class MockSearchEnine:
|
|||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.BING, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ from typing import Literal, Union
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_convert import (
|
||||
convert_code_to_tool_schema,
|
||||
convert_code_to_tool_schema_ast,
|
||||
)
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -128,3 +131,91 @@ def test_convert_code_to_tool_schema_function():
|
|||
def test_convert_code_to_tool_schema_async_function():
|
||||
schema = convert_code_to_tool_schema(dummy_async_fn)
|
||||
assert schema.get("type") == "async_function"
|
||||
|
||||
|
||||
TEST_CODE_FILE_TEXT = '''
|
||||
import pandas as pd # imported obj should not be parsed
|
||||
from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed
|
||||
from ..some_module2 import some_imported_function2 # relative import should not result in an error
|
||||
|
||||
class MyClass:
|
||||
"""This is a MyClass docstring."""
|
||||
def __init__(self, arg1):
|
||||
"""This is the constructor docstring."""
|
||||
self.arg1 = arg1
|
||||
|
||||
def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:
|
||||
"""
|
||||
This is a method docstring.
|
||||
|
||||
Args:
|
||||
arg2 (Union[list[str], str]): A union of a list of strings and a string.
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: A tuple of an integer and a string.
|
||||
"""
|
||||
return self.arg4 + arg5
|
||||
|
||||
async def my_async_method(self, some_arg) -> str:
|
||||
return "hi"
|
||||
|
||||
def _private_method(self): # private should not be parsed
|
||||
return "private"
|
||||
|
||||
def my_function(arg1, arg2) -> dict:
|
||||
"""This is a function docstring."""
|
||||
return arg1 + arg2
|
||||
|
||||
def my_async_function(arg1, arg2) -> dict:
|
||||
return arg1 + arg2
|
||||
|
||||
def _private_function(): # private should not be parsed
|
||||
return "private"
|
||||
'''
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_ast():
|
||||
expected = {
|
||||
"MyClass": {
|
||||
"type": "class",
|
||||
"description": "This is a MyClass docstring.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "This is the constructor docstring.",
|
||||
"signature": "(self, arg1)",
|
||||
"parameters": "",
|
||||
},
|
||||
"my_method": {
|
||||
"type": "function",
|
||||
"description": "This is a method docstring. ",
|
||||
"signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]",
|
||||
"parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.",
|
||||
},
|
||||
"my_async_method": {
|
||||
"type": "async_function",
|
||||
"description": "",
|
||||
"signature": "(self, some_arg) -> str",
|
||||
"parameters": "",
|
||||
},
|
||||
},
|
||||
"code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"',
|
||||
},
|
||||
"my_function": {
|
||||
"type": "function",
|
||||
"description": "This is a function docstring.",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2',
|
||||
},
|
||||
"my_async_function": {
|
||||
"type": "function",
|
||||
"description": "",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2",
|
||||
},
|
||||
}
|
||||
schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT)
|
||||
assert schemas == expected
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class TestUTWriter:
|
|||
)
|
||||
],
|
||||
created=1706710532,
|
||||
model="gpt-3.5-turbo-1106",
|
||||
model="gpt-4-turbo",
|
||||
object="chat.completion",
|
||||
system_fingerprint="fp_04f9a1eebf",
|
||||
usage=CompletionUsage(completion_tokens=35, prompt_tokens=1982, total_tokens=2017),
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ from metagpt.utils.cost_manager import CostManager
|
|||
|
||||
def test_cost_manager():
|
||||
cm = CostManager(total_budget=20)
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1000
|
||||
assert cm.get_total_completion_tokens() == 100
|
||||
assert cm.get_total_cost() == 0.013
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1100
|
||||
assert cm.get_total_completion_tokens() == 110
|
||||
assert cm.get_total_cost() == 0.0143
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink", "playwright", "pyppeteer"])
|
||||
async def test_mermaid(engine, context, mermaid_mocker):
|
||||
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
|
||||
# ink prerequisites: connected to internet
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def _paragraphs(n):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
"msgs, model, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
|
|
@ -37,7 +37,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
"text, prompt_template, model, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
from metagpt.utils.token_counter import count_input_tokens, count_output_tokens
|
||||
|
||||
|
||||
def test_count_message_tokens():
|
||||
|
|
@ -15,7 +15,7 @@ def test_count_message_tokens():
|
|||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages) == 15
|
||||
assert count_input_tokens(messages) == 15
|
||||
|
||||
|
||||
def test_count_message_tokens_with_name():
|
||||
|
|
@ -23,12 +23,12 @@ def test_count_message_tokens_with_name():
|
|||
{"role": "user", "content": "Hello", "name": "John"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages) == 17
|
||||
assert count_input_tokens(messages) == 17
|
||||
|
||||
|
||||
def test_count_message_tokens_empty_input():
|
||||
"""Empty input should return 3 tokens"""
|
||||
assert count_message_tokens([]) == 3
|
||||
assert count_input_tokens([]) == 3
|
||||
|
||||
|
||||
def test_count_message_tokens_invalid_model():
|
||||
|
|
@ -38,7 +38,7 @@ def test_count_message_tokens_invalid_model():
|
|||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
with pytest.raises(NotImplementedError):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
count_input_tokens(messages, model="invalid_model")
|
||||
|
||||
|
||||
def test_count_message_tokens_gpt_4():
|
||||
|
|
@ -46,27 +46,27 @@ def test_count_message_tokens_gpt_4():
|
|||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages, model="gpt-4-0314") == 15
|
||||
assert count_input_tokens(messages, model="gpt-4-0314") == 15
|
||||
|
||||
|
||||
def test_count_string_tokens():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4
|
||||
assert count_output_tokens(string, model="gpt-3.5-turbo-0301") == 4
|
||||
|
||||
|
||||
def test_count_string_tokens_empty_input():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0
|
||||
assert count_output_tokens("", model="gpt-3.5-turbo-0301") == 0
|
||||
|
||||
|
||||
def test_count_string_tokens_gpt_4():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
|
||||
assert count_output_tokens(string, model="gpt-4-0314") == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue