fix conflicts

This commit is contained in:
seehi 2025-02-26 22:20:14 +08:00
commit 77703f1236
347 changed files with 21628 additions and 1350 deletions

View file

@ -1,7 +1,7 @@
llm:
base_url: "https://api.openai.com/v1"
api_key: "sk-xxx"
model: "gpt-3.5-turbo-1106"
model: "gpt-3.5-turbo"
search:
api_type: "serpapi"

View file

@ -0,0 +1,2 @@
!*.png
unitest_Contacts

Binary file not shown.

After

Width:  |  Height:  |  Size: 611 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 840 KiB

View file

@ -0,0 +1,2 @@
tap(9):::android.view.ViewGroup_1067_236_android.widget.TextView_183_204_Apps_2
stop

View file

@ -0,0 +1 @@
Create a contact in Contacts App named zjy with a phone number +86 18831933368

View file

@ -0,0 +1,27 @@
llm:
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_gpt-3.5-turbo_BASE_URL"
api_key: "YOUR_gpt-3.5-turbo_API_KEY"
model: "gpt-3.5-turbo" # or gpt-3.5-turbo
# proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
models:
"YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_1_BASE_URL"
api_key: "YOUR_MODEL_1_API_KEY"
# proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
"YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_2_BASE_URL"
api_key: "YOUR_MODEL_2_API_KEY"
proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's

View file

@ -0,0 +1,27 @@
llm:
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_gpt-3.5-turbo_BASE_URL"
api_key: "YOUR_gpt-3.5-turbo_API_KEY"
model: "gpt-3.5-turbo" # or gpt-3.5-turbo
# proxy: "YOUR_gpt-3.5-turbo_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
models:
"YOUR_MODEL_NAME_1": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_1_BASE_URL"
api_key: "YOUR_MODEL_1_API_KEY"
# proxy: "YOUR_MODEL_1_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's
"YOUR_MODEL_NAME_2": # model: "gpt-4-turbo" # or gpt-3.5-turbo
api_type: "openai" # or azure / ollama / groq etc.
base_url: "YOUR_MODEL_2_BASE_URL"
api_key: "YOUR_MODEL_2_API_KEY"
proxy: "YOUR_MODEL_2_PROXY" # for LLM API requests
# timeout: 600 # Optional. If set to 0, default value is 300.
# Details: https://azure.microsoft.com/en-us/pricing/details/cognitive-services/openai-service/
pricing_plan: "" # Optional. Use for Azure LLM when its model name is not the same as OpenAI's

File diff suppressed because one or more lines are too long

View file

@ -38,7 +38,7 @@ DESIGN = {
TASK = {
"Required Python packages": ["pygame==2.0.1"],
"Required packages": ["pygame==2.0.1"],
"Required Other language third-party packages": ["No third-party dependencies required"],
"Logic Analysis": [
["game.py", "Contains Game class and related functions for game logic"],

View file

@ -0,0 +1,45 @@
from metagpt.actions.action import Action
from metagpt.config2 import Config
from metagpt.const import TEST_DATA_PATH
from metagpt.context import Context
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.roles.role import Role
def test_set_llm():
config1 = Config.default()
config2 = Config.default()
config2.llm.model = "gpt-3.5-turbo"
context = Context(config=config1)
act = Action(context=context)
assert act.config.llm.model == config1.llm.model
llm2 = create_llm_instance(config2.llm)
act.llm = llm2
assert act.llm.model == llm2.model
role = Role(context=context)
role.set_actions([act])
assert act.llm.model == llm2.model
role1 = Role(context=context)
act1 = Action(context=context)
assert act1.config.llm.model == config1.llm.model
act1.config = config2
role1.set_actions([act1])
assert act1.llm.model == llm2.model
# multiple LLM
config3_path = TEST_DATA_PATH / "config/config2_multi_llm.yaml"
dict3 = Config.read_yaml(config3_path)
config3 = Config(**dict3)
context3 = Context(config=config3)
role3 = Role(context=context3)
act3 = Action(context=context3, llm_name_or_type="YOUR_MODEL_NAME_1")
assert act3.config.llm.model == "gpt-3.5-turbo"
assert act3.llm.model == "gpt-4-turbo"
role3.set_actions([act3])
assert act3.config.llm.model == "gpt-3.5-turbo"
assert act3.llm.model == "gpt-4-turbo"

View file

@ -6,7 +6,7 @@
@File : test_action_node.py
"""
from pathlib import Path
from typing import List, Tuple
from typing import List, Optional, Tuple
import pytest
from pydantic import BaseModel, Field, ValidationError
@ -302,5 +302,19 @@ def test_action_node_from_pydantic_and_print_everything():
assert "tasks" in code, "tasks should be in code"
def test_optional():
mapping = {
"Logic Analysis": (Optional[List[Tuple[str, str]]], Field(default=None)),
"Task list": (Optional[List[str]], None),
"Plan": (Optional[str], ""),
"Anything UNCLEAR": (Optional[str], None),
}
m = {"Anything UNCLEAR": "a"}
t = ActionNode.create_model_class("test_class_1", mapping)
t1 = t(**m)
assert t1
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

View file

@ -0,0 +1,34 @@
import pytest
from metagpt.actions.talk_action import TalkAction
from metagpt.configs.models_config import ModelsConfig
from metagpt.const import METAGPT_ROOT, TEST_DATA_PATH
from metagpt.utils.common import aread, awrite
@pytest.mark.asyncio
async def test_models_configs(context):
default_model = ModelsConfig.default()
assert default_model is not None
models = ModelsConfig.from_yaml_file(TEST_DATA_PATH / "config/config2.yaml")
assert models
default_models = ModelsConfig.default()
backup = ""
if not default_models.models:
backup = await aread(filename=METAGPT_ROOT / "config/config2.yaml")
test_data = await aread(filename=TEST_DATA_PATH / "config/config2.yaml")
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=test_data)
try:
action = TalkAction(context=context, i_context="who are you?", llm_name_or_type="YOUR_MODEL_NAME_1")
assert action.private_llm.config.model == "YOUR_MODEL_NAME_1"
assert context.config.llm.model != "YOUR_MODEL_NAME_1"
finally:
if backup:
await awrite(filename=METAGPT_ROOT / "config/config2.yaml", data=backup)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -0,0 +1,48 @@
import random
import pytest
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
seed_value = 42
random.seed(seed_value)
vectors = [[random.random() for _ in range(8)] for _ in range(10)]
ids = [f"doc_{i}" for i in range(10)]
metadata = [{"color": "red", "rand_number": i % 10} for i in range(10)]
def assert_almost_equal(actual, expected):
delta = 1e-10
if isinstance(expected, list):
assert len(actual) == len(expected)
for ac, exp in zip(actual, expected):
assert abs(ac - exp) <= delta, f"{ac} is not within {delta} of {exp}"
else:
assert abs(actual - expected) <= delta, f"{actual} is not within {delta} of {expected}"
@pytest.mark.skip() # Skip because the pymilvus dependency is not installed by default
def test_milvus_store():
milvus_connection = MilvusConnection(uri="./milvus_local.db")
milvus_store = MilvusStore(milvus_connection)
collection_name = "TestCollection"
milvus_store.create_collection(collection_name, dim=8)
milvus_store.add(collection_name, ids, vectors, metadata)
search_results = milvus_store.search(collection_name, query=[1.0] * 8)
assert len(search_results) > 0
first_result = search_results[0]
assert first_result["id"] == "doc_0"
search_results_with_filter = milvus_store.search(collection_name, query=[1.0] * 8, filter={"rand_number": 1})
assert len(search_results_with_filter) > 0
assert search_results_with_filter[0]["id"] == "doc_1"
milvus_store.delete(collection_name, _ids=["doc_0"])
deleted_results = milvus_store.search(collection_name, query=[1.0] * 8, limit=1)
assert deleted_results[0]["id"] != "doc_0"
milvus_store.client.drop_collection(collection_name)

View file

@ -16,8 +16,8 @@ def mock_device_shape_invalid(self, adb_cmd: str) -> str:
return ADB_EXEC_FAIL
def mock_list_devices(self, adb_cmd: str) -> str:
return "devices\nemulator-5554"
def mock_list_devices(self) -> str:
return ["emulator-5554"]
def mock_get_screenshot(self, adb_cmd: str) -> str:
@ -35,6 +35,7 @@ def mock_write_read_operation(self, adb_cmd: str) -> str:
def test_android_ext_env(mocker):
device_id = "emulator-5554"
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_device_shape)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices", mock_list_devices)
ext_env = AndroidExtEnv(device_id=device_id, screenshot_dir="/data2/", xml_dir="/data2/")
assert ext_env.adb_prefix == f"adb -s {device_id} "
@ -48,7 +49,6 @@ def test_android_ext_env(mocker):
)
assert ext_env.device_shape == (0, 0)
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_list_devices)
assert ext_env.list_devices() == [device_id]
mocker.patch("metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd", mock_get_screenshot)

View file

@ -64,7 +64,7 @@ async def test_ext_env():
_ = await env.write_thru_api(EnvAPIAbstract(api_name="write_api", kwargs={"a": 5, "b": 10}))
assert env.value == 15
with pytest.raises(ValueError):
with pytest.raises(KeyError):
await env.read_from_api("not_exist_api")
assert await env.read_from_api("read_api_no_param") == 15

View file

@ -2,33 +2,34 @@
# -*- coding: utf-8 -*-
# @Desc : the unittest of WerewolfExtEnv
from metagpt.environment.werewolf.werewolf_ext_env import RoleState, WerewolfExtEnv
from metagpt.environment.werewolf.const import RoleState, RoleType
from metagpt.environment.werewolf.werewolf_ext_env import WerewolfExtEnv
from metagpt.roles.role import Role
class Werewolf(Role):
profile: str = "Werewolf"
profile: str = RoleType.WEREWOLF.value
class Villager(Role):
profile: str = "Villager"
profile: str = RoleType.VILLAGER.value
class Witch(Role):
profile: str = "Witch"
profile: str = RoleType.WITCH.value
class Guard(Role):
profile: str = "Guard"
profile: str = RoleType.GUARD.value
def test_werewolf_ext_env():
players_state = {
"Player0": ("Werewolf", RoleState.ALIVE),
"Player1": ("Werewolf", RoleState.ALIVE),
"Player2": ("Villager", RoleState.ALIVE),
"Player3": ("Witch", RoleState.ALIVE),
"Player4": ("Guard", RoleState.ALIVE),
"Player0": (RoleType.WEREWOLF.value, RoleState.ALIVE),
"Player1": (RoleType.WEREWOLF.value, RoleState.ALIVE),
"Player2": (RoleType.VILLAGER.value, RoleState.ALIVE),
"Player3": (RoleType.WITCH.value, RoleState.ALIVE),
"Player4": (RoleType.GUARD.value, RoleState.ALIVE),
}
ext_env = WerewolfExtEnv(players_state=players_state, step_idx=4, special_role_players=["Player3", "Player4"])
@ -41,9 +42,9 @@ def test_werewolf_ext_env():
assert "Werewolves, please open your eyes" in curr_instr["content"]
# current step_idx = 5
ext_env.wolf_kill_someone(wolf=Role(name="Player10"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player0"), player_name="Player4")
ext_env.wolf_kill_someone(wolf=Werewolf(name="Player1"), player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player10", player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player0", player_name="Player4")
ext_env.wolf_kill_someone(wolf_name="Player1", player_name="Player4")
assert ext_env.player_hunted == "Player4"
assert len(ext_env.living_players) == 5 # hunted but can be saved by witch
@ -52,11 +53,11 @@ def test_werewolf_ext_env():
# current step_idx = 18
assert ext_env.step_idx == 18
ext_env.vote_kill_someone(voteer=Werewolf(name="Player0"), player_name="Player2")
ext_env.vote_kill_someone(voteer=Werewolf(name="Player1"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Villager(name="Player2"), player_name="Player3")
ext_env.vote_kill_someone(voteer=Witch(name="Player3"), player_name="Player4")
ext_env.vote_kill_someone(voteer=Guard(name="Player4"), player_name="Player2")
ext_env.vote_kill_someone(voter_name="Player0", player_name="Player2")
ext_env.vote_kill_someone(voter_name="Player1", player_name="Player3")
ext_env.vote_kill_someone(voter_name="Player2", player_name="Player3")
ext_env.vote_kill_someone(voter_name="Player3", player_name="Player4")
ext_env.vote_kill_someone(voter_name="Player4", player_name="Player2")
assert ext_env.player_current_dead == "Player2"
assert len(ext_env.living_players) == 4

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,95 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : test on android emulator action. After Modify Role Test, this script is discarded.
import asyncio
import time
from pathlib import Path
import metagpt
from metagpt.const import TEST_DATA_PATH
from metagpt.environment.android.android_env import AndroidEnv
from metagpt.ext.android_assistant.actions.manual_record import ManualRecord
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
from metagpt.ext.android_assistant.actions.screenshot_parse import ScreenshotParse
from metagpt.ext.android_assistant.actions.self_learn_and_reflect import (
SelfLearnAndReflect,
)
from tests.metagpt.environment.android_env.test_android_ext_env import (
mock_device_shape,
mock_list_devices,
)
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/unitest_Contacts")
TASK_PATH.mkdir(parents=True, exist_ok=True)
DEMO_NAME = str(time.time())
SELF_EXPLORE_DOC_PATH = TASK_PATH.joinpath("auto_docs")
PARSE_RECORD_DOC_PATH = TASK_PATH.joinpath("demo_docs")
device_id = "emulator-5554"
xml_dir = Path("/sdcard")
screenshot_dir = Path("/sdcard/Pictures/Screenshots")
metagpt.environment.android.android_ext_env.AndroidExtEnv.execute_adb_with_cmd = mock_device_shape
metagpt.environment.android.android_ext_env.AndroidExtEnv.list_devices = mock_list_devices
test_env_self_learn_android = AndroidEnv(
device_id=device_id,
xml_dir=xml_dir,
screenshot_dir=screenshot_dir,
)
test_self_learning = SelfLearnAndReflect()
test_env_manual_learn_android = AndroidEnv(
device_id=device_id,
xml_dir=xml_dir,
screenshot_dir=screenshot_dir,
)
test_manual_record = ManualRecord()
test_manual_parse = ParseRecord()
test_env_screenshot_parse_android = AndroidEnv(
device_id=device_id,
xml_dir=xml_dir,
screenshot_dir=screenshot_dir,
)
test_screenshot_parse = ScreenshotParse()
if __name__ == "__main__":
loop = asyncio.get_event_loop()
test_action_list = [
test_self_learning.run(
round_count=20,
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
last_act="",
task_dir=TASK_PATH / "demos" / f"self_learning_{DEMO_NAME}",
docs_dir=SELF_EXPLORE_DOC_PATH,
env=test_env_self_learn_android,
),
test_manual_record.run(
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}",
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
env=test_env_manual_learn_android,
),
test_manual_parse.run(
task_dir=TASK_PATH / "demos" / f"manual_record_{DEMO_NAME}", # 修要修改
docs_dir=PARSE_RECORD_DOC_PATH, # 需要修改
env=test_env_manual_learn_android,
),
test_screenshot_parse.run(
round_count=20,
task_desc="Create a contact in Contacts App named zjy with a phone number +86 18831933368 ",
last_act="",
task_dir=TASK_PATH / f"act_{DEMO_NAME}",
docs_dir=PARSE_RECORD_DOC_PATH,
env=test_env_screenshot_parse_android,
grid_on=False,
),
]
loop.run_until_complete(asyncio.gather(*test_action_list))
loop.close()

View file

@ -0,0 +1,29 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : test case (imgs from appagent's)
import asyncio
from metagpt.actions.action import Action
from metagpt.const import TEST_DATA_PATH
from metagpt.ext.android_assistant.actions.parse_record import ParseRecord
TASK_PATH = TEST_DATA_PATH.joinpath("andriod_assistant/demo_Contacts")
TEST_BEFORE_PATH = TASK_PATH.joinpath("labeled_screenshots/0_labeled.png")
TEST_AFTER_PATH = TASK_PATH.joinpath("labeled_screenshots/1_labeled.png")
RECORD_PATH = TASK_PATH.joinpath("record.txt")
TASK_DESC_PATH = TASK_PATH.joinpath("task_desc.txt")
DOCS_DIR = TASK_PATH.joinpath("storage")
test_action = Action(name="test")
async def manual_learn_test():
parse_record = ParseRecord()
await parse_record.run(app_name="demo_Contacts", task_dir=TASK_PATH, docs_dir=DOCS_DIR)
if __name__ == "__main__":
loop = asyncio.get_event_loop()
loop.run_until_complete(manual_learn_test())
loop.close()

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,164 @@
import json
import pytest
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.ext.werewolf.actions import AddNewExperiences, RetrieveExperiences
from metagpt.ext.werewolf.schema import RoleExperience
from metagpt.logs import logger
class TestExperiencesOperation:
collection_name = "test"
test_round_id = "test_01"
version = "test"
samples_to_add = [
RoleExperience(
profile="Witch",
reflection="The game is intense with two players claiming to be the Witch and one claiming to be the Seer. "
"Player4's behavior is suspicious.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="Witch",
reflection="The game is in a critical state with only three players left, "
"and I need to make a wise decision to save Player7 or not.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="Seer",
reflection="Player1, who is a werewolf, falsely claimed to be a Seer, and Player6, who might be a Witch, "
"sided with him. I, as the real Seer, am under suspicion.",
response="",
outcome="",
round_id=test_round_id,
version=version,
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection1",
response="",
outcome="",
round_id=test_round_id,
version=version + "_01-10",
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection2",
response="",
outcome="",
round_id=test_round_id,
version=version + "_11-20",
),
RoleExperience(
profile="TestRole",
reflection="Some test reflection3",
response="",
outcome="",
round_id=test_round_id,
version=version + "_21-30",
),
]
@pytest.mark.asyncio
async def test_add(self):
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
)
if saved_file.exists():
saved_file.unlink()
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.run(self.samples_to_add)
# test insertion
inserted = action.engine.retriever._index._vector_store._collection.get()
assert len(inserted["documents"]) == len(self.samples_to_add)
# test if we record the samples correctly to local file
# & test if we could recover a embedding db from the file
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
action.add_from_file(saved_file)
inserted = action.engine.retriever._index._vector_store._collection.get()
assert len(inserted["documents"]) == len(self.samples_to_add)
@pytest.mark.asyncio
async def test_retrieve(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "one player claimed to be Seer and the other Witch"
results = action.run(query, profile="Witch")
results = json.loads(results)
assert len(results) == 2, "Witch should have 2 experiences"
assert "The game is intense with two players" in results[0]
@pytest.mark.asyncio
async def test_retrieve_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "some test query"
profile = "TestRole"
excluded_version = ""
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 3
excluded_version = self.version + "_21-30"
results = action.run(query, profile=profile, excluded_version=excluded_version)
results = json.loads(results)
assert len(results) == 2
class TestActualRetrieve:
collection_name = "role_reflection"
@pytest.mark.asyncio
async def test_check_experience_pool(self):
logger.info("check experience pool")
action = RetrieveExperiences(collection_name=self.collection_name)
if action.engine:
all_experiences = action.engine.retriever._index._vector_store._collection.get()
logger.info(f"{len(all_experiences['metadatas'])=}")
@pytest.mark.asyncio
async def test_retrieve_werewolf_experience(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
logger.info(f"test retrieval with {query=}")
action.run(query, "Werewolf")
@pytest.mark.asyncio
async def test_retrieve_villager_experience(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
logger.info(f"test retrieval with {query=}")
results = action.run(query, "Seer")
assert "conflict" not in results # 相似局面应该需要包含conflict关键词
@pytest.mark.asyncio
async def test_retrieve_villager_experience_filtering(self):
action = RetrieveExperiences(collection_name=self.collection_name)
query = "there are conflicts"
excluded_version = "01-10"
logger.info(f"test retrieval with {excluded_version=}")
results_01_10 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
excluded_version = "11-20"
logger.info(f"test retrieval with {excluded_version=}")
results_11_20 = action.run(query, profile="Seer", excluded_version=excluded_version, verbose=True)
assert results_01_10 == results_11_20

View file

@ -1,7 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : __init__.py
"""

View file

@ -1,32 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import pytest
from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from metagpt.actions import UserRequirement
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_action_planner():
role = SkAgent(planner_cls=ActionPlanner)
# let's give the agent 4 skills
role.import_skill(MathSkill(), "math")
role.import_skill(FileIOSkill(), "fileIO")
role.import_skill(TimeSkill(), "time")
role.import_skill(TextSkill(), "text")
task = "What is the sum of 110 and 990?"
role.put_message(Message(content=task, cause_by=UserRequirement))
await role._observe()
await role._think() # it will choose mathskill.Add
assert "1100" == (await role._act()).content

View file

@ -1,37 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/9/16 20:03
@Author : femto Zheng
@File : test_basic_planner.py
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import pytest
from semantic_kernel.core_skills import TextSkill
from metagpt.actions import UserRequirement
from metagpt.const import SKILL_DIRECTORY
from metagpt.roles.sk_agent import SkAgent
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_basic_planner():
task = """
Tomorrow is Valentine's day. I need to come up with a few date ideas. She speaks French so write it in French.
Convert the text to uppercase"""
role = SkAgent()
# let's give the agent some skills
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "SummarizeSkill")
role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill")
role.import_skill(TextSkill(), "TextSkill")
# using BasicPlanner
role.put_message(Message(content=task, cause_by=UserRequirement))
await role._observe()
await role._think()
# assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate
assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result
assert "WriterSkill.Translate" in role.plan.generated_plan.result
# assert "SALUT" in (await role._act()).content #content will be some French

View file

@ -60,3 +60,14 @@ mock_llm_config_dashscope = LLMConfig(api_type="dashscope", api_key="xxx", model
mock_llm_config_anthropic = LLMConfig(
api_type="anthropic", api_key="xxx", base_url="https://api.anthropic.com", model="claude-3-opus-20240229"
)
mock_llm_config_bedrock = LLMConfig(
api_type="bedrock",
model="gpt-100",
region_name="somewhere",
access_key="123abc",
secret_key="123abc",
max_token=10000,
)
mock_llm_config_ark = LLMConfig(api_type="ark", api_key="eyxxx", base_url="xxx", model="ep-xxx")

View file

@ -183,3 +183,90 @@ async def llm_general_chat_funcs_test(llm: BaseLLM, prompt: str, messages: list[
resp = await llm.acompletion_text(messages, stream=True)
assert resp == resp_cont
# For Amazon Bedrock
# Check the API documentation of each model
# https://docs.aws.amazon.com/bedrock/latest/userguide
BEDROCK_PROVIDER_REQUEST_BODY = {
"mistral": {"prompt": "", "max_tokens": 0, "stop": [], "temperature": 0.0, "top_p": 0.0, "top_k": 0},
"meta": {"prompt": "", "temperature": 0.0, "top_p": 0.0, "max_gen_len": 0},
"ai21": {
"prompt": "",
"temperature": 0.0,
"topP": 0.0,
"maxTokens": 0,
"stopSequences": [],
"countPenalty": {"scale": 0.0},
"presencePenalty": {"scale": 0.0},
"frequencyPenalty": {"scale": 0.0},
},
"cohere": {
"prompt": "",
"temperature": 0.0,
"p": 0.0,
"k": 0.0,
"max_tokens": 0,
"stop_sequences": [],
"return_likelihoods": "NONE",
"stream": False,
"num_generations": 0,
"logit_bias": {},
"truncate": "NONE",
},
"anthropic": {
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": 0,
"system": "",
"messages": [{"role": "", "content": ""}],
"temperature": 0.0,
"top_p": 0.0,
"top_k": 0,
"stop_sequences": [],
},
"amazon": {
"inputText": "",
"textGenerationConfig": {"temperature": 0.0, "topP": 0.0, "maxTokenCount": 0, "stopSequences": []},
},
}
BEDROCK_PROVIDER_RESPONSE_BODY = {
"mistral": {"outputs": [{"text": "Hello World", "stop_reason": ""}]},
"meta": {"generation": "Hello World", "prompt_token_count": 0, "generation_token_count": 0, "stop_reason": ""},
"ai21": {
"id": "",
"prompt": {"text": "Hello World", "tokens": []},
"completions": [
{"data": {"text": "Hello World", "tokens": []}, "finishReason": {"reason": "length", "length": 2}}
],
},
"cohere": {
"generations": [
{
"finish_reason": "",
"id": "",
"text": "Hello World",
"likelihood": 0.0,
"token_likelihoods": [{"token": 0.0}],
"is_finished": True,
"index": 0,
}
],
"id": "",
"prompt": "",
},
"anthropic": {
"id": "",
"model": "",
"type": "message",
"role": "assistant",
"content": [{"type": "text", "text": "Hello World"}],
"stop_reason": "",
"stop_sequence": "",
"usage": {"input_tokens": 0, "output_tokens": 0},
},
"amazon": {
"inputTextTokenCount": 0,
"results": [{"tokenCount": 0, "outputText": "Hello World", "completionReason": ""}],
},
}

View file

@ -0,0 +1,85 @@
"""
用于火山方舟Python SDK V3的测试用例
API文档https://www.volcengine.com/docs/82379/1263482
"""
from typing import AsyncIterator, List, Union
import pytest
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from metagpt.provider.ark_api import ArkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_ark
from tests.metagpt.provider.req_resp_const import (
get_openai_chat_completion,
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
name = "AI assistant"
resp_cont = resp_cont_tmpl.format(name=name)
USAGE = {"completion_tokens": 1000, "prompt_tokens": 1000, "total_tokens": 2000}
default_resp = get_openai_chat_completion(name)
default_resp.model = "doubao-pro-32k-240515"
default_resp.usage = USAGE
def create_chat_completion_chunk(
content: str, finish_reason: str = None, choices: List[Choice] = None
) -> ChatCompletionChunk:
if choices is None:
choices = [
Choice(
delta=ChoiceDelta(content=content, function_call=None, role="assistant", tool_calls=None),
finish_reason=finish_reason,
index=0,
logprobs=None,
)
]
return ChatCompletionChunk(
id="012",
choices=choices,
created=1716278586,
model="doubao-pro-32k-240515",
object="chat.completion.chunk",
system_fingerprint=None,
usage=None if choices else USAGE,
)
ark_resp_chunk = create_chat_completion_chunk(content="")
ark_resp_chunk_finish = create_chat_completion_chunk(content=resp_cont, finish_reason="stop")
ark_resp_chunk_last = create_chat_completion_chunk(content="", choices=[])
async def chunk_iterator(chunks: List[ChatCompletionChunk]) -> AsyncIterator[ChatCompletionChunk]:
for chunk in chunks:
yield chunk
async def mock_ark_acompletions_create(
self, stream: bool = False, **kwargs
) -> Union[ChatCompletionChunk, ChatCompletion]:
if stream:
chunks = [ark_resp_chunk, ark_resp_chunk_finish, ark_resp_chunk_last]
return chunk_iterator(chunks)
else:
return default_resp
@pytest.mark.asyncio
async def test_ark_acompletion(mocker):
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_ark_acompletions_create)
llm = ArkLLM(mock_llm_config_ark)
resp = await llm.acompletion(messages)
assert resp.choices[0].finish_reason == "stop"
assert resp.choices[0].message.content == resp_cont
assert resp.usage == USAGE
await llm_general_chat_funcs_test(llm, prompt, messages, resp_cont)

View file

@ -0,0 +1,109 @@
import json
import pytest
from metagpt.provider.bedrock.utils import (
NOT_SUPPORT_STREAM_MODELS,
SUPPORT_STREAM_MODELS,
)
from metagpt.provider.bedrock_api import BedrockLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_bedrock
from tests.metagpt.provider.req_resp_const import (
BEDROCK_PROVIDER_REQUEST_BODY,
BEDROCK_PROVIDER_RESPONSE_BODY,
)
# all available model from bedrock
models = SUPPORT_STREAM_MODELS | NOT_SUPPORT_STREAM_MODELS
messages = [{"role": "user", "content": "Hi!"}]
usage = {
"prompt_tokens": 1000000,
"completion_tokens": 1000000,
}
def mock_invoke_model(self: BedrockLLM, *args, **kwargs) -> dict:
provider = self.config.model.split(".")[0]
self._update_costs(usage, self.config.model)
return BEDROCK_PROVIDER_RESPONSE_BODY[provider]
def mock_invoke_model_stream(self: BedrockLLM, *args, **kwargs) -> dict:
# use json object to mock EventStream
def dict2bytes(x):
return json.dumps(x).encode("utf-8")
provider = self.config.model.split(".")[0]
if provider == "amazon":
response_body_bytes = dict2bytes({"outputText": "Hello World"})
elif provider == "anthropic":
response_body_bytes = dict2bytes(
{"type": "content_block_delta", "index": 0, "delta": {"type": "text_delta", "text": "Hello World"}}
)
elif provider == "cohere":
response_body_bytes = dict2bytes({"is_finished": False, "text": "Hello World"})
else:
response_body_bytes = dict2bytes(BEDROCK_PROVIDER_RESPONSE_BODY[provider])
response_body_stream = {"body": [{"chunk": {"bytes": response_body_bytes}}]}
self._update_costs(usage, self.config.model)
return response_body_stream
def get_bedrock_request_body(model_id) -> dict:
provider = model_id.split(".")[0]
return BEDROCK_PROVIDER_REQUEST_BODY[provider]
def is_subset(subset, superset) -> bool:
"""Ensure all fields in request body are allowed.
```python
subset = {"prompt": "hello","kwargs": {"temperature": 0.9,"p": 0.0}}
superset = {"prompt": "hello", "kwargs": {"temperature": 0.0, "top-p": 0.0}}
is_subset(subset, superset)
```
"""
for key, value in subset.items():
if key not in superset:
return False
if isinstance(value, dict):
if not isinstance(superset[key], dict):
return False
if not is_subset(value, superset[key]):
return False
return True
@pytest.fixture(scope="class", params=models)
def bedrock_api(request) -> BedrockLLM:
model_id = request.param
mock_llm_config_bedrock.model = model_id
api = BedrockLLM(mock_llm_config_bedrock)
return api
class TestBedrockAPI:
def _patch_invoke_model(self, mocker):
mocker.patch("metagpt.provider.bedrock_api.BedrockLLM.invoke_model", mock_invoke_model)
def _patch_invoke_model_stream(self, mocker):
mocker.patch(
"metagpt.provider.bedrock_api.BedrockLLM.invoke_model_with_response_stream",
mock_invoke_model_stream,
)
def test_get_request_body(self, bedrock_api: BedrockLLM):
"""Ensure request body has correct format"""
provider = bedrock_api.provider
request_body = json.loads(provider.get_request_body(messages, bedrock_api._const_kwargs))
assert is_subset(request_body, get_bedrock_request_body(bedrock_api.config.model))
@pytest.mark.asyncio
async def test_aask(self, bedrock_api: BedrockLLM, mocker):
self._patch_invoke_model(mocker)
self._patch_invoke_model_stream(mocker)
assert await bedrock_api.aask(messages, stream=False) == "Hello World"
assert await bedrock_api.aask(messages, stream=True) == "Hello World"

View file

@ -3,11 +3,11 @@
# @Desc : the unittest of ollama api
import json
from typing import Any, Tuple
from typing import Any, AsyncGenerator, Tuple
import pytest
from metagpt.provider.ollama_api import OllamaLLM
from metagpt.provider.ollama_api import OllamaLLM, OpenAIResponse
from tests.metagpt.provider.mock_llm_config import mock_llm_config
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
@ -23,21 +23,19 @@ default_resp = {"message": {"role": "assistant", "content": resp_cont}}
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
if stream:
class Iterator(object):
async def async_event_generator() -> AsyncGenerator[Any, None]:
events = [
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
]
for event in events:
yield OpenAIResponse(event, {})
async def __aiter__(self):
for event in self.events:
yield event
return Iterator(), None, None
return async_event_generator(), None, None
else:
raw_default_resp = default_resp.copy()
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
return json.dumps(raw_default_resp).encode(), None, None
return OpenAIResponse(json.dumps(raw_default_resp).encode(), {}), None, None
@pytest.mark.asyncio

View file

@ -1,62 +1,55 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the unittest of spark api
"""
用于讯飞星火SDK的测试用例
文档https://www.xfyun.cn/doc/spark/Web.html
"""
from typing import AsyncIterator, List
import pytest
from sparkai.core.messages.ai import AIMessage, AIMessageChunk
from sparkai.core.outputs.chat_generation import ChatGeneration
from sparkai.core.outputs.llm_result import LLMResult
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
from tests.metagpt.provider.mock_llm_config import (
mock_llm_config,
mock_llm_config_spark,
)
from metagpt.provider.spark_api import SparkLLM
from tests.metagpt.provider.mock_llm_config import mock_llm_config_spark
from tests.metagpt.provider.req_resp_const import (
llm_general_chat_funcs_test,
messages,
prompt,
resp_cont_tmpl,
)
resp_cont = resp_cont_tmpl.format(name="Spark")
USAGE = {
"token_usage": {"question_tokens": 1000, "prompt_tokens": 1000, "completion_tokens": 1000, "total_tokens": 2000}
}
spark_agenerate_result = LLMResult(
generations=[[ChatGeneration(text=resp_cont, message=AIMessage(content=resp_cont, additional_kwargs=USAGE))]]
)
chunks = [AIMessageChunk(content=resp_cont), AIMessageChunk(content="", additional_kwargs=USAGE)]
class MockWebSocketApp(object):
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
pass
def run_forever(self, sslopt=None):
pass
async def chunk_iterator(chunks: List[AIMessageChunk]) -> AsyncIterator[AIMessageChunk]:
for chunk in chunks:
yield chunk
def test_get_msg_from_web(mocker):
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
get_msg_from_web = GetMessageFromWeb(prompt, mock_llm_config)
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
ret = get_msg_from_web.run()
assert ret == ""
def mock_spark_get_msg_from_web_run(self) -> str:
return resp_cont
@pytest.mark.asyncio
async def test_spark_aask(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
llm = SparkLLM(mock_llm_config_spark)
resp = await llm.aask("Hello!")
assert resp == resp_cont
async def mock_spark_acreate(self, messages, stream):
if stream:
return chunk_iterator(chunks)
else:
return spark_agenerate_result
@pytest.mark.asyncio
async def test_spark_acompletion(mocker):
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
mocker.patch("metagpt.provider.spark_api.SparkLLM.acreate", mock_spark_acreate)
spark_llm = SparkLLM(mock_llm_config)
spark_llm = SparkLLM(mock_llm_config_spark)
resp = await spark_llm.acompletion([])
assert resp == resp_cont
resp = await spark_llm.acompletion([messages])
assert spark_llm.get_choice_text(resp) == resp_cont
await llm_general_chat_funcs_test(spark_llm, prompt, prompt, resp_cont)
await llm_general_chat_funcs_test(spark_llm, prompt, messages, resp_cont)

View file

@ -7,6 +7,7 @@ from llama_index.core.llms import MockLLM
from llama_index.core.schema import Document, NodeWithScore, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.parsers import OmniParse
from metagpt.rag.retrievers import SimpleHybridRetriever
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.schema import BM25RetrieverConfig, ObjectNode
@ -37,6 +38,10 @@ class TestSimpleEngine:
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
@pytest.fixture
def mock_get_file_extractor(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleEngine._get_file_extractor")
def test_from_docs(
self,
mocker,
@ -44,6 +49,7 @@ class TestSimpleEngine:
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
mock_get_file_extractor,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
@ -53,6 +59,8 @@ class TestSimpleEngine:
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
file_extractor = mocker.MagicMock()
mock_get_file_extractor.return_value = file_extractor
# Setup
input_dir = "test_dir"
@ -75,7 +83,9 @@ class TestSimpleEngine:
)
# Assert
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files, fs=None)
mock_simple_directory_reader.assert_called_once_with(
input_dir=input_dir, input_files=input_files, file_extractor=file_extractor, fs=None
)
mock_get_retriever.assert_called_once()
mock_get_rankers.assert_called_once()
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
@ -298,3 +308,17 @@ class TestSimpleEngine:
# Assert
assert "obj" in node.node.metadata
assert node.node.metadata["obj"] == expected_obj
def test_get_file_extractor(self, mocker):
# mock no omniparse config
mock_omniparse_config = mocker.patch("metagpt.rag.engines.simple.config.omniparse", autospec=True)
mock_omniparse_config.base_url = ""
file_extractor = SimpleEngine._get_file_extractor()
assert file_extractor == {}
# mock have omniparse config
mock_omniparse_config.base_url = "http://localhost:8000"
file_extractor = SimpleEngine._get_file_extractor()
assert ".pdf" in file_extractor
assert isinstance(file_extractor[".pdf"], OmniParse)

View file

@ -8,6 +8,7 @@ from metagpt.rag.schema import (
ElasticsearchIndexConfig,
ElasticsearchStoreConfig,
FAISSIndexConfig,
MilvusIndexConfig,
)
@ -20,6 +21,10 @@ class TestRAGIndexFactory:
def faiss_config(self):
return FAISSIndexConfig(persist_path="")
@pytest.fixture
def milvus_config(self):
return MilvusIndexConfig(uri="", collection_name="")
@pytest.fixture
def chroma_config(self):
return ChromaIndexConfig(persist_path="", collection_name="")
@ -65,6 +70,16 @@ class TestRAGIndexFactory:
):
self.index_factory.get_index(bm25_config, embed_model=mock_embedding)
def test_create_milvus_index(self, mocker, milvus_config, mock_from_vector_store, mock_embedding):
# Mock
mock_milvus_store = mocker.patch("metagpt.rag.factories.index.MilvusVectorStore")
# Exec
self.index_factory.get_index(milvus_config, embed_model=mock_embedding)
# Assert
mock_milvus_store.assert_called_once()
def test_create_chroma_index(self, mocker, chroma_config, mock_from_vector_store, mock_embedding):
# Mock
mock_chroma_db = mocker.patch("metagpt.rag.factories.index.chromadb.PersistentClient")

View file

@ -5,6 +5,7 @@ from llama_index.core.embeddings import MockEmbedding
from llama_index.core.schema import TextNode
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.vector_stores.elasticsearch import ElasticsearchStore
from llama_index.vector_stores.milvus import MilvusVectorStore
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
@ -12,12 +13,14 @@ from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.es_retriever import ElasticsearchRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.retrievers.milvus_retriever import MilvusRetriever
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaRetrieverConfig,
ElasticsearchRetrieverConfig,
ElasticsearchStoreConfig,
FAISSRetrieverConfig,
MilvusRetrieverConfig,
)
@ -41,6 +44,10 @@ class TestRetrieverFactory:
def mock_chroma_vector_store(self, mocker):
return mocker.MagicMock(spec=ChromaVectorStore)
@pytest.fixture
def mock_milvus_vector_store(self, mocker):
return mocker.MagicMock(spec=MilvusVectorStore)
@pytest.fixture
def mock_es_vector_store(self, mocker):
return mocker.MagicMock(spec=ElasticsearchStore)
@ -91,6 +98,14 @@ class TestRetrieverFactory:
assert isinstance(retriever, ChromaRetriever)
def test_get_retriever_with_milvus_config(self, mocker, mock_milvus_vector_store, mock_embedding):
mock_config = MilvusRetrieverConfig(uri="/path/to/milvus.db", collection_name="test_collection")
mocker.patch("metagpt.rag.factories.retriever.MilvusVectorStore", return_value=mock_milvus_vector_store)
retriever = self.retriever_factory.get_retriever(configs=[mock_config], nodes=[], embed_model=mock_embedding)
assert isinstance(retriever, MilvusRetriever)
def test_get_retriever_with_es_config(self, mocker, mock_es_vector_store, mock_embedding):
mock_config = ElasticsearchRetrieverConfig(store_config=ElasticsearchStoreConfig())
mocker.patch("metagpt.rag.factories.retriever.ElasticsearchStore", return_value=mock_es_vector_store)

View file

@ -0,0 +1,118 @@
import pytest
from llama_index.core import Document
from metagpt.const import EXAMPLE_DATA_PATH
from metagpt.rag.parsers import OmniParse
from metagpt.rag.schema import (
OmniParsedResult,
OmniParseOptions,
OmniParseType,
ParseResultType,
)
from metagpt.utils.omniparse_client import OmniParseClient
# test data
TEST_DOCX = EXAMPLE_DATA_PATH / "omniparse/test01.docx"
TEST_PDF = EXAMPLE_DATA_PATH / "omniparse/test02.pdf"
TEST_VIDEO = EXAMPLE_DATA_PATH / "omniparse/test03.mp4"
TEST_AUDIO = EXAMPLE_DATA_PATH / "omniparse/test04.mp3"
class TestOmniParseClient:
parse_client = OmniParseClient()
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_parse_pdf(self, mock_request_parse):
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
parse_ret = await self.parse_client.parse_pdf(TEST_PDF)
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_document(self, mock_request_parse):
mock_content = "#test title\ntest_parse_document"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
with open(TEST_DOCX, "rb") as f:
file_bytes = f.read()
with pytest.raises(ValueError):
# bytes data must provide bytes_filename
await self.parse_client.parse_document(file_bytes)
parse_ret = await self.parse_client.parse_document(file_bytes, bytes_filename="test.docx")
assert parse_ret == mock_parsed_ret
@pytest.mark.asyncio
async def test_parse_video(self, mock_request_parse):
mock_content = "#test title\ntest_parse_video"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
with pytest.raises(ValueError):
# Wrong file extension test
await self.parse_client.parse_video(TEST_DOCX)
parse_ret = await self.parse_client.parse_video(TEST_VIDEO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
@pytest.mark.asyncio
async def test_parse_audio(self, mock_request_parse):
mock_content = "#test title\ntest_parse_audio"
mock_request_parse.return_value = {
"text": mock_content,
"metadata": {},
}
parse_ret = await self.parse_client.parse_audio(TEST_AUDIO)
assert "text" in parse_ret and "metadata" in parse_ret
assert parse_ret["text"] == mock_content
class TestOmniParse:
@pytest.fixture
def mock_omniparse(self):
parser = OmniParse(
parse_options=OmniParseOptions(
parse_type=OmniParseType.PDF,
result_type=ParseResultType.MD,
max_timeout=120,
num_workers=3,
)
)
return parser
@pytest.fixture
def mock_request_parse(self, mocker):
return mocker.patch("metagpt.rag.parsers.omniparse.OmniParseClient._request_parse")
@pytest.mark.asyncio
async def test_load_data(self, mock_omniparse, mock_request_parse):
# mock
mock_content = "#test title\ntest content"
mock_parsed_ret = OmniParsedResult(text=mock_content, markdown=mock_content)
mock_request_parse.return_value = mock_parsed_ret.model_dump()
# single file
documents = mock_omniparse.load_data(file_path=TEST_PDF)
doc = documents[0]
assert isinstance(doc, Document)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown
# multi files
file_paths = [TEST_DOCX, TEST_PDF]
mock_omniparse.parse_type = OmniParseType.DOCUMENT
documents = await mock_omniparse.aload_data(file_path=file_paths)
doc = documents[0]
# assert
assert isinstance(doc, Document)
assert len(documents) == len(file_paths)
assert doc.text == mock_parsed_ret.text == mock_parsed_ret.markdown

View file

@ -14,8 +14,8 @@ from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles import ProductManager
from metagpt.utils.common import any_to_str
from tests.metagpt.roles.mock import MockMessages
from metagpt.utils.git_repository import GitRepository
from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio

View file

@ -1,3 +1,4 @@
import tempfile
from pathlib import Path
from random import random
from tempfile import TemporaryDirectory
@ -6,6 +7,7 @@ import pytest
from metagpt.actions.research import CollectLinks
from metagpt.roles import researcher
from metagpt.team import Team
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -57,5 +59,13 @@ def test_write_report(mocker, context):
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
@pytest.mark.asyncio
async def test_serialize():
team = Team()
team.hire([researcher.Researcher()])
with tempfile.TemporaryDirectory() as dirname:
team.serialize(Path(dirname) / "team.json")
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,6 +5,7 @@ import pytest
from metagpt.provider.human_provider import HumanProvider
from metagpt.roles.role import Role
from metagpt.schema import Message, UserMessage
def test_role_desc():
@ -18,5 +19,15 @@ def test_role_human(context):
assert isinstance(role.llm, HumanProvider)
@pytest.mark.asyncio
async def test_recovered():
role = Role(profile="Tester", desc="Tester", recovered=True)
role.put_message(UserMessage(content="2"))
role.latest_observed_msg = Message(content="1")
await role._observe()
await role._observe()
assert role.rc.msg_buffer.empty()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
@ -55,6 +55,7 @@ def test_environment_serdeser(context):
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
def test_environment_serdeser_v2(context):
@ -69,6 +70,7 @@ def test_environment_serdeser_v2(context):
assert isinstance(role, ProjectManager)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
assert list(new_env.roles.values())[0].rc.watch == pm.rc.watch
def test_environment_serdeser_save(context):
@ -85,3 +87,8 @@ def test_environment_serdeser_save(context):
new_env: Environment = Environment(**env_dict, context=context)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
assert list(new_env.roles.values())[0].rc.watch == role_c.rc.watch
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -28,9 +28,9 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_roles(context):
role_a = RoleA()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
role_b = RoleB()
assert len(role_a.rc.watch) == 1
assert len(role_a.rc.watch) == 2
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])

View file

@ -8,9 +8,9 @@ from typing import Optional
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.fix_bug import FixBug
from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
@ -68,7 +68,7 @@ class RoleA(Role):
def __init__(self, **kwargs):
super(RoleA, self).__init__(**kwargs)
self.set_actions([ActionPass])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
class RoleB(Role):
@ -93,7 +93,7 @@ class RoleC(Role):
def __init__(self, **kwargs):
super(RoleC, self).__init__(**kwargs)
self.set_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._watch([FixBug, UserRequirement])
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True

View file

@ -1,17 +0,0 @@
# -*- coding: utf-8 -*-
# @Desc :
import pytest
from metagpt.roles.sk_agent import SkAgent
@pytest.mark.asyncio
async def test_sk_agent_serdeser():
role = SkAgent()
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
assert "name" in ser_role_dict
assert "planner" in ser_role_dict
new_role = SkAgent(**ser_role_dict)
assert new_role.name == "Sunshine"
assert len(new_role.actions) == 1

View file

@ -29,3 +29,7 @@ def div(a: int, b: int = 0):
assert new_action.name == "WriteCodeReview"
await new_action.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,8 +14,8 @@ from tests.metagpt.provider.mock_llm_config import mock_llm_config
def test_config_1():
cfg = Config.default()
llm = cfg.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if cfg.llm.api_type == LLMType.OPENAI:
assert llm is not None
def test_config_from_dict():

View file

@ -53,8 +53,8 @@ def test_context_1():
def test_context_2():
ctx = Context()
llm = ctx.config.get_openai_llm()
assert llm is not None
assert llm.api_type == LLMType.OPENAI
if ctx.config.llm.api_type == LLMType.OPENAI:
assert llm is not None
kwargs = ctx.kwargs
assert kwargs is not None

View file

@ -105,11 +105,11 @@ def test_config_mixin_4_multi_inheritance_override_config():
async def test_config_priority():
"""If action's config is set, then its llm will be set, otherwise, it will use the role's llm"""
home_dir = Path.home() / CONFIG_ROOT
gpt4t = Config.from_home("gpt-4-1106-preview.yaml")
gpt4t = Config.from_home("gpt-4-turbo.yaml")
if not home_dir.exists():
assert gpt4t is None
gpt35 = Config.default()
gpt35.llm.model = "gpt-3.5-turbo-1106"
gpt35.llm.model = "gpt-4-turbo"
gpt4 = Config.default()
gpt4.llm.model = "gpt-4-0613"
@ -127,8 +127,8 @@ async def test_config_priority():
env = Environment(desc="US election live broadcast")
Team(investment=10.0, env=env, roles=[A, B, C])
assert a1.llm.model == "gpt-4-1106-preview" if Path(home_dir / "gpt-4-1106-preview.yaml").exists() else "gpt-4-0613"
assert a1.llm.model == "gpt-4-turbo" if Path(home_dir / "gpt-4-turbo.yaml").exists() else "gpt-4-0613"
assert a2.llm.model == "gpt-4-0613"
assert a3.llm.model == "gpt-3.5-turbo-1106"
assert a3.llm.model == "gpt-4-turbo"
# history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="a1", n_round=3)

View file

@ -57,7 +57,7 @@ def test_add_role(env: Environment):
name="Alice", profile="product manager", goal="create a new product", constraints="limited resources"
)
env.add_role(role)
assert env.get_role(role.profile) == role
assert env.get_role(str(role._setting)) == role
def test_get_roles(env: Environment):

View file

@ -30,5 +30,12 @@ def test_software_company(new_filename):
logger.info(result.output)
def test_software_company_with_run_tests():
args = ["Make a cli snake game", "--run-tests", "--n-round=8"]
result = runner.invoke(app, args)
logger.info(result.output)
# assert "unittest" in result.output.lower() or "pytest" in result.output.lower()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -58,7 +58,7 @@ class TestUTWriter:
)
],
created=1706710532,
model="gpt-3.5-turbo-1106",
model="gpt-4-turbo",
object="chat.completion",
system_fingerprint="fp_04f9a1eebf",
usage=CompletionUsage(completion_tokens=35, prompt_tokens=1982, total_tokens=2017),

View file

@ -12,11 +12,11 @@ from metagpt.utils.cost_manager import CostManager
def test_cost_manager():
cm = CostManager(total_budget=20)
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-1106-preview")
cm.update_cost(prompt_tokens=1000, completion_tokens=100, model="gpt-4-turbo")
assert cm.get_total_prompt_tokens() == 1000
assert cm.get_total_completion_tokens() == 100
assert cm.get_total_cost() == 0.013
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-1106-preview")
cm.update_cost(prompt_tokens=100, completion_tokens=10, model="gpt-4-turbo")
assert cm.get_total_prompt_tokens() == 1100
assert cm.get_total_completion_tokens() == 110
assert cm.get_total_cost() == 0.0143

View file

@ -20,7 +20,7 @@ def _paragraphs(n):
@pytest.mark.parametrize(
"msgs, model_name, system_text, reserved, expected",
"msgs, model, system_text, reserved, expected",
[
(_msgs(), "gpt-3.5-turbo-0613", "System", 1500, 1),
(_msgs(), "gpt-3.5-turbo-16k", "System", 3000, 6),
@ -37,7 +37,7 @@ def test_reduce_message_length(msgs, model_name, system_text, reserved, expected
@pytest.mark.parametrize(
"text, prompt_template, model_name, system_text, reserved, expected",
"text, prompt_template, model, system_text, reserved, expected",
[
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-0613", "System", 1500, 2),
(" ".join("Hello World." for _ in range(1000)), "Prompt: {}", "gpt-3.5-turbo-16k", "System", 3000, 1),

View file

@ -7,7 +7,7 @@
"""
import pytest
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
from metagpt.utils.token_counter import count_message_tokens, count_output_tokens
def test_count_message_tokens():
@ -53,20 +53,20 @@ def test_count_string_tokens():
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
assert count_string_tokens(string, model_name="gpt-3.5-turbo-0301") == 4
assert count_output_tokens(string, model="gpt-3.5-turbo-0301") == 4
def test_count_string_tokens_empty_input():
"""Test that the string tokens are counted correctly."""
assert count_string_tokens("", model_name="gpt-3.5-turbo-0301") == 0
assert count_output_tokens("", model="gpt-3.5-turbo-0301") == 0
def test_count_string_tokens_gpt_4():
"""Test that the string tokens are counted correctly."""
string = "Hello, world!"
assert count_string_tokens(string, model_name="gpt-4-0314") == 4
assert count_output_tokens(string, model="gpt-4-0314") == 4
if __name__ == "__main__":

View file

@ -117,7 +117,6 @@ class MockLLM(OriginalLLM):
raise ValueError(
"In current test setting, api call is not allowed, you should properly mock your tests, "
"or add expected api response in tests/data/rsp_cache.json. "
f"The prompt you want for api call: {msg_key}"
)
# Call the original unmocked method
rsp = await ask_func(*args, **kwargs)