mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 11:26:23 +02:00
feat: merge main
This commit is contained in:
commit
8a92fa0f36
404 changed files with 20076 additions and 1163 deletions
|
|
@ -1,7 +1,7 @@
|
|||
llm:
|
||||
base_url: "https://api.openai.com/v1"
|
||||
api_key: "sk-xxx"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
search:
|
||||
api_type: "serpapi"
|
||||
|
|
|
|||
2
tests/data/andriod_assistant/.gitignore
vendored
Normal file
2
tests/data/andriod_assistant/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
!*.png
|
||||
unitest_Contacts
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 611 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 840 KiB |
2
tests/data/andriod_assistant/demo_Contacts/record.txt
Normal file
2
tests/data/andriod_assistant/demo_Contacts/record.txt
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
tap(9):::android.view.ViewGroup_1067_236_android.widget.TextView_183_204_Apps_2
|
||||
stop
|
||||
1
tests/data/andriod_assistant/demo_Contacts/task_desc.txt
Normal file
1
tests/data/andriod_assistant/demo_Contacts/task_desc.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
Create a contact in Contacts App named zjy with a phone number +86 18831933368
|
||||
27
tests/data/config/config2.yaml
Normal file
27
tests/data/config/config2.yaml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
llm:
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
base_url: "YOUR_gpt-3.5-turbo_BASE_URL"
|
||||
api_key: "YOUR_gpt-3.5-turbo_API_KEY"
|
||||
model: "gpt-3.5-turbo" # or gpt-3.5-turbo
|
||||
# proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests
|
||||
# timeout: 600 # Optional. If set to 0, default value is 300.
|
||||
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
|
||||
|
||||
models:
|
||||
"YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
base_url: "YOUR_MODEL_1_BASE_URL"
|
||||
api_key: "YOUR_MODEL_1_API_KEY"
|
||||
# proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests
|
||||
# timeout: 600 # Optional. If set to 0, default value is 300.
|
||||
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
|
||||
"YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo
|
||||
api_type: "openai" # or azure / ollama / groq etc.
|
||||
base_url: "YOUR_MODEL_2_BASE_URL"
|
||||
api_key: "YOUR_MODEL_2_API_KEY"
|
||||
proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests
|
||||
# timeout: 600 # Optional. If set to 0, default value is 300.
|
||||
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
|
||||
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
|
||||
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -1,5 +0,0 @@
|
|||
{"world_name": "the ville",
|
||||
"maze_width": 140,
|
||||
"maze_height": 100,
|
||||
"sq_tile_size": 32,
|
||||
"special_constraint": ""}
|
||||
|
|
@ -1,63 +0,0 @@
|
|||
32138, the Ville, artist's co-living space, Latoya Williams's room
|
||||
32148, the Ville, artist's co-living space, Latoya Williams's bathroom
|
||||
32158, the Ville, artist's co-living space, Rajiv Patel's room
|
||||
32168, the Ville, artist's co-living space, Rajiv Patel's bathroom
|
||||
32178, the Ville, artist's co-living space, Abigail Chen's room
|
||||
32188, the Ville, artist's co-living space, Abigail Chen's bathroom
|
||||
32198, the Ville, artist's co-living space, Francisco Lopez's room
|
||||
32139, the Ville, artist's co-living space, Francisco Lopez's bathroom
|
||||
32149, the Ville, artist's co-living space, Hailey Johnson's room
|
||||
32159, the Ville, artist's co-living space, Hailey Johnson's bathroom
|
||||
32179, the Ville, artist's co-living space, common room
|
||||
32189, the Ville, artist's co-living space, kitchen
|
||||
32199, the Ville, Arthur Burton's apartment, main room
|
||||
32140, the Ville, Arthur Burton's apartment, bathroom
|
||||
32150, the Ville, Ryan Park's apartment, main room
|
||||
32160, the Ville, Ryan Park's apartment, bathroom
|
||||
32170, the Ville, Isabella Rodriguez's apartment, main room
|
||||
32180, the Ville, Isabella Rodriguez's apartment, bathroom
|
||||
32190, the Ville, Giorgio Rossi's apartment, main room
|
||||
32200, the Ville, Giorgio Rossi's apartment, bathroom
|
||||
32141, the Ville, Carlos Gomez's apartment, main room
|
||||
32151, the Ville, Carlos Gomez's apartment, bathroom
|
||||
32161, the Ville, The Rose and Crown Pub, pub
|
||||
32171, the Ville, Hobbs Cafe, cafe
|
||||
32181, the Ville, Oak Hill College, classroom
|
||||
32191, the Ville, Oak Hill College, library
|
||||
32201, the Ville, Oak Hill College, hallway
|
||||
32142, the Ville, Johnson Park, park
|
||||
32152, the Ville, Harvey Oak Supply Store, supply store
|
||||
32162, the Ville, The Willows Market and Pharmacy, store
|
||||
32193, the Ville, Adam Smith's house, main room
|
||||
32203, the Ville, Adam Smith's house, bathroom
|
||||
32174, the Ville, Yuriko Yamamoto's house, main room
|
||||
32184, the Ville, Yuriko Yamamoto's house, bathroom
|
||||
32194, the Ville, Moore family's house, main room
|
||||
32204, the Ville, Moore family's house, bathroom
|
||||
32172, the Ville, Dorm for Oak Hill College, Klaus Mueller's room
|
||||
32182, the Ville, Dorm for Oak Hill College, Maria Lopez's room
|
||||
32192, the Ville, Dorm for Oak Hill College, Ayesha Khan's room
|
||||
32202, the Ville, Dorm for Oak Hill College, Wolfgang Schulz's room
|
||||
32143, the Ville, Dorm for Oak Hill College, man's bathroom
|
||||
32153, the Ville, Dorm for Oak Hill College, woman's bathroom
|
||||
32163, the Ville, Dorm for Oak Hill College, common room
|
||||
32173, the Ville, Dorm for Oak Hill College, kitchen
|
||||
32183, the Ville, Dorm for Oak Hill College, garden
|
||||
32205, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room
|
||||
32215, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room
|
||||
32225, the Ville, Tamara Taylor and Carmen Ortiz's house, common room
|
||||
32235, the Ville, Tamara Taylor and Carmen Ortiz's house, kitchen
|
||||
32245, the Ville, Tamara Taylor and Carmen Ortiz's house, bathroom
|
||||
32255, the Ville, Tamara Taylor and Carmen Ortiz's house, garden
|
||||
32265, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom
|
||||
32275, the Ville, Moreno family's house, empty bedroom
|
||||
32206, the Ville, Moreno family's house, common room
|
||||
32216, the Ville, Moreno family's house, kitchen
|
||||
32226, the Ville, Moreno family's house, bathroom
|
||||
32236, the Ville, Moreno family's house, garden
|
||||
32246, the Ville, Lin family's house, Mei and John Lin's bedroom
|
||||
32256, the Ville, Lin family's house, Eddy Lin's bedroom
|
||||
32266, the Ville, Lin family's house, common room
|
||||
32276, the Ville, Lin family's house, kitchen
|
||||
32207, the Ville, Lin family's house, bathroom
|
||||
32217, the Ville, Lin family's house, garden
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
32227, the Ville, <all>, bed
|
||||
32237, the Ville, <all>, desk
|
||||
32247, the Ville, <all>, closet
|
||||
32257, the Ville, <all>, shelf
|
||||
32267, the Ville, <all>, easel
|
||||
32277, the Ville, <all>, bathroom sink
|
||||
32208, the Ville, <all>, shower
|
||||
32218, the Ville, <all>, toilet
|
||||
32228, the Ville, <all>, kitchen sink
|
||||
32238, the Ville, <all>, refrigerator
|
||||
32248, the Ville, <all>, toaster
|
||||
32258, the Ville, <all>, cooking area
|
||||
32268, the Ville, <all>, common room table
|
||||
32278, the Ville, <all>, common room sofa
|
||||
32209, the Ville, <all>, guitar
|
||||
32219, the Ville, <all>, microphone
|
||||
32229, the Ville, <all>, bar customer seating
|
||||
32239, the Ville, <all>, behind the bar counter
|
||||
32249, the Ville, <all>, behind the cafe counter
|
||||
32259, the Ville, <all>, cafe customer seating
|
||||
32269, the Ville, <all>, piano
|
||||
32279, the Ville, <all>, blackboard
|
||||
32210, the Ville, <all>, game console
|
||||
32220, the Ville, <all>, computer desk
|
||||
32230, the Ville, <all>, computer
|
||||
32240, the Ville, <all>, library sofa
|
||||
32250, the Ville, <all>, bookshelf
|
||||
32260, the Ville, <all>, library table
|
||||
32270, the Ville, <all>, classroom student seating
|
||||
32280, the Ville, <all>, classroom podium
|
||||
32211, the Ville, <all>, behind the pharmacy counter
|
||||
32221, the Ville, <all>, behind the grocery counter
|
||||
32231, the Ville, <all>, pharmacy store shelf
|
||||
32241, the Ville, <all>, grocery store shelf
|
||||
32251, the Ville, <all>, pharmacy store counter
|
||||
32261, the Ville, <all>, grocery store counter
|
||||
32271, the Ville, <all>, supply store product shelf
|
||||
32281, the Ville, <all>, behind the supply store counter
|
||||
32212, the Ville, <all>, supply store counter
|
||||
32222, the Ville, <all>, dorm garden
|
||||
32232, the Ville, <all>, house garden
|
||||
32242, the Ville, <all>, garden chair
|
||||
32252, the Ville, <all>, park garden
|
||||
32262, the Ville, <all>, harp
|
||||
32272, the Ville, <all>, lifting weight
|
||||
32282, the Ville, <all>, pool table
|
||||
|
|
|
@ -1,19 +0,0 @@
|
|||
32135, the Ville, artist's co-living space
|
||||
32145, the Ville, Arthur Burton's apartment
|
||||
32155, the Ville, Ryan Park's apartment
|
||||
32165, the Ville, Isabella Rodriguez's apartment
|
||||
32175, the Ville, Giorgio Rossi's apartment
|
||||
32185, the Ville, Carlos Gomez's apartment
|
||||
32195, the Ville, The Rose and Crown Pub
|
||||
32136, the Ville, Hobbs Cafe
|
||||
32146, the Ville, Oak Hill College
|
||||
32156, the Ville, Johnson Park
|
||||
32166, the Ville, Harvey Oak Supply Store
|
||||
32176, the Ville, The Willows Market and Pharmacy
|
||||
32186, the Ville, Adam Smith's house
|
||||
32196, the Ville, Yuriko Yamamoto's house
|
||||
32137, the Ville, Moore family's house
|
||||
32147, the Ville, Tamara Taylor and Carmen Ortiz's house
|
||||
32157, the Ville, Moreno family's house
|
||||
32167, the Ville, Lin family's house
|
||||
32177, the Ville, Dorm for Oak Hill College
|
||||
|
|
|
@ -1,40 +0,0 @@
|
|||
32285, the Ville, artist's co-living space, Latoya Williams's room, sp-A
|
||||
32295, the Ville, artist's co-living space, Rajiv Patel's room, sp-A
|
||||
32305, the Ville, artist's co-living space, Rajiv Patel's room, sp-B
|
||||
32315, the Ville, artist's co-living space, Abigail Chen's room, sp-A
|
||||
32286, the Ville, artist's co-living space, Francisco Lopez's room, sp-A
|
||||
32296, the Ville, artist's co-living space, Hailey Johnson's room, sp-A
|
||||
32306, the Ville, Arthur Burton's apartment, main room, sp-A
|
||||
32316, the Ville, Arthur Burton's apartment, main room, sp-B
|
||||
32287, the Ville, Ryan Park's apartment, main room, sp-A
|
||||
32297, the Ville, Ryan Park's apartment, main room, sp-B
|
||||
32307, the Ville, Isabella Rodriguez's apartment, main room, sp-A
|
||||
32317, the Ville, Isabella Rodriguez's apartment, main room, sp-B
|
||||
32288, the Ville, Giorgio Rossi's apartment, main room, sp-A
|
||||
32298, the Ville, Giorgio Rossi's apartment, main room, sp-B
|
||||
32308, the Ville, Carlos Gomez's apartment, main room, sp-A
|
||||
32318, the Ville, Carlos Gomez's apartment, main room, sp-B
|
||||
32289, the Ville, Adam Smith's house, main room, sp-A
|
||||
32299, the Ville, Adam Smith's house, main room, sp-B
|
||||
32309, the Ville, Yuriko Yamamoto's house, main room, sp-A
|
||||
32319, the Ville, Yuriko Yamamoto's house, main room, sp-B
|
||||
32290, the Ville, Moore family's house, main room, sp-A
|
||||
32300, the Ville, Moore family's house, main room, sp-B
|
||||
32310, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room, sp-A
|
||||
32320, the Ville, Tamara Taylor and Carmen Ortiz's house, Tamara Taylor's room, sp-B
|
||||
32291, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room, sp-A
|
||||
32301, the Ville, Tamara Taylor and Carmen Ortiz's house, Carmen Ortiz's room, sp-B
|
||||
32311, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom, sp-A
|
||||
32321, the Ville, Moreno family's house, Tom and Jane Moreno's bedroom, sp-B
|
||||
32292, the Ville, Moreno family's house, empty bedroom, sp-A
|
||||
32302, the Ville, Moreno family's house, empty bedroom, sp-B
|
||||
32312, the Ville, Lin family's house, Mei and John Lin's bedroom, sp-A
|
||||
32322, the Ville, Lin family's house, Mei and John Lin's bedroom, sp-B
|
||||
32293, the Ville, Lin family's house, Eddy Lin's bedroom, sp-A
|
||||
32303, the Ville, Lin family's house, Eddy Lin's bedroom, sp-B
|
||||
32313, the Ville, Dorm for Oak Hill College, Klaus Mueller's room, sp-A
|
||||
32323, the Ville, Dorm for Oak Hill College, Klaus Mueller's room, sp-B
|
||||
32294, the Ville, Dorm for Oak Hill College, Maria Lopez's room, sp-A
|
||||
32304, the Ville, Dorm for Oak Hill College, Ayesha Khan's room, sp-A
|
||||
32314, the Ville, Dorm for Oak Hill College, Ayesha Khan's room, sp-B
|
||||
32324, the Ville, Dorm for Oak Hill College, Wolfgang Schulz's room, sp-A
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
32134, the Ville
|
||||
|
File diff suppressed because one or more lines are too long
|
|
@ -38,7 +38,7 @@ DESIGN = {
|
|||
|
||||
|
||||
TASK = {
|
||||
"Required Python packages": ["pygame==2.0.1"],
|
||||
"Required packages": ["pygame==2.0.1"],
|
||||
"Required Other language third-party packages": ["No third-party dependencies required"],
|
||||
"Logic Analysis": [
|
||||
["game.py", "Contains Game class and related functions for game logic"],
|
||||
|
|
|
|||
0
tests/metagpt/configs/__init__.py
Normal file
0
tests/metagpt/configs/__init__.py
Normal file
34
tests/metagpt/configs/test_models_config.py
Normal file
34
tests/metagpt/configs/test_models_config.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH
|
||||
from metagpt.utils.common import aread, awrite
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_models_configs(context):
|
||||
default_model = ModelsConfig.default()
|
||||
assert default_model is not None
|
||||
|
||||
models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml")
|
||||
assert models
|
||||
|
||||
default_models = ModelsConfig.default()
|
||||
backup = ""
|
||||
if not default_models.models:
|
||||
backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml")
|
||||
test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml")
|
||||
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data)
|
||||
|
||||
try:
|
||||
action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1")
|
||||
assert action.private_llm.config.model == "YOUR_MODEL_NAME_1"
|
||||
assert context.config.llm.model != "YOUR_MODEL_NAME_1"
|
||||
finally:
|
||||
if backup:
|
||||
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -4,8 +4,8 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.android_env.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android_env.const import ADB_EXEC_FAIL
|
||||
from metagpt.environment.android.android_ext_env import AndroidExtEnv
|
||||
from metagpt.environment.android.const import ADB_EXEC_FAIL
|
||||
|
||||
|
||||
def mock_device_shape(self, adb_cmd: str) -> str:
|
||||
|
|
@ -16,8 +16,8 @@ def mock_device_shape_invalid(self, adb_cmd: str) -> str:
|
|||
return ADB_EXEC_FAIL
|
||||
|
||||
|
||||
def mock_list_devices(self, adb_cmd: str) -> str:
|
||||
return "devices\nemulator-5554"
|
||||
def mock_list_devices(self) -> str:
|
||||
return ["emulator-5554"]
|
||||
|
||||
|
||||
def mock_get_screenshot(self, adb_cmd: str) -> str:
|
||||
|
|
@ -34,9 +34,8 @@ def mock_write_read_operation(self, adb_cmd: str) -> str:
|
|||
|
||||
def test_android_ext_env(mocker):
|
||||
device_id = "emulator-5554"
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices", mock_list_devices)
|
||||
|
||||
ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/")
|
||||
assert ext_env.adb_prefix == f"adb -s {device_id} "
|
||||
|
|
@ -46,25 +45,20 @@ def test_android_ext_env(mocker):
|
|||
assert ext_env.device_shape == (720, 1080)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape_invalid
|
||||
)
|
||||
assert ext_env.device_shape == (0, 0)
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices
|
||||
)
|
||||
assert ext_env.list_devices() == [device_id]
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot
|
||||
)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot)
|
||||
assert ext_env.get_screenshot("screenshot_xxxx-xx-xx", "/data/") == Path("/data/screenshot_xxxx-xx-xx.png")
|
||||
|
||||
mocker.patch("metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_xml)
|
||||
assert ext_env.get_xml("xml_xxxx-xx-xx", "/data/") == Path("/data/xml_xxxx-xx-xx.xml")
|
||||
|
||||
mocker.patch(
|
||||
"metagpt.environment.android_env.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
"metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_write_read_operation
|
||||
)
|
||||
res = "OK"
|
||||
assert ext_env.system_back() == res
|
||||
|
|
|
|||
|
|
@ -3,8 +3,8 @@
|
|||
# @Desc : the unittest of MinecraftExtEnv
|
||||
|
||||
|
||||
from metagpt.environment.minecraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft_env.minecraft_ext_env import MinecraftExtEnv
|
||||
from metagpt.environment.minecraft.const import MC_CKPT_DIR
|
||||
from metagpt.environment.minecraft.minecraft_ext_env import MinecraftExtEnv
|
||||
|
||||
|
||||
def test_minecraft_ext_env():
|
||||
|
|
|
|||
|
|
@ -4,12 +4,18 @@
|
|||
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.environment.stanford_town_env.stanford_town_ext_env import (
|
||||
StanfordTownExtEnv,
|
||||
from metagpt.environment.stanford_town.env_space import (
|
||||
EnvAction,
|
||||
EnvActionType,
|
||||
EnvObsParams,
|
||||
EnvObsType,
|
||||
)
|
||||
from metagpt.environment.stanford_town.stanford_town_ext_env import StanfordTownExtEnv
|
||||
|
||||
maze_asset_path = (
|
||||
Path(__file__).absolute().parent.joinpath("..", "..", "..", "data", "environment", "stanford_town", "the_ville")
|
||||
Path(__file__)
|
||||
.absolute()
|
||||
.parent.joinpath("..", "..", "..", "..", "metagpt/ext/stanford_town/static_dirs/assets/the_ville")
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -27,7 +33,6 @@ def test_stanford_town_ext_env():
|
|||
assert len(ext_env.get_nearby_tiles(tile=tile, vision_r=5)) == 121
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
ext_env.add_tiles_event(tile[1], tile[0], event=event)
|
||||
ext_env.add_event_from_tile(event, tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
||||
|
|
@ -38,3 +43,22 @@ def test_stanford_town_ext_env():
|
|||
|
||||
ext_env.remove_subject_events_from_tile(subject=event[0], tile=tile)
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 0
|
||||
|
||||
|
||||
def test_stanford_town_ext_env_observe_step():
|
||||
ext_env = StanfordTownExtEnv(maze_asset_path=maze_asset_path)
|
||||
obs, info = ext_env.reset()
|
||||
assert len(info) == 0
|
||||
assert len(obs["address_tiles"]) == 306
|
||||
|
||||
tile = (58, 9)
|
||||
obs = ext_env.observe(obs_params=EnvObsParams(obs_type=EnvObsType.TILE_PATH, coord=tile, level="world"))
|
||||
assert obs == "the Ville"
|
||||
|
||||
action = ext_env.action_space.sample()
|
||||
assert len(action) == 4
|
||||
assert len(action["event"]) == 4
|
||||
|
||||
event = ("double studio:double studio:bedroom 2:bed", None, None, None)
|
||||
obs, _, _, _, _ = ext_env.step(action=EnvAction(action_type=EnvActionType.ADD_TILE_EVENT, coord=tile, event=event))
|
||||
assert len(ext_env.tiles[tile[1]][tile[0]]["events"]) == 1
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ExtEnv&Env
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
|
|
@ -12,11 +14,26 @@ from metagpt.environment.base_env import (
|
|||
mark_as_readable,
|
||||
mark_as_writeable,
|
||||
)
|
||||
from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvObsParams
|
||||
|
||||
|
||||
class ForTestEnv(Environment):
|
||||
value: int = 0
|
||||
|
||||
def reset(
|
||||
self,
|
||||
*,
|
||||
seed: Optional[int] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
pass
|
||||
|
||||
def observe(self, obs_params: Optional[BaseEnvObsParams] = None) -> Any:
|
||||
pass
|
||||
|
||||
def step(self, action: BaseEnvAction) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@mark_as_readable
|
||||
def read_api_no_param(self):
|
||||
return self.value
|
||||
|
|
@ -44,11 +61,11 @@ async def test_ext_env():
|
|||
assert len(apis) > 0
|
||||
assert len(apis["read_api"]) == 3
|
||||
|
||||
_ = await env.step(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
|
||||
assert env.value == 15
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await env.observe("not_exist_api")
|
||||
with pytest.raises(KeyError):
|
||||
await env.read_from_api("not_exist_api")
|
||||
|
||||
assert await env.observe("read_api_no_param") == 15
|
||||
assert await env.observe(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
assert await env.read_from_api("read_api_no_param") == 15
|
||||
assert await env.read_from_api(EnvAPIAbstract(api_name="read_api", kwargs={"a": 5, "b": 5})) == 10
|
||||
|
|
|
|||
|
|
@ -2,33 +2,34 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of WerewolfExtEnv
|
||||
|
||||
from metagpt.environment.werewolf_env.werewolf_ext_env import RoleState, WerewolfExtEnv
|
||||
from metagpt.environment.werewolf.const import RoleState, RoleType
|
||||
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
class Werewolf(Role):
|
||||
profile: str = "Werewolf"
|
||||
profile: str = RoleType.WEREWOLF.value
|
||||
|
||||
|
||||
class Villager(Role):
|
||||
profile: str = "Villager"
|
||||
profile: str = RoleType.VILLAGER.value
|
||||
|
||||
|
||||
class Witch(Role):
|
||||
profile: str = "Witch"
|
||||
profile: str = RoleType.WITCH.value
|
||||
|
||||
|
||||
class Guard(Role):
|
||||
profile: str = "Guard"
|
||||
profile: str = RoleType.GUARD.value
|
||||
|
||||
|
||||
def test_werewolf_ext_env():
|
||||
players_state = {
|
||||
"Player0": ("Werewolf", RoleState.ALIVE),
|
||||
"Player1": ("Werewolf", RoleState.ALIVE),
|
||||
"Player2": ("Villager", RoleState.ALIVE),
|
||||
"Player3": ("Witch", RoleState.ALIVE),
|
||||
"Player4": ("Guard", RoleState.ALIVE),
|
||||
"Player0": (RoleType.WEREWOLF.value, RoleState.ALIVE),
|
||||
"Player1": (RoleType.WEREWOLF.value, RoleState.ALIVE),
|
||||
"Player2": (RoleType.VILLAGER.value, RoleState.ALIVE),
|
||||
"Player3": (RoleType.WITCH.value, RoleState.ALIVE),
|
||||
"Player4": (RoleType.GUARD.value, RoleState.ALIVE),
|
||||
}
|
||||
ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])
|
||||
|
||||
|
|
@ -41,9 +42,9 @@ def test_werewolf_ext_env():
|
|||
assert "Werewolves, please open your eyes" in curr_instr["content"]
|
||||
|
||||
# current step_idx = 5
|
||||
ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player10", player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player0", player_name="Player4")
|
||||
ext_env.wolf_kill_someone(wolf_name="Player1", player_name="Player4")
|
||||
assert ext_env.player_hunted == "Player4"
|
||||
assert len(ext_env.living_players) == 5 # hunted but can be saved by witch
|
||||
|
||||
|
|
@ -52,11 +53,11 @@ def test_werewolf_ext_env():
|
|||
|
||||
# current step_idx = 18
|
||||
assert ext_env.step_idx == 18
|
||||
ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2")
|
||||
ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3")
|
||||
ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3")
|
||||
ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4")
|
||||
ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2")
|
||||
ext_env.vote_kill_someone(voter_name="Player0", player_name="Player2")
|
||||
ext_env.vote_kill_someone(voter_name="Player1", player_name="Player3")
|
||||
ext_env.vote_kill_someone(voter_name="Player2", player_name="Player3")
|
||||
ext_env.vote_kill_someone(voter_name="Player3", player_name="Player4")
|
||||
ext_env.vote_kill_someone(voter_name="Player4", player_name="Player2")
|
||||
assert ext_env.player_current_dead == "Player2"
|
||||
assert len(ext_env.living_players) == 4
|
||||
|
||||
|
|
|
|||
3
tests/metagpt/ext/__init__.py
Normal file
3
tests/metagpt/ext/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/android_assistant/__init__.py
Normal file
3
tests/metagpt/ext/android_assistant/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
95
tests/metagpt/ext/android_assistant/test_an.py
Normal file
95
tests/metagpt/ext/android_assistant/test_an.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : test on android emulator action. After Modify Role Test, this script is discarded.
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import metagpt
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.environment.android.android_env import AndroidEnv
|
||||
from metagpt.ext.android_assistant.actions.manual_record import ManualRecord
|
||||
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
|
||||
from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse
|
||||
from metagpt.ext.android_assistant.actions.self_learn_and_reflect import (
|
||||
SelfLearnAndReflect,
|
||||
)
|
||||
from tests.metagpt.environment.android_env.test_android_ext_env import (
|
||||
mock_device_shape,
|
||||
mock_list_devices,
|
||||
)
|
||||
|
||||
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/unitest_Contacts")
|
||||
TASK_PATH.mkdir(parents=True, exist_ok=True)
|
||||
DEMO_NAME = str(time.time())
|
||||
SELF_EXPLORE_DOC_PATH = TASK_PATH.joinpath("auto_docs")
|
||||
PARSE_RECORD_DOC_PATH = TASK_PATH.joinpath("demo_docs")
|
||||
|
||||
device_id = "emulator-5554"
|
||||
xml_dir = Path("/sdcard")
|
||||
screenshot_dir = Path("/sdcard/Pictures/Screenshots")
|
||||
|
||||
|
||||
metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd = mock_device_shape
|
||||
metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices = mock_list_devices
|
||||
|
||||
|
||||
test_env_self_learn_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_self_learning = SelfLearnAndReflect()
|
||||
|
||||
test_env_manual_learn_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_manual_record = ManualRecord()
|
||||
test_manual_parse = ParseRecord()
|
||||
|
||||
test_env_screenshot_parse_android = AndroidEnv(
|
||||
device_id=device_id,
|
||||
xml_dir=xml_dir,
|
||||
screenshot_dir=screenshot_dir,
|
||||
)
|
||||
test_screenshot_parse = ScreenshotParse()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
test_action_list = [
|
||||
test_self_learning.run(
|
||||
round_count=20,
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
last_act="",
|
||||
task_dir=TASK_PATH / "demos" / f"self_learning_{DEMO_NAME}",
|
||||
docs_dir=SELF_EXPLORE_DOC_PATH,
|
||||
env=test_env_self_learn_android,
|
||||
),
|
||||
test_manual_record.run(
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}",
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
env=test_env_manual_learn_android,
|
||||
),
|
||||
test_manual_parse.run(
|
||||
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改
|
||||
docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改
|
||||
env=test_env_manual_learn_android,
|
||||
),
|
||||
test_screenshot_parse.run(
|
||||
round_count=20,
|
||||
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
|
||||
last_act="",
|
||||
task_dir=TASK_PATH / f"act_{DEMO_NAME}",
|
||||
docs_dir=PARSE_RECORD_DOC_PATH,
|
||||
env=test_env_screenshot_parse_android,
|
||||
grid_on=False,
|
||||
),
|
||||
]
|
||||
|
||||
loop.run_until_complete(asyncio.gather(*test_action_list))
|
||||
loop.close()
|
||||
29
tests/metagpt/ext/android_assistant/test_parse_record.py
Normal file
29
tests/metagpt/ext/android_assistant/test_parse_record.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : test case (imgs from appagent's)
|
||||
|
||||
import asyncio
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
|
||||
|
||||
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/demo_Contacts")
|
||||
TEST_BEFORE_PATH = TASK_PATH.joinpath("labeled_screenshots/0_labeled.png")
|
||||
TEST_AFTER_PATH = TASK_PATH.joinpath("labeled_screenshots/1_labeled.png")
|
||||
RECORD_PATH = TASK_PATH.joinpath("record.txt")
|
||||
TASK_DESC_PATH = TASK_PATH.joinpath("task_desc.txt")
|
||||
DOCS_DIR = TASK_PATH.joinpath("storage")
|
||||
|
||||
test_action = Action(name="test")
|
||||
|
||||
|
||||
async def manual_learn_test():
|
||||
parse_record = ParseRecord()
|
||||
await parse_record.run(app_name="demo_Contacts", task_dir=TASK_PATH, docs_dir=DOCS_DIR)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(manual_learn_test())
|
||||
loop.close()
|
||||
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/gen_action_details.py
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.environment.api.env_api import EnvAPIAbstract
|
||||
from metagpt.ext.stanford_town.actions.gen_action_details import (
|
||||
GenActionArena,
|
||||
GenActionDetails,
|
||||
GenActionObject,
|
||||
GenActionSector,
|
||||
GenActObjDescription,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_action_details():
|
||||
role = STRole(
|
||||
name="Klaus Mueller",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
act_desp = "sleeping"
|
||||
act_dura = "120"
|
||||
|
||||
access_tile = await role.rc.env.read_from_api(
|
||||
EnvAPIAbstract(api_name="access_tile", kwargs={"tile": role.scratch.curr_tile})
|
||||
)
|
||||
act_world = access_tile["world"]
|
||||
assert act_world == "the Ville"
|
||||
|
||||
sector = await GenActionSector().run(role, access_tile, act_desp)
|
||||
arena = await GenActionArena().run(role, act_desp, act_world, sector)
|
||||
temp_address = f"{act_world}:{sector}:{arena}"
|
||||
obj = await GenActionObject().run(role, act_desp, temp_address)
|
||||
|
||||
act_obj_desp = await GenActObjDescription().run(role, obj, act_desp)
|
||||
|
||||
result_dict = await GenActionDetails().run(role, act_desp, act_dura)
|
||||
|
||||
# gen_action_sector
|
||||
assert isinstance(sector, str)
|
||||
assert sector in role.s_mem.get_str_accessible_sectors(act_world)
|
||||
|
||||
# gen_action_arena
|
||||
assert isinstance(arena, str)
|
||||
assert arena in role.s_mem.get_str_accessible_sector_arenas(f"{act_world}:{sector}")
|
||||
|
||||
# gen_action_obj
|
||||
assert isinstance(obj, str)
|
||||
assert obj in role.s_mem.get_str_accessible_arena_game_objects(temp_address)
|
||||
|
||||
if result_dict:
|
||||
for key in [
|
||||
"action_address",
|
||||
"action_duration",
|
||||
"action_description",
|
||||
"action_pronunciatio",
|
||||
"action_event",
|
||||
"chatting_with",
|
||||
"chat",
|
||||
"chatting_with_buffer",
|
||||
"chatting_end_time",
|
||||
"act_obj_description",
|
||||
"act_obj_pronunciatio",
|
||||
"act_obj_event",
|
||||
]:
|
||||
assert key in result_dict
|
||||
assert result_dict["action_address"] == f"{temp_address}:{obj}"
|
||||
assert result_dict["action_duration"] == int(act_dura)
|
||||
assert result_dict["act_obj_description"] == act_obj_desp
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of actions/summarize_conv
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.actions.summarize_conv import SummarizeConv
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_conv():
|
||||
conv = [("Role_A", "what's the weather today?"), ("Role_B", "It looks pretty good, and I will take a walk then.")]
|
||||
|
||||
output = await SummarizeConv().run(conv)
|
||||
assert "weather" in output
|
||||
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/memory/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
89
tests/metagpt/ext/stanford_town/memory/test_agent_memory.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of AgentMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import AgentMemory
|
||||
from metagpt.ext.stanford_town.memory.retrieve import agent_retrieve
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试; Load方法中使用了add方法,验证Load即可验证所有add
|
||||
2.2 Get方法测试
|
||||
"""
|
||||
memory_easy_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memroy_chat_storage_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/associative_memory",
|
||||
)
|
||||
memory_save_easy_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
memory_save_chat_test_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/test_memory",
|
||||
)
|
||||
|
||||
|
||||
class TestAgentMemory:
|
||||
@pytest.fixture
|
||||
def agent_memory(self):
|
||||
# 创建一个AgentMemory实例并返回,可以在所有测试用例中共享
|
||||
test_agent_memory = AgentMemory()
|
||||
test_agent_memory.set_mem_path(memroy_chat_storage_path)
|
||||
return test_agent_memory
|
||||
|
||||
def test_load(self, agent_memory):
|
||||
logger.info(f"存储路径为:{agent_memory.memory_saved}")
|
||||
logger.info(f"存储记忆条数为:{len(agent_memory.storage)}")
|
||||
logger.info(f"kw_strength为{agent_memory.kw_strength_event},{agent_memory.kw_strength_thought}")
|
||||
logger.info(f"embeeding.json条数为{len(agent_memory.embeddings)}")
|
||||
|
||||
assert agent_memory.embeddings is not None
|
||||
|
||||
def test_save(self, agent_memory):
|
||||
try:
|
||||
agent_memory.save(memory_save_chat_test_path)
|
||||
logger.info("成功存储")
|
||||
except:
|
||||
pass
|
||||
|
||||
def test_summary_function(self, agent_memory):
|
||||
logger.info(f"event长度为{len(agent_memory.event_list)}")
|
||||
logger.info(f"thought长度为{len(agent_memory.thought_list)}")
|
||||
logger.info(f"chat长度为{len(agent_memory.chat_list)}")
|
||||
result1 = agent_memory.get_summarized_latest_events(4)
|
||||
logger.info(f"总结最近事件结果为:{result1}")
|
||||
|
||||
def test_get_last_chat_function(self, agent_memory):
|
||||
result2 = agent_memory.get_last_chat("customers")
|
||||
logger.info(f"上一次对话是{result2}")
|
||||
|
||||
def test_retrieve_function(self, agent_memory):
|
||||
focus_points = ["who i love?"]
|
||||
retrieved = dict()
|
||||
for focal_pt in focus_points:
|
||||
nodes = [
|
||||
[i.last_accessed, i]
|
||||
for i in agent_memory.event_list + agent_memory.thought_list
|
||||
if "idle" not in i.embedding_key
|
||||
]
|
||||
nodes = sorted(nodes, key=lambda x: x[0])
|
||||
nodes = [i for created, i in nodes]
|
||||
results = agent_retrieve(agent_memory, datetime.now() - timedelta(days=120), 0.99, focal_pt, nodes, 5)
|
||||
final_result = []
|
||||
for n in results:
|
||||
for i in agent_memory.storage:
|
||||
if i.memory_id == n:
|
||||
i.last_accessed = datetime.now() - timedelta(days=120)
|
||||
final_result.append(i)
|
||||
|
||||
retrieved[focal_pt] = final_result
|
||||
logger.info(f"检索结果为{retrieved}")
|
||||
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
76
tests/metagpt/ext/stanford_town/memory/test_basic_memory.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of BasicMemory
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.logs import logger
|
||||
|
||||
"""
|
||||
memory测试思路
|
||||
1. Basic Memory测试
|
||||
2. Agent Memory测试
|
||||
2.1 Load & Save方法测试
|
||||
2.2 Add方法测试
|
||||
2.3 Get方法测试
|
||||
"""
|
||||
|
||||
# Create some sample BasicMemory instances
|
||||
memory1 = BasicMemory(
|
||||
memory_id="1",
|
||||
memory_count=1,
|
||||
type_count=1,
|
||||
memory_type="event",
|
||||
depth=1,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject1",
|
||||
predicate="Predicate1",
|
||||
object="Object1",
|
||||
content="This is content 1",
|
||||
embedding_key="embedding_key_1",
|
||||
poignancy=1,
|
||||
keywords=["keyword1", "keyword2"],
|
||||
filling=["memory_id_2"],
|
||||
)
|
||||
memory2 = BasicMemory(
|
||||
memory_id="2",
|
||||
memory_count=2,
|
||||
type_count=2,
|
||||
memory_type="thought",
|
||||
depth=2,
|
||||
created=datetime.now(),
|
||||
expiration=datetime.now() + timedelta(days=30),
|
||||
subject="Subject2",
|
||||
predicate="Predicate2",
|
||||
object="Object2",
|
||||
content="This is content 2",
|
||||
embedding_key="embedding_key_2",
|
||||
poignancy=2,
|
||||
keywords=["keyword3", "keyword4"],
|
||||
filling=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mem_set():
|
||||
basic_mem2 = memory2
|
||||
yield basic_mem2
|
||||
|
||||
|
||||
def test_basic_mem_function(basic_mem_set):
|
||||
a, b, c = basic_mem_set.summary()
|
||||
logger.info(f"{a}{b}{c}")
|
||||
assert a == "Subject2"
|
||||
|
||||
|
||||
def test_basic_mem_save(basic_mem_set):
|
||||
result = basic_mem_set.save_to_dict()
|
||||
logger.info(f"save结果为{result}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of MemoryTree
|
||||
|
||||
from metagpt.ext.stanford_town.memory.spatial_memory import MemoryTree
|
||||
from metagpt.ext.stanford_town.utils.const import STORAGE_PATH
|
||||
|
||||
|
||||
def test_spatial_memory():
|
||||
f_path = STORAGE_PATH.joinpath(
|
||||
"base_the_ville_isabella_maria_klaus/personas/Isabella Rodriguez/bootstrap_memory/spatial_memory.json"
|
||||
)
|
||||
x = MemoryTree()
|
||||
x.set_mem_path(f_path)
|
||||
assert x.tree
|
||||
assert "the Ville" in x.tree
|
||||
assert "Isabella Rodriguez's apartment" in x.get_str_accessible_sectors("the Ville")
|
||||
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/plan/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
67
tests/metagpt/ext/stanford_town/plan/test_conversation.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of roles conversation
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.plan.converse import agent_conversation
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH, STORAGE_PATH
|
||||
from metagpt.ext.stanford_town.utils.mg_ga_transform import get_reverie_meta
|
||||
from metagpt.ext.stanford_town.utils.utils import copy_folder
|
||||
|
||||
|
||||
async def init_two_roles(fork_sim_code: str = "base_the_ville_isabella_maria_klaus") -> Tuple["STRole"]:
|
||||
sim_code = "unittest_sim"
|
||||
|
||||
copy_folder(str(STORAGE_PATH.joinpath(fork_sim_code)), str(STORAGE_PATH.joinpath(sim_code)))
|
||||
|
||||
reverie_meta = get_reverie_meta(fork_sim_code)
|
||||
role_ir_name = "Isabella Rodriguez"
|
||||
role_km_name = "Klaus Mueller"
|
||||
|
||||
env = StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH)
|
||||
|
||||
role_ir = STRole(
|
||||
name=role_ir_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_ir_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_ir.set_env(env)
|
||||
await role_ir.init_curr_tile()
|
||||
|
||||
role_km = STRole(
|
||||
name=role_km_name,
|
||||
sim_code=sim_code,
|
||||
profile=role_km_name,
|
||||
step=reverie_meta.get("step"),
|
||||
start_time=reverie_meta.get("start_date"),
|
||||
curr_time=reverie_meta.get("curr_time"),
|
||||
sec_per_step=reverie_meta.get("sec_per_step"),
|
||||
)
|
||||
role_km.set_env(env)
|
||||
await role_km.init_curr_tile()
|
||||
|
||||
return role_ir, role_km
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_conversation():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
|
||||
curr_chat = await agent_conversation(role_ir, role_km, conv_rounds=2)
|
||||
assert len(curr_chat) % 2 == 0
|
||||
|
||||
meet = False
|
||||
for conv in curr_chat:
|
||||
if "Valentine's Day party" in conv[1]:
|
||||
# conv[0] speaker, conv[1] utterance
|
||||
meet = True
|
||||
assert meet
|
||||
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
25
tests/metagpt/ext/stanford_town/plan/test_st_plan.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of st_plan
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.ext.stanford_town.plan.st_plan import _choose_retrieved, _should_react
|
||||
from tests.metagpt.ext.stanford_town.plan.test_conversation import init_two_roles
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_react():
|
||||
role_ir, role_km = await init_two_roles()
|
||||
roles = {role_ir.name: role_ir, role_km.name: role_km}
|
||||
role_ir.scratch.act_address = "mock data"
|
||||
|
||||
observed = await role_ir.observe()
|
||||
retrieved = role_ir.retrieve(observed)
|
||||
|
||||
focused_event = _choose_retrieved(role_ir.name, retrieved)
|
||||
|
||||
if focused_event:
|
||||
reaction_mode = await _should_react(role_ir, focused_event, roles) # chat with Isabella Rodriguez
|
||||
assert not reaction_mode
|
||||
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
3
tests/metagpt/ext/stanford_town/roles/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
26
tests/metagpt/ext/stanford_town/roles/test_st_role.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of STRole
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.memory.agent_memory import BasicMemory
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_observe():
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
await role.init_curr_tile()
|
||||
|
||||
ret_events = await role.observe()
|
||||
assert ret_events
|
||||
for event in ret_events:
|
||||
assert isinstance(event, BasicMemory)
|
||||
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
47
tests/metagpt/ext/stanford_town/test_reflect.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of reflection
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.environment import StanfordTownEnv
|
||||
from metagpt.ext.stanford_town.actions.run_reflect_action import (
|
||||
AgentEventTriple,
|
||||
AgentFocusPt,
|
||||
AgentInsightAndGuidance,
|
||||
)
|
||||
from metagpt.ext.stanford_town.roles.st_role import STRole
|
||||
from metagpt.ext.stanford_town.utils.const import MAZE_ASSET_PATH
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reflect():
|
||||
"""
|
||||
init STRole form local json, set sim_code(path),curr_time & start_time
|
||||
"""
|
||||
role = STRole(
|
||||
sim_code="base_the_ville_isabella_maria_klaus",
|
||||
start_time="February 13, 2023",
|
||||
curr_time="February 13, 2023, 00:00:00",
|
||||
)
|
||||
role.set_env(StanfordTownEnv(maze_asset_path=MAZE_ASSET_PATH))
|
||||
role.init_curr_tile()
|
||||
|
||||
run_focus = AgentFocusPt()
|
||||
statements = ""
|
||||
await run_focus.run(role, statements, n=3)
|
||||
|
||||
"""
|
||||
这里有通过测试的结果,但是更多时候LLM生成的结果缺少了because of;考虑修改一下prompt
|
||||
result = {'Klaus Mueller and Maria Lopez have a close relationship because they have been friends for a long time and have a strong bond': [1, 2, 5, 9, 11, 14], 'Klaus Mueller has a crush on Maria Lopez': [8, 15, 24], 'Klaus Mueller is academically inclined and actively researching a topic': [13, 20], 'Klaus Mueller is socially active and acquainted with Isabella Rodriguez': [17, 21, 22], 'Klaus Mueller is organized and prepared': [19]}
|
||||
"""
|
||||
run_insight = AgentInsightAndGuidance()
|
||||
statements = "[user: Klaus Mueller has a close relationship with Maria Lopez, user:s Mueller and Maria Lopez have a close relationship, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller and Maria Lopez have a strong relationship, user: Klaus Mueller is a dormmate of Maria Lopez., user: Klaus Mueller and Maria Lopez have a strong bond, user: Klaus Mueller has a crush on Maria Lopez, user: Klaus Mueller and Maria Lopez have been friends for more than 2 years., user: Klaus Mueller has a close relationship with Maria Lopez, user: Klaus Mueller Maria Lopez is heading off to college., user: Klaus Mueller and Maria Lopez have a close relationship, user: Klaus Mueller is actively researching a topic, user: Klaus Mueller is close friends and classmates with Maria Lopez., user: Klaus Mueller is socially active, user: Klaus Mueller has a crush on Maria Lopez., user: Klaus Mueller and Maria Lopez have been friends for a long time, user: Klaus Mueller is academically inclined, user: For Klaus Mueller's planning: should remember to ask Maria Lopez about her research paper, as she found it interesting that he mentioned it., user: Klaus Mueller is acquainted with Isabella Rodriguez, user: Klaus Mueller is organized and prepared, user: Maria Lopez is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is conversing about conversing about Maria's research paper mentioned by Klaus, user: Klaus Mueller is a student, user: Klaus Mueller is a student, user: Klaus Mueller is conversing about two friends named Klaus Mueller and Maria Lopez discussing their morning plans and progress on a research paper before Maria heads off to college., user: Klaus Mueller is socially active, user: Klaus Mueller is socially active, user: Klaus Mueller is socially active and acquainted with Isabella Rodriguez, user: Klaus Mueller has a crush on Maria Lopez]"
|
||||
await run_insight.run(role, statements, n=5)
|
||||
|
||||
run_triple = AgentEventTriple()
|
||||
statements = "(Klaus Mueller is academically inclined)"
|
||||
await run_triple.run(statements, role)
|
||||
|
||||
role.scratch.importance_trigger_curr = -1
|
||||
role.reflect()
|
||||
3
tests/metagpt/ext/werewolf/__init__.py
Normal file
3
tests/metagpt/ext/werewolf/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
3
tests/metagpt/ext/werewolf/actions/__init__.py
Normal file
3
tests/metagpt/ext/werewolf/actions/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
164
tests/metagpt/ext/werewolf/actions/test_experience_operation.py
Normal file
164
tests/metagpt/ext/werewolf/actions/test_experience_operation.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.ext.werewolf.actions import AddNewExperiences, RetrieveExperiences
|
||||
from metagpt.ext.werewolf.schema import RoleExperience
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class TestExperiencesOperation:
|
||||
collection_name = "test"
|
||||
test_round_id = "test_01"
|
||||
version = "test"
|
||||
samples_to_add = [
|
||||
RoleExperience(
|
||||
profile="Witch",
|
||||
reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. "
|
||||
"Player4's behavior is suspicious.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="Witch",
|
||||
reflection="The game is in a critical state with only three players left, "
|
||||
"and I need to make a wise decision to save Player7 or not.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="Seer",
|
||||
reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, "
|
||||
"sided with him. I, as the real Seer, am under suspicion.",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version,
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection1",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_01-10",
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection2",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_11-20",
|
||||
),
|
||||
RoleExperience(
|
||||
profile="TestRole",
|
||||
reflection="Some test reflection3",
|
||||
response="",
|
||||
outcome="",
|
||||
round_id=test_round_id,
|
||||
version=version + "_21-30",
|
||||
),
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add(self):
|
||||
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
|
||||
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
)
|
||||
if saved_file.exists():
|
||||
saved_file.unlink()
|
||||
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.run(self.samples_to_add)
|
||||
|
||||
# test insertion
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
# test if we record the samples correctly to local file
|
||||
# & test if we could recover a embedding db from the file
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.add_from_file(saved_file)
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "one player claimed to be Seer and the other Witch"
|
||||
results = action.run(query, profile="Witch")
|
||||
results = json.loads(results)
|
||||
|
||||
assert len(results) == 2, "Witch should have 2 experiences"
|
||||
assert "The game is intense with two players" in results[0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_filtering(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "some test query"
|
||||
profile = "TestRole"
|
||||
|
||||
excluded_version = ""
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 3
|
||||
|
||||
excluded_version = self.version + "_21-30"
|
||||
results = action.run(query, profile=profile, excluded_version=excluded_version)
|
||||
results = json.loads(results)
|
||||
assert len(results) == 2
|
||||
|
||||
|
||||
class TestActualRetrieve:
|
||||
collection_name = "role_reflection"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_experience_pool(self):
|
||||
logger.info("check experience pool")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
if action.engine:
|
||||
all_experiences = action.engine.retriever._index._vector_store._collection.get()
|
||||
logger.info(f"{len(all_experiences['metadatas'])=}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_werewolf_experience(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
logger.info(f"test retrieval with {query=}")
|
||||
action.run(query, "Werewolf")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
logger.info(f"test retrieval with {query=}")
|
||||
results = action.run(query, "Seer")
|
||||
assert "conflict" not in results # 相似局面应该需要包含conflict关键词
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retrieve_villager_experience_filtering(self):
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
|
||||
query = "there are conflicts"
|
||||
|
||||
excluded_version = "01-10"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
excluded_version = "11-20"
|
||||
logger.info(f"test retrieval with {excluded_version=}")
|
||||
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
|
||||
|
||||
assert results_01_10 == results_11_20
|
||||
|
|
@ -60,3 +60,14 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
|
|||
mock_llm_config_anthropic = LLMConfig(
|
||||
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
|
||||
)
|
||||
|
||||
mock_llm_config_bedrock = LLMConfig(
|
||||
api_type="bedrock",
|
||||
model="gpt-100",
|
||||
region_name="somewhere",
|
||||
access_key="123abc",
|
||||
secret_key="123abc",
|
||||
max_token=10000,
|
||||
)
|
||||
|
||||
mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx")
|
||||
|
|
|
|||
|
|
@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
|
|||
|
||||
resp = await llm.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_cont
|
||||
|
||||
|
||||
# For Amazon Bedrock
|
||||
# Check the API documentation of each model
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide
|
||||
BEDROCK_PROVIDER_REQUEST_BODY = {
|
||||
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
|
||||
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
|
||||
"ai21": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"topP": 0.0,
|
||||
"maxTokens": 0,
|
||||
"stopSequences": [],
|
||||
"countPenalty": {"scale": 0.0},
|
||||
"presencePenalty": {"scale": 0.0},
|
||||
"frequencyPenalty": {"scale": 0.0},
|
||||
},
|
||||
"cohere": {
|
||||
"prompt": "",
|
||||
"temperature": 0.0,
|
||||
"p": 0.0,
|
||||
"k": 0.0,
|
||||
"max_tokens": 0,
|
||||
"stop_sequences": [],
|
||||
"return_likelihoods": "NONE",
|
||||
"stream": False,
|
||||
"num_generations": 0,
|
||||
"logit_bias": {},
|
||||
"truncate": "NONE",
|
||||
},
|
||||
"anthropic": {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 0,
|
||||
"system": "",
|
||||
"messages": [{"role": "", "content": ""}],
|
||||
"temperature": 0.0,
|
||||
"top_p": 0.0,
|
||||
"top_k": 0,
|
||||
"stop_sequences": [],
|
||||
},
|
||||
"amazon": {
|
||||
"inputText": "",
|
||||
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
|
||||
},
|
||||
}
|
||||
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY = {
|
||||
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
|
||||
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
|
||||
"ai21": {
|
||||
"id": "",
|
||||
"prompt": {"text": "Hello World", "tokens": []},
|
||||
"completions": [
|
||||
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
|
||||
],
|
||||
},
|
||||
"cohere": {
|
||||
"generations": [
|
||||
{
|
||||
"finish_reason": "",
|
||||
"id": "",
|
||||
"text": "Hello World",
|
||||
"likelihood": 0.0,
|
||||
"token_likelihoods": [{"token": 0.0}],
|
||||
"is_finished": True,
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"id": "",
|
||||
"prompt": "",
|
||||
},
|
||||
"anthropic": {
|
||||
"id": "",
|
||||
"model": "",
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"content": [{"type": "text", "text": "Hello World"}],
|
||||
"stop_reason": "",
|
||||
"stop_sequence": "",
|
||||
"usage": {"input_tokens": 0, "output_tokens": 0},
|
||||
},
|
||||
"amazon": {
|
||||
"inputTextTokenCount": 0,
|
||||
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
|
||||
},
|
||||
}
|
||||
|
|
|
|||
85
tests/metagpt/provider/test_ark.py
Normal file
85
tests/metagpt/provider/test_ark.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""
|
||||
用于火山方舟Python SDK V3的测试用例
|
||||
API文档:https://www.volcengine.com/docs/82379/1263482
|
||||
"""
|
||||
|
||||
from typing import AsyncIterator, List, Union
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
|
||||
|
||||
from metagpt.provider.ark_api import ArkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
get_openai_chat_completion,
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
name = "AI assistant"
|
||||
resp_cont = resp_cont_tmpl.format(name=name)
|
||||
USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000}
|
||||
default_resp = get_openai_chat_completion(name)
|
||||
default_resp.model = "doubao-pro-32k-240515"
|
||||
default_resp.usage = USAGE
|
||||
|
||||
|
||||
def create_chat_completion_chunk(
|
||||
content: str, finish_reason: str = None, choices: List[Choice] = None
|
||||
) -> ChatCompletionChunk:
|
||||
if choices is None:
|
||||
choices = [
|
||||
Choice(
|
||||
delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None),
|
||||
finish_reason=finish_reason,
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
]
|
||||
|
||||
return ChatCompletionChunk(
|
||||
id="012",
|
||||
choices=choices,
|
||||
created=1716278586,
|
||||
model="doubao-pro-32k-240515",
|
||||
object="chat.completion.chunk",
|
||||
system_fingerprint=None,
|
||||
usage=None if choices else USAGE,
|
||||
)
|
||||
|
||||
|
||||
ark_resp_chunk = create_chat_completion_chunk(content="")
|
||||
ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop")
|
||||
ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[])
|
||||
|
||||
|
||||
async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
async def mock_ark_acompletions_create(
|
||||
self, stream: bool = False, **kwargs
|
||||
) -> Union[ChatCompletionChunk, ChatCompletion]:
|
||||
if stream:
|
||||
chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last]
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ark_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create)
|
||||
|
||||
llm = ArkLLM(mock_llm_config_ark)
|
||||
|
||||
resp = await llm.acompletion(messages)
|
||||
assert resp.choices[0].finish_reason == "stop"
|
||||
assert resp.choices[0].message.content == resp_cont
|
||||
assert resp.usage == USAGE
|
||||
|
||||
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)
|
||||
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
109
tests/metagpt/provider/test_bedrock_api.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.bedrock.utils import (
|
||||
NOT_SUUPORT_STREAM_MODELS,
|
||||
SUPPORT_STREAM_MODELS,
|
||||
)
|
||||
from metagpt.provider.bedrock_api import BedrockLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
BEDROCK_PROVIDER_REQUEST_BODY,
|
||||
BEDROCK_PROVIDER_RESPONSE_BODY,
|
||||
)
|
||||
|
||||
# all available model from bedrock
|
||||
models = SUPPORT_STREAM_MODELS | NOT_SUUPORT_STREAM_MODELS
|
||||
messages = [{"role": "user", "content": "Hi!"}]
|
||||
usage = {
|
||||
"prompt_tokens": 1000000,
|
||||
"completion_tokens": 1000000,
|
||||
}
|
||||
|
||||
|
||||
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
provider = self.config.model.split(".")[0]
|
||||
self._update_costs(usage, self.config.model)
|
||||
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
|
||||
|
||||
|
||||
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
|
||||
# use json object to mock EventStream
|
||||
def dict2bytes(x):
|
||||
return json.dumps(x).encode("utf-8")
|
||||
|
||||
provider = self.config.model.split(".")[0]
|
||||
|
||||
if provider == "amazon":
|
||||
response_body_bytes = dict2bytes({"outputText": "Hello World"})
|
||||
elif provider == "anthropic":
|
||||
response_body_bytes = dict2bytes(
|
||||
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
|
||||
)
|
||||
elif provider == "cohere":
|
||||
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
|
||||
else:
|
||||
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
|
||||
|
||||
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
|
||||
self._update_costs(usage, self.config.model)
|
||||
return response_body_stream
|
||||
|
||||
|
||||
def get_bedrock_request_body(model_id) -> dict:
|
||||
provider = model_id.split(".")[0]
|
||||
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
|
||||
|
||||
|
||||
def is_subset(subset, superset) -> bool:
|
||||
"""Ensure all fields in request body are allowed.
|
||||
|
||||
```python
|
||||
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
|
||||
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
|
||||
is_subset(subset, superset)
|
||||
```
|
||||
|
||||
"""
|
||||
for key, value in subset.items():
|
||||
if key not in superset:
|
||||
return False
|
||||
if isinstance(value, dict):
|
||||
if not isinstance(superset[key], dict):
|
||||
return False
|
||||
if not is_subset(value, superset[key]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture(scope="class", params=models)
|
||||
def bedrock_api(request) -> BedrockLLM:
|
||||
model_id = request.param
|
||||
mock_llm_config_bedrock.model = model_id
|
||||
api = BedrockLLM(mock_llm_config_bedrock)
|
||||
return api
|
||||
|
||||
|
||||
class TestBedrockAPI:
|
||||
def _patch_invoke_model(self, mocker):
|
||||
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
|
||||
|
||||
def _patch_invoke_model_stream(self, mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
|
||||
mock_invoke_model_stream,
|
||||
)
|
||||
|
||||
def test_get_request_body(self, bedrock_api: BedrockLLM):
|
||||
"""Ensure request body has correct format"""
|
||||
provider = bedrock_api.provider
|
||||
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
|
||||
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
|
||||
self._patch_invoke_model(mocker)
|
||||
self._patch_invoke_model_stream(mocker)
|
||||
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
|
||||
assert await bedrock_api.aask(messages, stream=True) == "Hello World"
|
||||
|
|
@ -1,62 +1,55 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of spark api
|
||||
"""
|
||||
用于讯飞星火SDK的测试用例
|
||||
文档:https://www.xfyun.cn/doc/spark/Web.html
|
||||
"""
|
||||
|
||||
|
||||
from typing import AsyncIterator, List
|
||||
|
||||
import pytest
|
||||
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
|
||||
from sparkai.core.outputs.chat_generation import ChatGeneration
|
||||
from sparkai.core.outputs.llm_result import LLMResult
|
||||
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_spark,
|
||||
)
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
|
||||
from tests.metagpt.provider.req_resp_const import (
|
||||
llm_general_chat_funcs_test,
|
||||
messages,
|
||||
prompt,
|
||||
resp_cont_tmpl,
|
||||
)
|
||||
|
||||
resp_cont = resp_cont_tmpl.format(name="Spark")
|
||||
USAGE = {
|
||||
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
|
||||
}
|
||||
spark_agenerate_result = LLMResult(
|
||||
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
|
||||
)
|
||||
|
||||
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
|
||||
|
||||
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_cont
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
assert resp == resp_cont
|
||||
async def mock_spark_acreate(self, messages, stream):
|
||||
if stream:
|
||||
return chunk_iterator(chunks)
|
||||
else:
|
||||
return spark_agenerate_result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
|
||||
|
||||
spark_llm = SparkLLM(mock_llm_config)
|
||||
spark_llm = SparkLLM(mock_llm_config_spark)
|
||||
|
||||
resp = await spark_llm.acompletion([])
|
||||
assert resp == resp_cont
|
||||
resp = await spark_llm.acompletion([messages])
|
||||
assert spark_llm.get_choice_text(resp) == resp_cont
|
||||
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
|
||||
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
|
|||
from llama_index.core.schema import Document, NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
|
||||
|
|
@ -25,10 +26,6 @@ class TestSimpleEngine:
|
|||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
|
@ -41,14 +38,18 @@ class TestSimpleEngine:
|
|||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_file_extractor(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
mock_get_file_extractor,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
|
|
@ -58,6 +59,8 @@ class TestSimpleEngine:
|
|||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
file_extractor = mocker.MagicMock()
|
||||
mock_get_file_extractor.return_value = file_extractor
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
|
|
@ -80,12 +83,11 @@ class TestSimpleEngine:
|
|||
)
|
||||
|
||||
# Assert
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
mock_simple_directory_reader.assert_called_once_with(
|
||||
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_retriever.assert_called_once()
|
||||
mock_get_rankers.assert_called_once()
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
|
|
@ -119,7 +121,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is not None
|
||||
assert engine._transformations is not None
|
||||
|
||||
def test_from_objs_with_bm25_config(self):
|
||||
# Setup
|
||||
|
|
@ -137,6 +139,7 @@ class TestSimpleEngine:
|
|||
def test_from_index(self, mocker, mock_llm, mock_embedding):
|
||||
# Mock
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index.as_retriever.return_value = "retriever"
|
||||
mock_get_index = mocker.patch("metagpt.rag.engines.simple.get_index")
|
||||
mock_get_index.return_value = mock_index
|
||||
|
||||
|
|
@ -149,7 +152,7 @@ class TestSimpleEngine:
|
|||
|
||||
# Assert
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
assert engine.index is mock_index
|
||||
assert engine._retriever == "retriever"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
|
|
@ -200,14 +203,11 @@ class TestSimpleEngine:
|
|||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Exec
|
||||
|
|
@ -230,7 +230,7 @@ class TestSimpleEngine:
|
|||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
engine = SimpleEngine(retriever=mock_retriever)
|
||||
|
||||
# Exec
|
||||
engine.add_objs(objs=objs)
|
||||
|
|
@ -308,3 +308,17 @@ class TestSimpleEngine:
|
|||
# Assert
|
||||
assert "obj" in node.node.metadata
|
||||
assert node.node.metadata["obj"] == expected_obj
|
||||
|
||||
def test_get_file_extractor(self, mocker):
|
||||
# mock no omniparse config
|
||||
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
|
||||
mock_omniparse_config.base_url = ""
|
||||
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert file_extractor == {}
|
||||
|
||||
# mock have omniparse config
|
||||
mock_omniparse_config.base_url = "http://localhost:8000"
|
||||
file_extractor = SimpleEngine._get_file_extractor()
|
||||
assert ".pdf" in file_extractor
|
||||
assert isinstance(file_extractor[".pdf"], OmniParse)
|
||||
|
|
|
|||
|
|
@ -97,6 +97,5 @@ class TestConfigBasedFactory:
|
|||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
val = ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert val is None
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.configs.embedding_config import EmbeddingType
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
|
||||
|
||||
|
|
@ -10,30 +11,51 @@ class TestRAGEmbeddingFactory:
|
|||
self.embedding_factory = RAGEmbeddingFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_openai_embedding(self, mocker):
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
|
||||
@staticmethod
|
||||
def mock_openai_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OpenAIEmbedding")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_azure_embedding(self, mocker):
|
||||
@staticmethod
|
||||
def mock_azure_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.AzureOpenAIEmbedding")
|
||||
|
||||
def test_get_rag_embedding_openai(self, mock_openai_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.OPENAI)
|
||||
@staticmethod
|
||||
def mock_gemini_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.GeminiEmbedding")
|
||||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
@staticmethod
|
||||
def mock_ollama_embedding(mocker):
|
||||
return mocker.patch("metagpt.rag.factories.embedding.OllamaEmbedding")
|
||||
|
||||
def test_get_rag_embedding_azure(self, mock_azure_embedding):
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(LLMType.AZURE)
|
||||
|
||||
# Assert
|
||||
mock_azure_embedding.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_openai_embedding):
|
||||
@pytest.mark.parametrize(
|
||||
("mock_func", "embedding_type"),
|
||||
[
|
||||
(mock_openai_embedding, LLMType.OPENAI),
|
||||
(mock_azure_embedding, LLMType.AZURE),
|
||||
(mock_openai_embedding, EmbeddingType.OPENAI),
|
||||
(mock_azure_embedding, EmbeddingType.AZURE),
|
||||
(mock_gemini_embedding, EmbeddingType.GEMINI),
|
||||
(mock_ollama_embedding, EmbeddingType.OLLAMA),
|
||||
],
|
||||
)
|
||||
def test_get_rag_embedding(self, mock_func, embedding_type, mocker):
|
||||
# Mock
|
||||
mock_config = mocker.patch("metagpt.rag.factories.embedding.config")
|
||||
mock = mock_func(mocker)
|
||||
|
||||
# Exec
|
||||
self.embedding_factory.get_rag_embedding(embedding_type)
|
||||
|
||||
# Assert
|
||||
mock.assert_called_once()
|
||||
|
||||
def test_get_rag_embedding_default(self, mocker, mock_config):
|
||||
# Mock
|
||||
mock_openai_embedding = self.mock_openai_embedding(mocker)
|
||||
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.OPENAI
|
||||
|
||||
# Exec
|
||||
|
|
@ -41,3 +63,44 @@ class TestRAGEmbeddingFactory:
|
|||
|
||||
# Assert
|
||||
mock_openai_embedding.assert_called_once()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model, embed_batch_size, expected_params",
|
||||
[("test_model", 100, {"model_name": "test_model", "embed_batch_size": 100}), (None, None, {})],
|
||||
)
|
||||
def test_try_set_model_and_batch_size(self, mock_config, model, embed_batch_size, expected_params):
|
||||
# Mock
|
||||
mock_config.embedding.model = model
|
||||
mock_config.embedding.embed_batch_size = embed_batch_size
|
||||
|
||||
# Setup
|
||||
test_params = {}
|
||||
|
||||
# Exec
|
||||
self.embedding_factory._try_set_model_and_batch_size(test_params)
|
||||
|
||||
# Assert
|
||||
assert test_params == expected_params
|
||||
|
||||
def test_resolve_embedding_type(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = EmbeddingType.OPENAI
|
||||
|
||||
# Exec
|
||||
embedding_type = self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
# Assert
|
||||
assert embedding_type == EmbeddingType.OPENAI
|
||||
|
||||
def test_resolve_embedding_type_exception(self, mock_config):
|
||||
# Mock
|
||||
mock_config.embedding.api_type = None
|
||||
mock_config.llm.api_type = LLMType.GEMINI
|
||||
|
||||
# Assert
|
||||
with pytest.raises(TypeError):
|
||||
self.embedding_factory._resolve_embedding_type()
|
||||
|
||||
def test_raise_for_key(self):
|
||||
with pytest.raises(ValueError):
|
||||
self.embedding_factory._raise_for_key("key")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.embeddings import MockEmbedding
|
||||
from llama_index.core.schema import TextNode
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
|
||||
|
||||
|
|
@ -43,6 +45,14 @@ class TestRetrieverFactory:
|
|||
def mock_es_vector_store(self, mocker):
|
||||
return mocker.MagicMock(spec=ElasticsearchStore)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_nodes(self, mocker):
|
||||
return [TextNode(text="msg")]
|
||||
|
||||
@pytest.fixture
|
||||
def mock_embedding(self):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
def test_get_retriever_with_faiss_config(self, mock_faiss_index, mocker, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
|
|
@ -52,42 +62,40 @@ class TestRetrieverFactory:
|
|||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_vector_store_index):
|
||||
def test_get_retriever_with_bm25_config(self, mocker, mock_nodes):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=mock_nodes)
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_vector_store_index):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(self, mocker, mock_nodes, mock_embedding):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=1)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
retriever = self.retriever_factory.get_retriever(
|
||||
configs=[mock_faiss_config, mock_bm25_config], nodes=mock_nodes, embed_model=mock_embedding
|
||||
)
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_vector_store_index, mock_chroma_vector_store):
|
||||
def test_get_retriever_with_chroma_config(self, mocker, mock_chroma_vector_store, mock_embedding):
|
||||
mock_config = ChromaRetrieverConfig(persist_path="/path/to/chroma", collection_name="test_collection")
|
||||
mock_chromadb = mocker.patch("metagpt.rag.factories.retriever.chromadb.PersistentClient")
|
||||
mock_chromadb.get_or_create_collection.return_value = mocker.MagicMock()
|
||||
mocker.patch("metagpt.rag.factories.retriever.ChromaVectorStore", return_value=mock_chroma_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ChromaRetriever)
|
||||
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_vector_store_index, mock_es_vector_store):
|
||||
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
|
||||
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
|
||||
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config])
|
||||
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
|
||||
|
||||
assert isinstance(retriever, ElasticsearchRetriever)
|
||||
|
||||
|
|
@ -111,3 +119,19 @@ class TestRetrieverFactory:
|
|||
extracted_index = self.retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_get_or_build_when_get(self, mocker):
|
||||
want = "existing_index"
|
||||
mocker.patch.object(self.retriever_factory, "_extract_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
||||
def test_get_or_build_when_build(self, mocker):
|
||||
want = "call_build_es_index"
|
||||
mocker.patch.object(self.retriever_factory, "_build_es_index", return_value=want)
|
||||
|
||||
got = self.retriever_factory._build_es_index(None)
|
||||
|
||||
assert got == want
|
||||
|
|
|
|||
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
118
tests/metagpt/rag/parser/test_omniparse.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import pytest
|
||||
from llama_index.core import Document
|
||||
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.parsers import OmniParse
|
||||
from metagpt.rag.schema import (
|
||||
OmniParsedResult,
|
||||
OmniParseOptions,
|
||||
OmniParseType,
|
||||
ParseResultType,
|
||||
)
|
||||
from metagpt.utils.omniparse_client import OmniParseClient
|
||||
|
||||
# test data
|
||||
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
|
||||
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
|
||||
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
|
||||
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
|
||||
|
||||
|
||||
class TestOmniParseClient:
|
||||
parse_client = OmniParseClient()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_pdf(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_document(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_document"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
with open(TEST_DOCX, "rb") as f:
|
||||
file_bytes = f.read()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# bytes data must provide bytes_filename
|
||||
await self.parse_client.parse_document(file_bytes)
|
||||
|
||||
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
|
||||
assert parse_ret == mock_parsed_ret
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_video(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_video"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
# Wrong file extension test
|
||||
await self.parse_client.parse_video(TEST_DOCX)
|
||||
|
||||
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_audio(self, mock_request_parse):
|
||||
mock_content = "#test title\ntest_parse_audio"
|
||||
mock_request_parse.return_value = {
|
||||
"text": mock_content,
|
||||
"metadata": {},
|
||||
}
|
||||
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
|
||||
assert "text" in parse_ret and "metadata" in parse_ret
|
||||
assert parse_ret["text"] == mock_content
|
||||
|
||||
|
||||
class TestOmniParse:
|
||||
@pytest.fixture
|
||||
def mock_omniparse(self):
|
||||
parser = OmniParse(
|
||||
parse_options=OmniParseOptions(
|
||||
parse_type=OmniParseType.PDF,
|
||||
result_type=ParseResultType.MD,
|
||||
max_timeout=120,
|
||||
num_workers=3,
|
||||
)
|
||||
)
|
||||
return parser
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_parse(self, mocker):
|
||||
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_data(self, mock_omniparse, mock_request_parse):
|
||||
# mock
|
||||
mock_content = "#test title\ntest content"
|
||||
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
|
||||
mock_request_parse.return_value = mock_parsed_ret.model_dump()
|
||||
|
||||
# single file
|
||||
documents = mock_omniparse.load_data(file_path=TEST_PDF)
|
||||
doc = documents[0]
|
||||
assert isinstance(doc, Document)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
||||
# multi files
|
||||
file_paths = [TEST_DOCX, TEST_PDF]
|
||||
mock_omniparse.parse_type = OmniParseType.DOCUMENT
|
||||
documents = await mock_omniparse.aload_data(file_path=file_paths)
|
||||
doc = documents[0]
|
||||
|
||||
# assert
|
||||
assert isinstance(doc, Document)
|
||||
assert len(documents) == len(file_paths)
|
||||
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
|
||||
|
|
@ -10,7 +10,6 @@ import json
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -30,11 +29,7 @@ async def test_product_manager(new_filename):
|
|||
rsp = await product_manager.run(MockMessages.req)
|
||||
assert context.git_repo
|
||||
assert context.repo
|
||||
assert rsp.cause_by == any_to_str(PrepareDocuments)
|
||||
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
|
||||
|
||||
# write prd
|
||||
rsp = await product_manager.run(rsp)
|
||||
assert rsp.cause_by == any_to_str(WritePRD)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
|
|
@ -146,5 +147,21 @@ async def test_team_recover_multi_roles_save(mocker, context):
|
|||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context(context):
|
||||
context.kwargs.set("a", "a")
|
||||
context.cost_manager.max_budget = 9
|
||||
company = Team(context=context)
|
||||
|
||||
save_to = context.repo.workdir / "serial"
|
||||
company.serialize(save_to)
|
||||
|
||||
company.deserialize(save_to, Context())
|
||||
assert company.env.context.repo
|
||||
assert company.env.context.repo.workdir == context.repo.workdir
|
||||
assert company.env.context.kwargs.a == "a"
|
||||
assert company.env.context.cost_manager.max_budget == context.cost_manager.max_budget
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -105,11 +105,11 @@ def test_config_mixin_4_multi_inheritance_override_config():
|
|||
async def test_config_priority():
|
||||
"""If action's config is set, then its llm will be set, otherwise, it will use the role's llm"""
|
||||
home_dir = Path.home() / CONFIG_ROOT
|
||||
gpt4t = Config.from_home("gpt-4-1106-preview.yaml")
|
||||
gpt4t = Config.from_home("gpt-4-turbo.yaml")
|
||||
if not home_dir.exists():
|
||||
assert gpt4t is None
|
||||
gpt35 = Config.default()
|
||||
gpt35.llm.model = "gpt-3.5-turbo-1106"
|
||||
gpt35.llm.model = "gpt-4-turbo"
|
||||
gpt4 = Config.default()
|
||||
gpt4.llm.model = "gpt-4-0613"
|
||||
|
||||
|
|
@ -127,8 +127,8 @@ async def test_config_priority():
|
|||
env = Environment(desc="US election live broadcast")
|
||||
Team(investment=10.0, env=env, roles=[A, B, C])
|
||||
|
||||
assert a1.llm.model == "gpt-4-1106-preview" if Path(home_dir / "gpt-4-1106-preview.yaml").exists() else "gpt-4-0613"
|
||||
assert a1.llm.model == "gpt-4-turbo" if Path(home_dir / "gpt-4-turbo.yaml").exists() else "gpt-4-0613"
|
||||
assert a2.llm.model == "gpt-4-0613"
|
||||
assert a3.llm.model == "gpt-3.5-turbo-1106"
|
||||
assert a3.llm.model == "gpt-4-turbo"
|
||||
|
||||
# history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="a1", n_round=3)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def test_add_role(env: Environment):
|
|||
name="Alice", profile="product manager", goal="create a new product", constraints="limited resources"
|
||||
)
|
||||
env.add_role(role)
|
||||
assert env.get_role(role.profile) == role
|
||||
assert env.get_role(str(role._setting)) == role
|
||||
|
||||
|
||||
def test_get_roles(env: Environment):
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ class MockSearchEnine:
|
|||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.BING, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
|
|
|
|||
|
|
@ -2,7 +2,10 @@ from typing import Literal, Union
|
|||
|
||||
import pandas as pd
|
||||
|
||||
from metagpt.tools.tool_convert import convert_code_to_tool_schema
|
||||
from metagpt.tools.tool_convert import (
|
||||
convert_code_to_tool_schema,
|
||||
convert_code_to_tool_schema_ast,
|
||||
)
|
||||
|
||||
|
||||
class DummyClass:
|
||||
|
|
@ -128,3 +131,91 @@ def test_convert_code_to_tool_schema_function():
|
|||
def test_convert_code_to_tool_schema_async_function():
|
||||
schema = convert_code_to_tool_schema(dummy_async_fn)
|
||||
assert schema.get("type") == "async_function"
|
||||
|
||||
|
||||
TEST_CODE_FILE_TEXT = '''
|
||||
import pandas as pd # imported obj should not be parsed
|
||||
from some_module1 import some_imported_function, SomeImportedClass # imported obj should not be parsed
|
||||
from ..some_module2 import some_imported_function2 # relative import should not result in an error
|
||||
|
||||
class MyClass:
|
||||
"""This is a MyClass docstring."""
|
||||
def __init__(self, arg1):
|
||||
"""This is the constructor docstring."""
|
||||
self.arg1 = arg1
|
||||
|
||||
def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:
|
||||
"""
|
||||
This is a method docstring.
|
||||
|
||||
Args:
|
||||
arg2 (Union[list[str], str]): A union of a list of strings and a string.
|
||||
...
|
||||
|
||||
Returns:
|
||||
Tuple[int, str]: A tuple of an integer and a string.
|
||||
"""
|
||||
return self.arg4 + arg5
|
||||
|
||||
async def my_async_method(self, some_arg) -> str:
|
||||
return "hi"
|
||||
|
||||
def _private_method(self): # private should not be parsed
|
||||
return "private"
|
||||
|
||||
def my_function(arg1, arg2) -> dict:
|
||||
"""This is a function docstring."""
|
||||
return arg1 + arg2
|
||||
|
||||
def my_async_function(arg1, arg2) -> dict:
|
||||
return arg1 + arg2
|
||||
|
||||
def _private_function(): # private should not be parsed
|
||||
return "private"
|
||||
'''
|
||||
|
||||
|
||||
def test_convert_code_to_tool_schema_ast():
|
||||
expected = {
|
||||
"MyClass": {
|
||||
"type": "class",
|
||||
"description": "This is a MyClass docstring.",
|
||||
"methods": {
|
||||
"__init__": {
|
||||
"type": "function",
|
||||
"description": "This is the constructor docstring.",
|
||||
"signature": "(self, arg1)",
|
||||
"parameters": "",
|
||||
},
|
||||
"my_method": {
|
||||
"type": "function",
|
||||
"description": "This is a method docstring. ",
|
||||
"signature": "(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal['a', 'b', 'c'] = 'a') -> Tuple[int, str]",
|
||||
"parameters": "Args: arg2 (Union[list[str], str]): A union of a list of strings and a string. ... Returns: Tuple[int, str]: A tuple of an integer and a string.",
|
||||
},
|
||||
"my_async_method": {
|
||||
"type": "async_function",
|
||||
"description": "",
|
||||
"signature": "(self, some_arg) -> str",
|
||||
"parameters": "",
|
||||
},
|
||||
},
|
||||
"code": 'class MyClass:\n """This is a MyClass docstring."""\n def __init__(self, arg1):\n """This is the constructor docstring."""\n self.arg1 = arg1\n\n def my_method(self, arg2: Union[list[str], str], arg3: pd.DataFrame, arg4: int = 1, arg5: Literal["a","b","c"] = "a") -> Tuple[int, str]:\n """\n This is a method docstring.\n \n Args:\n arg2 (Union[list[str], str]): A union of a list of strings and a string.\n ...\n \n Returns:\n Tuple[int, str]: A tuple of an integer and a string.\n """\n return self.arg4 + arg5\n \n async def my_async_method(self, some_arg) -> str:\n return "hi"\n \n def _private_method(self): # private should not be parsed\n return "private"',
|
||||
},
|
||||
"my_function": {
|
||||
"type": "function",
|
||||
"description": "This is a function docstring.",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": 'def my_function(arg1, arg2) -> dict:\n """This is a function docstring."""\n return arg1 + arg2',
|
||||
},
|
||||
"my_async_function": {
|
||||
"type": "function",
|
||||
"description": "",
|
||||
"signature": "(arg1, arg2) -> dict",
|
||||
"parameters": "",
|
||||
"code": "def my_async_function(arg1, arg2) -> dict:\n return arg1 + arg2",
|
||||
},
|
||||
}
|
||||
schemas = convert_code_to_tool_schema_ast(TEST_CODE_FILE_TEXT)
|
||||
assert schemas == expected
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class TestUTWriter:
|
|||
)
|
||||
],
|
||||
created=1706710532,
|
||||
model="gpt-3.5-turbo-1106",
|
||||
model="gpt-4-turbo",
|
||||
object="chat.completion",
|
||||
system_fingerprint="fp_04f9a1eebf",
|
||||
usage=CompletionUsage(completion_tokens=35, prompt_tokens=1982, total_tokens=2017),
|
||||
|
|
|
|||
|
|
@ -12,11 +12,11 @@ from metagpt.utils.cost_manager import CostManager
|
|||
|
||||
def test_cost_manager():
|
||||
cm = CostManager(total_budget=20)
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1000
|
||||
assert cm.get_total_completion_tokens() == 100
|
||||
assert cm.get_total_cost() == 0.013
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-1106-preview")
|
||||
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-turbo")
|
||||
assert cm.get_total_prompt_tokens() == 1100
|
||||
assert cm.get_total_completion_tokens() == 110
|
||||
assert cm.get_total_cost() == 0.0143
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.utils.mermaid import MMC1, mermaid_to_file
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink"]) # TODO: playwright and pyppeteer
|
||||
@pytest.mark.parametrize("engine", ["nodejs", "ink", "playwright", "pyppeteer"])
|
||||
async def test_mermaid(engine, context, mermaid_mocker):
|
||||
# nodejs prerequisites: npm install -g @mermaid-js/mermaid-cli
|
||||
# ink prerequisites: connected to internet
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ def _paragraphs(n):
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"msgs, model_name, system_text, reserved, expected",
|
||||
"msgs, model, system_text, reserved, expected",
|
||||
[
|
||||
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
|
||||
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
|
||||
|
|
@ -37,7 +37,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text, prompt_template, model_name, system_text, reserved, expected",
|
||||
"text, prompt_template, model, system_text, reserved, expected",
|
||||
[
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
|
||||
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
from metagpt.utils.token_counter import count_input_tokens, count_output_tokens
|
||||
|
||||
|
||||
def test_count_message_tokens():
|
||||
|
|
@ -15,7 +15,7 @@ def test_count_message_tokens():
|
|||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages) == 15
|
||||
assert count_input_tokens(messages) == 15
|
||||
|
||||
|
||||
def test_count_message_tokens_with_name():
|
||||
|
|
@ -23,12 +23,12 @@ def test_count_message_tokens_with_name():
|
|||
{"role": "user", "content": "Hello", "name": "John"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages) == 17
|
||||
assert count_input_tokens(messages) == 17
|
||||
|
||||
|
||||
def test_count_message_tokens_empty_input():
|
||||
"""Empty input should return 3 tokens"""
|
||||
assert count_message_tokens([]) == 3
|
||||
assert count_input_tokens([]) == 3
|
||||
|
||||
|
||||
def test_count_message_tokens_invalid_model():
|
||||
|
|
@ -38,7 +38,7 @@ def test_count_message_tokens_invalid_model():
|
|||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
with pytest.raises(NotImplementedError):
|
||||
count_message_tokens(messages, model="invalid_model")
|
||||
count_input_tokens(messages, model="invalid_model")
|
||||
|
||||
|
||||
def test_count_message_tokens_gpt_4():
|
||||
|
|
@ -46,27 +46,27 @@ def test_count_message_tokens_gpt_4():
|
|||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
assert count_message_tokens(messages, model="gpt-4-0314") == 15
|
||||
assert count_input_tokens(messages, model="gpt-4-0314") == 15
|
||||
|
||||
|
||||
def test_count_string_tokens():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4
|
||||
assert count_output_tokens(string, model="gpt-3.5-turbo-0301") == 4
|
||||
|
||||
|
||||
def test_count_string_tokens_empty_input():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0
|
||||
assert count_output_tokens("", model="gpt-3.5-turbo-0301") == 0
|
||||
|
||||
|
||||
def test_count_string_tokens_gpt_4():
|
||||
"""Test that the string tokens are counted correctly."""
|
||||
|
||||
string = "Hello, world!"
|
||||
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
|
||||
assert count_output_tokens(string, model="gpt-4-0314") == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -114,7 +114,6 @@ class MockLLM(OriginalLLM):
|
|||
raise ValueError(
|
||||
"In current test setting, api call is not allowed, you should properly mock your tests, "
|
||||
"or add expected api response in tests/data/rsp_cache.json. "
|
||||
f"The prompt you want for api call: {msg_key}"
|
||||
)
|
||||
# Call the original unmocked method
|
||||
rsp = await ask_func(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue