mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-26 15:49:42 +02:00
Merge branch 'geekan:dev' into dev
This commit is contained in:
commit
9efb9e6e80
50 changed files with 549 additions and 635 deletions
|
|
@ -94,6 +94,10 @@ class Context:
|
|||
|
||||
@property
|
||||
def llm_api(self):
|
||||
# 1. 初始化llm,带有缓存结果
|
||||
# 2. 如果缓存query,那么直接返回缓存结果
|
||||
# 3. 如果没有缓存query,那么调用llm_api,返回结果
|
||||
# 4. 如果有缓存query,那么更新缓存结果
|
||||
return self._llm_api
|
||||
|
||||
|
||||
|
|
|
|||
92
tests/data/demo_project/game.py
Normal file
92
tests/data/demo_project/game.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
## game.py
|
||||
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)]
|
||||
self.score: int = 0
|
||||
self.game_over: bool = False
|
||||
|
||||
def reset_game(self):
|
||||
self.grid = [[0 for _ in range(4)] for _ in range(4)]
|
||||
self.score = 0
|
||||
self.game_over = False
|
||||
self.add_new_tile()
|
||||
self.add_new_tile()
|
||||
|
||||
def move(self, direction: str):
|
||||
if direction == "up":
|
||||
self._move_up()
|
||||
elif direction == "down":
|
||||
self._move_down()
|
||||
elif direction == "left":
|
||||
self._move_left()
|
||||
elif direction == "right":
|
||||
self._move_right()
|
||||
|
||||
def is_game_over(self) -> bool:
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
if self.grid[i][j] == 0:
|
||||
return False
|
||||
if j < 3 and self.grid[i][j] == self.grid[i][j + 1]:
|
||||
return False
|
||||
if i < 3 and self.grid[i][j] == self.grid[i + 1][j]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_empty_cells(self) -> List[Tuple[int, int]]:
|
||||
empty_cells = []
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
if self.grid[i][j] == 0:
|
||||
empty_cells.append((i, j))
|
||||
return empty_cells
|
||||
|
||||
def add_new_tile(self):
|
||||
empty_cells = self.get_empty_cells()
|
||||
if empty_cells:
|
||||
x, y = random.choice(empty_cells)
|
||||
self.grid[x][y] = 2 if random.random() < 0.9 else 4
|
||||
|
||||
def get_score(self) -> int:
|
||||
return self.score
|
||||
|
||||
def _move_up(self):
|
||||
for j in range(4):
|
||||
for i in range(1, 4):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(i, 0, -1):
|
||||
if self.grid[k - 1][j] == 0:
|
||||
self.grid[k - 1][j] = self.grid[k][j]
|
||||
self.grid[k][j] = 0
|
||||
|
||||
def _move_down(self):
|
||||
for j in range(4):
|
||||
for i in range(2, -1, -1):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(i, 3):
|
||||
if self.grid[k + 1][j] == 0:
|
||||
self.grid[k + 1][j] = self.grid[k][j]
|
||||
self.grid[k][j] = 0
|
||||
|
||||
def _move_left(self):
|
||||
for i in range(4):
|
||||
for j in range(1, 4):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(j, 0, -1):
|
||||
if self.grid[i][k - 1] == 0:
|
||||
self.grid[i][k - 1] = self.grid[i][k]
|
||||
self.grid[i][k] = 0
|
||||
|
||||
def _move_right(self):
|
||||
for i in range(4):
|
||||
for j in range(2, -1, -1):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(j, 3):
|
||||
if self.grid[i][k + 1] == 0:
|
||||
self.grid[i][k + 1] = self.grid[i][k]
|
||||
self.grid[i][k] = 0
|
||||
|
|
@ -5,6 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
"""
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action, ActionType, WritePRD, WriteTest
|
||||
|
||||
|
||||
|
|
@ -18,3 +20,22 @@ def test_action_type():
|
|||
assert ActionType.WRITE_TEST.value == WriteTest
|
||||
assert ActionType.WRITE_PRD.name == "WRITE_PRD"
|
||||
assert ActionType.WRITE_TEST.name == "WRITE_TEST"
|
||||
|
||||
|
||||
def test_simple_action():
|
||||
action = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
assert action.name == "AlexSay"
|
||||
assert action.node.instruction == "Express your opinion with emotion and don't repeat it"
|
||||
|
||||
|
||||
def test_empty_action():
|
||||
action = Action()
|
||||
assert action.name == "Action"
|
||||
assert not action.node
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_action_exception():
|
||||
action = Action()
|
||||
with pytest.raises(NotImplementedError):
|
||||
await action.run()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -20,35 +21,35 @@ from metagpt.team import Team
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_two_roles():
|
||||
action1 = Action(name="BidenSay", instruction="Express opinions and argue vigorously, and strive to gain votes")
|
||||
action2 = Action(name="TrumpSay", instruction="Express opinions and argue vigorously, and strive to gain votes")
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Trump", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "Biden" in history
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role_in_env():
|
||||
action = Action(name="Debate", instruction="Express opinions and argue vigorously, and strive to gain votes")
|
||||
biden = Role(name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Biden", n_round=3)
|
||||
assert "Biden" in history
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role():
|
||||
action = Action(name="Debate", instruction="Express opinions and argue vigorously, and strive to gain votes")
|
||||
biden = Role(name="Biden", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
|
||||
|
||||
assert len(msg.content) > 10
|
||||
|
|
@ -113,6 +114,10 @@ t_dict = {
|
|||
"Anything UNCLEAR": "We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?",
|
||||
}
|
||||
|
||||
t_dict_min = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
}
|
||||
|
||||
WRITE_TASKS_OUTPUT_MAPPING = {
|
||||
"Required Python third-party packages": (str, ...),
|
||||
"Required Other language third-party packages": (str, ...),
|
||||
|
|
@ -139,11 +144,19 @@ def test_create_model_class():
|
|||
assert output.schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_missing():
|
||||
def test_create_model_class_with_fields_unrecognized():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
_ = test_class(**t_dict) # 这里应该要挂掉
|
||||
_ = test_class(**t_dict) # just warning
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_missing():
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_ = test_class(**t_dict_min)
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
|
|
|
|||
|
|
@ -30,3 +30,11 @@ async def test_search_xlsx():
|
|||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
|
|
|
|||
|
|
@ -15,20 +15,24 @@ from metagpt.learn.text_to_image import text_to_image
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
async def test_metagpt_llm():
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -27,13 +27,16 @@ async def test_text_to_speech():
|
|||
assert "base64" in data or "http" in data
|
||||
|
||||
# test iflytek
|
||||
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
|
||||
## Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ from metagpt.provider.general_api_base import (
|
|||
APIRequestor,
|
||||
ApiType,
|
||||
OpenAIResponse,
|
||||
_aiohttp_proxies_arg,
|
||||
_build_api_url,
|
||||
_make_session,
|
||||
_requests_proxies_arg,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warn,
|
||||
logfmt,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
|
@ -36,6 +39,10 @@ def test_basic():
|
|||
log_warn("warn")
|
||||
log_info("info")
|
||||
|
||||
logfmt({"k1": b"v1", "k2": 1, "k3": "a b"})
|
||||
|
||||
_build_api_url(url="http://www.baidu.com/s?wd=", query="baidu")
|
||||
|
||||
|
||||
def test_openai_response():
|
||||
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
|
||||
|
|
@ -53,11 +60,18 @@ def test_proxy():
|
|||
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
|
||||
proxy_dict = {"http": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
proxy_dict = {"https": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
assert _aiohttp_proxies_arg(proxy_dict) == proxy
|
||||
|
||||
assert _make_session() is not None
|
||||
|
||||
assert _aiohttp_proxies_arg(None) is None
|
||||
assert _aiohttp_proxies_arg("test") == "test"
|
||||
with pytest.raises(ValueError):
|
||||
_aiohttp_proxies_arg(-1)
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
|
|
@ -83,6 +97,29 @@ async def mock_interpret_async_response(
|
|||
return b"baidu", True
|
||||
|
||||
|
||||
def test_requestor_headers():
|
||||
# validate_headers
|
||||
headers = api_requestor._validate_headers(None)
|
||||
assert not headers
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers(-1)
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({1: 2})
|
||||
with pytest.raises(Exception):
|
||||
api_requestor._validate_headers({"test": 1})
|
||||
supplied_headers = {"test": "test"}
|
||||
assert api_requestor._validate_headers(supplied_headers) == supplied_headers
|
||||
|
||||
api_requestor.organization = "test"
|
||||
api_requestor.api_version = "test123"
|
||||
api_requestor.api_type = ApiType.OPEN_AI
|
||||
request_id = "test123"
|
||||
headers = api_requestor.request_headers(method="post", extra={}, request_id=request_id)
|
||||
assert headers["LLM-Organization"] == api_requestor.organization
|
||||
assert headers["LLM-Version"] == api_requestor.api_version
|
||||
assert headers["X-Request-Id"] == request_id
|
||||
|
||||
|
||||
def test_api_requestor(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
|
|
|
|||
|
|
@ -7,23 +7,25 @@ import pytest
|
|||
from metagpt.provider.human_provider import HumanProvider
|
||||
|
||||
resp_content = "test"
|
||||
|
||||
|
||||
def mock_llm_ask(msg: str, timeout: int = 3) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
||||
return mock_llm_ask(msg)
|
||||
resp_exit = "exit"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
mocker.patch("builtins.input", lambda _: resp_content)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
resp = human_provider.ask(resp_content)
|
||||
assert resp == resp_content
|
||||
resp = await human_provider.aask(None)
|
||||
assert resp_content == resp
|
||||
|
||||
mocker.patch("builtins.input", lambda _: resp_exit)
|
||||
with pytest.raises(SystemExit):
|
||||
human_provider.ask(resp_exit)
|
||||
|
||||
resp = await human_provider.acompletion([])
|
||||
assert not resp
|
||||
|
||||
resp = await human_provider.acompletion_text([])
|
||||
assert resp == ""
|
||||
|
|
|
|||
|
|
@ -17,10 +17,23 @@ prompt_msg = "who are you"
|
|||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def test_get_msg_from_web():
|
||||
class MockWebSocketApp(object):
|
||||
def __init__(self, ws_url, on_message=None, on_error=None, on_close=None, on_open=None):
|
||||
pass
|
||||
|
||||
def run_forever(self, sslopt=None):
|
||||
pass
|
||||
|
||||
|
||||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
||||
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
|
|
@ -29,6 +42,7 @@ def mock_spark_get_msg_from_web_run(self) -> str:
|
|||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
|
|
|
|||
|
|
@ -16,3 +16,11 @@ async def test_async_sse_client():
|
|||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
|||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = {"result": "test response"}
|
||||
default_resp = b'{"result": "test response"}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -39,3 +39,6 @@ async def test_zhipu_model_api(mocker):
|
|||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
|
|
|
|||
|
|
@ -7,19 +7,39 @@
|
|||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, awrite
|
||||
from tests.metagpt.roles.mock import MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("llm_mock")
|
||||
async def test_architect():
|
||||
# FIXME: make git as env? Or should we support
|
||||
# Prerequisites
|
||||
filename = uuid.uuid4().hex + ".json"
|
||||
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
|
||||
|
||||
role = Architect()
|
||||
role.put_message(MockMessages.req)
|
||||
rsp = await role.run(MockMessages.prd)
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
|
||||
# test update
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
assert rsp
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
assert len(rsp.content) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,3 +5,59 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_qa_engineer.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import QaEngineer
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, aread, awrite
|
||||
|
||||
|
||||
async def test_qa():
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
|
||||
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
|
||||
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
|
||||
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
|
||||
|
||||
class MockEnv(Environment):
|
||||
msgs: List[Message] = Field(default_factory=list)
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
self.msgs.append(message)
|
||||
return True
|
||||
|
||||
env = MockEnv()
|
||||
|
||||
role = QaEngineer()
|
||||
role.set_env(env)
|
||||
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(WriteTest)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(RunCode)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(DebugError)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
role.test_round_allowed = 1
|
||||
rsp = await role.run(with_message=msg)
|
||||
assert "Exceeding" in rsp.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -28,6 +28,6 @@ async def test_action_deserialize():
|
|||
|
||||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == ""
|
||||
assert new_action.name == "Action"
|
||||
assert isinstance(new_action.llm, type(LLM()))
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ async def test_write_design_deserialize():
|
|||
action = WriteDesign()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.name == "WriteDesign"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
|
|
@ -37,5 +37,5 @@ async def test_write_task_deserialize():
|
|||
action = WriteTasks()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
assert new_action.name == "CreateTasks"
|
||||
assert new_action.name == "WriteTasks"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ async def test_action_deserialize(style: str, part: str):
|
|||
|
||||
new_action = WriteDocstring(**serialized_data)
|
||||
|
||||
assert not new_action.name
|
||||
assert new_action.name == "WriteDocstring"
|
||||
assert new_action.desc == "Write docstring for code."
|
||||
ret = await new_action.run(code, style=style)
|
||||
assert part in ret
|
||||
|
|
|
|||
33
tests/metagpt/test_document.py
Normal file
33
tests/metagpt/test_document.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/2 21:00
|
||||
@Author : alexanderwu
|
||||
@File : test_document.py
|
||||
"""
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.document import Repo
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def set_existing_repo(path):
|
||||
repo1 = Repo.from_path(path)
|
||||
repo1.set("doc/wtf_file.md", "wtf content")
|
||||
repo1.set("code/wtf_file.py", "def hello():\n print('hello')")
|
||||
logger.info(repo1) # check doc
|
||||
|
||||
|
||||
def load_existing_repo(path):
|
||||
repo = Repo.from_path(path)
|
||||
logger.info(repo)
|
||||
logger.info(repo.eda())
|
||||
|
||||
assert repo
|
||||
assert repo.get("doc/wtf_file.md").content == "wtf content"
|
||||
assert repo.get("code/wtf_file.py").content == "def hello():\n print('hello')"
|
||||
|
||||
|
||||
def test_repo_set_load():
|
||||
repo_path = CONFIG.workspace_path / "test_repo"
|
||||
set_existing_repo(repo_path)
|
||||
load_existing_repo(repo_path)
|
||||
|
|
@ -10,15 +10,17 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, UserRequirement
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_name, any_to_str
|
||||
|
||||
|
||||
class MockAction(Action):
|
||||
|
|
@ -96,7 +98,7 @@ async def test_react():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msg_to():
|
||||
async def test_send_to():
|
||||
m = Message(content="a", send_to=["a", MockRole, Message])
|
||||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
|
|
@ -107,5 +109,50 @@ async def test_msg_to():
|
|||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
|
||||
def test_init_action():
|
||||
role = Role()
|
||||
role.init_actions([MockAction, MockAction])
|
||||
assert role.action_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover():
|
||||
# Mock LLM actions
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.aask.side_effect = ["1"]
|
||||
|
||||
role = Role()
|
||||
assert role.is_watch(any_to_str(UserRequirement))
|
||||
role.put_message(None)
|
||||
role.publish_message(None)
|
||||
|
||||
role.llm = mock_llm
|
||||
role.init_actions([MockAction, MockAction])
|
||||
role.recovered = True
|
||||
role.latest_observed_msg = Message(content="recover_test")
|
||||
role.rc.state = 0
|
||||
assert role.todo == any_to_name(MockAction)
|
||||
|
||||
rsp = await role.run()
|
||||
assert rsp.cause_by == any_to_str(MockAction)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_think_act():
|
||||
# Mock LLM actions
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.aask.side_effect = ["ok"]
|
||||
|
||||
role = Role()
|
||||
role.init_actions([MockAction])
|
||||
await role.think()
|
||||
role.rc.memory.add(Message("run"))
|
||||
assert len(role.get_memories()) == 1
|
||||
rsp = await role.act()
|
||||
assert rsp
|
||||
assert isinstance(rsp, ActionOutput)
|
||||
assert rsp.content == "run"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -16,8 +16,10 @@ from metagpt.actions import Action
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
CodeSummarizeContext,
|
||||
Document,
|
||||
Message,
|
||||
MessageQueue,
|
||||
|
|
@ -61,6 +63,8 @@ def test_message():
|
|||
assert m.role == "b"
|
||||
assert m.send_to == {"c"}
|
||||
assert m.cause_by == "c"
|
||||
m.sent_from = "e"
|
||||
assert m.sent_from == "e"
|
||||
|
||||
m.cause_by = "Message"
|
||||
assert m.cause_by == "Message"
|
||||
|
|
@ -121,6 +125,8 @@ def test_document():
|
|||
@pytest.mark.asyncio
|
||||
async def test_message_queue():
|
||||
mq = MessageQueue()
|
||||
val = await mq.dump()
|
||||
assert val == "[]"
|
||||
mq.push(Message(content="1"))
|
||||
mq.push(Message(content="2中文测试aaa"))
|
||||
msg = mq.pop()
|
||||
|
|
@ -132,5 +138,23 @@ async def test_message_queue():
|
|||
assert new_mq.pop_all() == mq.pop_all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_list", "want"),
|
||||
[
|
||||
(
|
||||
[f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"],
|
||||
CodeSummarizeContext(
|
||||
design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_CodeSummarizeContext(file_list, want):
|
||||
ctx = CodeSummarizeContext.loads(file_list)
|
||||
assert ctx == want
|
||||
m = {ctx: ctx}
|
||||
assert want in m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from metagpt.config import CONFIG
|
|||
@pytest.mark.asyncio
|
||||
async def test_hello():
|
||||
workdir = Path(__file__).parent.parent.parent.parent
|
||||
script_pathname = workdir / "metagpt/tools/hello.py"
|
||||
script_pathname = workdir / "metagpt/tools/openapi_v3_hello.py"
|
||||
env = CONFIG.new_environ()
|
||||
env["PYTHONPATH"] = str(workdir) + ":" + env.get("PYTHONPATH", "")
|
||||
process = subprocess.Popen(["python", str(script_pathname)], cwd=workdir, env=env)
|
||||
|
|
|
|||
|
|
@ -91,6 +91,10 @@ class TestGetProjectRoot:
|
|||
x=(TutorialAssistant, RunCode(), "a"),
|
||||
want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"},
|
||||
),
|
||||
Input(
|
||||
x={"a": TutorialAssistant, "b": RunCode(), "c": "a"},
|
||||
want={"a", "metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode"},
|
||||
),
|
||||
]
|
||||
for i in inputs:
|
||||
v = any_to_str_set(i.x)
|
||||
|
|
|
|||
|
|
@ -119,95 +119,7 @@ def test_extract_struct(
|
|||
case()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
t_text = '''
|
||||
## Required Python third-party packages
|
||||
```python
|
||||
"""
|
||||
flask==1.1.2
|
||||
pygame==2.0.1
|
||||
"""
|
||||
```
|
||||
|
||||
## Required Other language third-party packages
|
||||
```python
|
||||
"""
|
||||
No third-party packages required for other languages.
|
||||
"""
|
||||
```
|
||||
|
||||
## Full API spec
|
||||
```python
|
||||
"""
|
||||
openapi: 3.0.0
|
||||
info:
|
||||
title: Web Snake Game API
|
||||
version: 1.0.0
|
||||
paths:
|
||||
/game:
|
||||
get:
|
||||
summary: Get the current game state
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON object of the game state
|
||||
post:
|
||||
summary: Send a command to the game
|
||||
requestBody:
|
||||
required: true
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
command:
|
||||
type: string
|
||||
responses:
|
||||
'200':
|
||||
description: A JSON object of the updated game state
|
||||
"""
|
||||
```
|
||||
|
||||
## Logic Analysis
|
||||
```python
|
||||
[
|
||||
("app.py", "Main entry point for the Flask application. Handles HTTP requests and responses."),
|
||||
("game.py", "Contains the Game and Snake classes. Handles the game logic."),
|
||||
("static/js/script.js", "Handles user interactions and updates the game UI."),
|
||||
("static/css/styles.css", "Defines the styles for the game UI."),
|
||||
("templates/index.html", "The main page of the web application. Displays the game UI.")
|
||||
]
|
||||
```
|
||||
|
||||
## Task list
|
||||
```python
|
||||
[
|
||||
"game.py",
|
||||
"app.py",
|
||||
"static/css/styles.css",
|
||||
"static/js/script.js",
|
||||
"templates/index.html"
|
||||
]
|
||||
```
|
||||
|
||||
## Shared Knowledge
|
||||
```python
|
||||
"""
|
||||
'game.py' contains the Game and Snake classes which are responsible for the game logic. The Game class uses an instance of the Snake class.
|
||||
|
||||
'app.py' is the main entry point for the Flask application. It creates an instance of the Game class and handles HTTP requests and responses.
|
||||
|
||||
'static/js/script.js' is responsible for handling user interactions and updating the game UI based on the game state returned by 'app.py'.
|
||||
|
||||
'static/css/styles.css' defines the styles for the game UI.
|
||||
|
||||
'templates/index.html' is the main page of the web application. It displays the game UI and loads 'static/js/script.js' and 'static/css/styles.css'.
|
||||
"""
|
||||
```
|
||||
|
||||
## Anything UNCLEAR
|
||||
We need clarification on how the high score should be stored. Should it persist across sessions (stored in a database or a file) or should it reset every time the game is restarted? Also, should the game speed increase as the snake grows, or should it remain constant throughout the game?
|
||||
'''
|
||||
|
||||
def test_parse_with_markdown_mapping():
|
||||
OUTPUT_MAPPING = {
|
||||
"Original Requirements": (str, ...),
|
||||
"Product Goals": (List[str], ...),
|
||||
|
|
@ -218,7 +130,7 @@ We need clarification on how the high score should be stored. Should it persist
|
|||
"Requirement Pool": (List[Tuple[str, str]], ...),
|
||||
"Anything UNCLEAR": (str, ...),
|
||||
}
|
||||
t_text1 = """## Original Requirements:
|
||||
t_text_with_content_tag = """[CONTENT]## Original Requirements:
|
||||
|
||||
The user wants to create a web-based version of the game "Fly Bird".
|
||||
|
||||
|
|
@ -286,8 +198,11 @@ The product should be a web-based version of the game "Fly Bird" that is engagin
|
|||
## Anything UNCLEAR:
|
||||
|
||||
There are no unclear points.
|
||||
"""
|
||||
d = OutputParser.parse_data_with_mapping(t_text1, OUTPUT_MAPPING)
|
||||
[/CONTENT]"""
|
||||
t_text_raw = t_text_with_content_tag.replace("[CONTENT]", "").replace("[/CONTENT]", "")
|
||||
d = OutputParser.parse_data_with_mapping(t_text_with_content_tag, OUTPUT_MAPPING)
|
||||
|
||||
import json
|
||||
|
||||
print(json.dumps(d))
|
||||
assert d["Original Requirements"] == t_text_raw.split("## Original Requirements:")[1].split("##")[0].strip()
|
||||
|
|
|
|||
|
|
@ -27,6 +27,19 @@ async def test_redis():
|
|||
assert await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
conn = Redis()
|
||||
await conn.set("test", "test", timeout_sec=0)
|
||||
assert not await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -41,17 +41,18 @@ async def test_s3():
|
|||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_no_error():
|
||||
conn = S3()
|
||||
key = conn.auth_config["aws_secret_access_key"]
|
||||
conn.auth_config["aws_secret_access_key"] = ""
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
conn = S3()
|
||||
assert not conn.is_valid
|
||||
res = await conn.cache("ABC", ".bak", "script")
|
||||
assert not res
|
||||
finally:
|
||||
conn.auth_config["aws_secret_access_key"] = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue