Merge pull request #663 from iorisa/feature/unittest

feat: +unit test
This commit is contained in:
geekan 2024-01-02 17:49:23 +08:00 committed by GitHub
commit d2260a5958
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 306 additions and 66 deletions

View file

@ -372,16 +372,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
return msg
def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
news = []
# Warning, remove `id` here to make it work for recover
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
for idx, new in enumerate(observed_pure):
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
news.append(observed[idx])
return news
async def _observe(self, ignore_memory=False) -> int:
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
@ -407,29 +397,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
logger.debug(f"{self._setting} observed: {news_text}")
return len(self.rc.news)
# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self.rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self.rc.memory.get()
# self.rc.memory.add_batch(news)
# # Filter out messages of interest.
# self.rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self.rc.news)
def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:

View file

@ -5,6 +5,7 @@
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations
import traceback
from datetime import timedelta
@ -22,7 +23,7 @@ class Redis:
async def _connect(self, force=False):
if self._client and not force:
return True
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
if not self.is_configured:
return False
try:
@ -37,7 +38,7 @@ class Redis:
logger.warning(f"Redis initialization has failed:{e}")
return False
async def get(self, key: str) -> bytes:
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:
@ -65,3 +66,14 @@ class Redis:
@property
def is_valid(self) -> bool:
return self._client is not None
@property
def is_configured(self) -> bool:
return bool(
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)

View file

@ -154,16 +154,17 @@ class S3:
@property
def is_valid(self):
is_invalid = (
not CONFIG.S3_ACCESS_KEY
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
or not CONFIG.S3_SECRET_KEY
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
or not CONFIG.S3_ENDPOINT_URL
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
or not CONFIG.S3_BUCKET
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
return self.is_configured
@property
def is_configured(self) -> bool:
return bool(
CONFIG.S3_ACCESS_KEY
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
and CONFIG.S3_SECRET_KEY
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
and CONFIG.S3_ENDPOINT_URL
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
and CONFIG.S3_BUCKET
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
)
if is_invalid:
logger.info("S3 is invalid")
return not is_invalid

View file

@ -0,0 +1,92 @@
## game.py
import random
from typing import List, Tuple
class Game:
def __init__(self):
self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)]
self.score: int = 0
self.game_over: bool = False
def reset_game(self):
self.grid = [[0 for _ in range(4)] for _ in range(4)]
self.score = 0
self.game_over = False
self.add_new_tile()
self.add_new_tile()
def move(self, direction: str):
if direction == "up":
self._move_up()
elif direction == "down":
self._move_down()
elif direction == "left":
self._move_left()
elif direction == "right":
self._move_right()
def is_game_over(self) -> bool:
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
return False
if j < 3 and self.grid[i][j] == self.grid[i][j + 1]:
return False
if i < 3 and self.grid[i][j] == self.grid[i + 1][j]:
return False
return True
def get_empty_cells(self) -> List[Tuple[int, int]]:
empty_cells = []
for i in range(4):
for j in range(4):
if self.grid[i][j] == 0:
empty_cells.append((i, j))
return empty_cells
def add_new_tile(self):
empty_cells = self.get_empty_cells()
if empty_cells:
x, y = random.choice(empty_cells)
self.grid[x][y] = 2 if random.random() < 0.9 else 4
def get_score(self) -> int:
return self.score
def _move_up(self):
for j in range(4):
for i in range(1, 4):
if self.grid[i][j] != 0:
for k in range(i, 0, -1):
if self.grid[k - 1][j] == 0:
self.grid[k - 1][j] = self.grid[k][j]
self.grid[k][j] = 0
def _move_down(self):
for j in range(4):
for i in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(i, 3):
if self.grid[k + 1][j] == 0:
self.grid[k + 1][j] = self.grid[k][j]
self.grid[k][j] = 0
def _move_left(self):
for i in range(4):
for j in range(1, 4):
if self.grid[i][j] != 0:
for k in range(j, 0, -1):
if self.grid[i][k - 1] == 0:
self.grid[i][k - 1] = self.grid[i][k]
self.grid[i][k] = 0
def _move_right(self):
for i in range(4):
for j in range(2, -1, -1):
if self.grid[i][j] != 0:
for k in range(j, 3):
if self.grid[i][k + 1] == 0:
self.grid[i][k + 1] = self.grid[i][k]
self.grid[i][k] = 0

View file

@ -15,20 +15,24 @@ from metagpt.learn.text_to_image import text_to_image
@pytest.mark.asyncio
async def test():
async def test_metagpt_llm():
# Prerequisites
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
assert CONFIG.OPENAI_API_KEY
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
CONFIG.set_context(new_options)
try:
data = await text_to_image("Panda emoji", size_type="512x512")
assert "base64" in data or "http" in data
finally:
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
CONFIG.set_context(old_options)
if __name__ == "__main__":

View file

@ -27,13 +27,16 @@ async def test_text_to_speech():
assert "base64" in data or "http" in data
# test iflytek
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
## Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
CONFIG.set_context(new_options)
try:
data = await text_to_speech("panda emoji")
assert "base64" in data or "http" in data
finally:
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
CONFIG.set_context(old_options)
if __name__ == "__main__":

View file

@ -7,18 +7,38 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message handling.
"""
import uuid
import pytest
from metagpt.actions import WriteDesign, WritePRD
from metagpt.config import CONFIG
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_architect():
# FIXME: make git as env? Or should we support
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
role = Architect()
role.put_message(MockMessages.req)
rsp = await role.run(MockMessages.prd)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)
# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert len(rsp.content) > 0
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -5,3 +5,59 @@
@Author : alexanderwu
@File : test_qa_engineer.py
"""
from pathlib import Path
from typing import List
import pytest
from pydantic import Field
from metagpt.actions import DebugError, RunCode, WriteTest
from metagpt.actions.summarize_code import SummarizeCode
from metagpt.config import CONFIG
from metagpt.environment import Environment
from metagpt.roles import QaEngineer
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, aread, awrite
async def test_qa():
# Prerequisites
demo_path = Path(__file__).parent / "../../data/demo_project"
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
class MockEnv(Environment):
msgs: List[Message] = Field(default_factory=list)
def publish_message(self, message: Message, peekable: bool = True) -> bool:
self.msgs.append(message)
return True
env = MockEnv()
role = QaEngineer()
role.set_env(env)
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(WriteTest)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(RunCode)
msg = env.msgs[0]
env.msgs.clear()
await role.run(with_message=msg)
assert env.msgs
assert env.msgs[0].cause_by == any_to_str(DebugError)
msg = env.msgs[0]
env.msgs.clear()
role.test_round_allowed = 1
rsp = await role.run(with_message=msg)
assert "Exceeding" in rsp.content
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -10,15 +10,17 @@
functionality is to be consolidated into the `Environment` class.
"""
import uuid
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.environment import Environment
from metagpt.provider.base_llm import BaseLLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_name, any_to_str
class MockAction(Action):
@ -96,7 +98,7 @@ async def test_react():
@pytest.mark.asyncio
async def test_msg_to():
async def test_send_to():
m = Message(content="a", send_to=["a", MockRole, Message])
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
@ -107,5 +109,50 @@ async def test_msg_to():
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
def test_init_action():
role = Role()
role.init_actions([MockAction, MockAction])
assert role.action_count == 2
@pytest.mark.asyncio
async def test_recover():
# Mock LLM actions
mock_llm = MagicMock(spec=BaseLLM)
mock_llm.aask.side_effect = ["1"]
role = Role()
assert role.is_watch(any_to_str(UserRequirement))
role.put_message(None)
role.publish_message(None)
role.llm = mock_llm
role.init_actions([MockAction, MockAction])
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
assert role.todo == any_to_name(MockAction)
rsp = await role.run()
assert rsp.cause_by == any_to_str(MockAction)
@pytest.mark.asyncio
async def test_think_act():
# Mock LLM actions
mock_llm = MagicMock(spec=BaseLLM)
mock_llm.aask.side_effect = ["ok"]
role = Role()
role.init_actions([MockAction])
await role.think()
role.rc.memory.add(Message("run"))
assert len(role.get_memories()) == 1
rsp = await role.act()
assert rsp
assert isinstance(rsp, ActionOutput)
assert rsp.content == "run"
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -16,8 +16,10 @@ from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
CodeSummarizeContext,
Document,
Message,
MessageQueue,
@ -61,6 +63,8 @@ def test_message():
assert m.role == "b"
assert m.send_to == {"c"}
assert m.cause_by == "c"
m.sent_from = "e"
assert m.sent_from == "e"
m.cause_by = "Message"
assert m.cause_by == "Message"
@ -121,6 +125,8 @@ def test_document():
@pytest.mark.asyncio
async def test_message_queue():
mq = MessageQueue()
val = await mq.dump()
assert val == "[]"
mq.push(Message(content="1"))
mq.push(Message(content="2中文测试aaa"))
msg = mq.pop()
@ -132,5 +138,23 @@ async def test_message_queue():
assert new_mq.pop_all() == mq.pop_all()
@pytest.mark.parametrize(
("file_list", "want"),
[
(
[f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"],
CodeSummarizeContext(
design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt"
),
)
],
)
def test_CodeSummarizeContext(file_list, want):
ctx = CodeSummarizeContext.loads(file_list)
assert ctx == want
m = {ctx: ctx}
assert want in m
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -27,6 +27,19 @@ async def test_redis():
assert await conn.get("test") == b"test"
await conn.close()
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
CONFIG.set_context(new_options)
try:
conn = Redis()
await conn.set("test", "test", timeout_sec=0)
assert not await conn.get("test") == b"test"
await conn.close()
finally:
CONFIG.set_context(old_options)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -41,17 +41,18 @@ async def test_s3():
res = await conn.cache(data, ".bak", "script")
assert "http" in res
@pytest.mark.asyncio
async def test_s3_no_error():
conn = S3()
key = conn.auth_config["aws_secret_access_key"]
conn.auth_config["aws_secret_access_key"] = ""
# Mock session env
old_options = CONFIG.options.copy()
new_options = old_options.copy()
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
CONFIG.set_context(new_options)
try:
conn = S3()
assert not conn.is_valid
res = await conn.cache("ABC", ".bak", "script")
assert not res
finally:
conn.auth_config["aws_secret_access_key"] = key
CONFIG.set_context(old_options)
if __name__ == "__main__":