feat: +unit test

This commit is contained in:
莘权 马 2024-01-02 11:59:03 +08:00
parent 83ee76cca7
commit 2f3e4c7f15
6 changed files with 96 additions and 43 deletions

View file

@ -372,16 +372,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
return msg
def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
news = []
# Warning, remove `id` here to make it work for recover
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
for idx, new in enumerate(observed_pure):
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
news.append(observed[idx])
return news
async def _observe(self, ignore_memory=False) -> int:
"""Prepare new messages for processing from the message buffer and other sources."""
# Read unprocessed messages from the msg buffer.
@ -407,29 +397,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
logger.debug(f"{self._setting} observed: {news_text}")
return len(self.rc.news)
# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self.rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self.rc.memory.get()
# self.rc.memory.add_batch(news)
# # Filter out messages of interest.
# self.rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self.rc.news)
def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:

View file

@ -5,6 +5,7 @@
@Author : mashenquan
@File : redis.py
"""
from __future__ import annotations
import traceback
from datetime import timedelta
@ -22,7 +23,15 @@ class Redis:
async def _connect(self, force=False):
if self._client and not force:
return True
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
is_ready = (
CONFIG.REDIS_HOST
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
and CONFIG.REDIS_PORT
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
and CONFIG.REDIS_DB is not None
and CONFIG.REDIS_PASSWORD is not None
)
if not is_ready:
return False
try:
@ -37,7 +46,7 @@ class Redis:
logger.warning(f"Redis initialization has failed:{e}")
return False
async def get(self, key: str) -> bytes:
async def get(self, key: str) -> bytes | None:
if not await self._connect() or not key:
return None
try:

View file

@ -10,15 +10,17 @@
functionality is to be consolidated into the `Environment` class.
"""
import uuid
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.environment import Environment
from metagpt.provider.base_llm import BaseLLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import any_to_str
from metagpt.utils.common import any_to_name, any_to_str
class MockAction(Action):
@ -96,7 +98,7 @@ async def test_react():
@pytest.mark.asyncio
async def test_msg_to():
async def test_send_to():
m = Message(content="a", send_to=["a", MockRole, Message])
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
@ -107,5 +109,50 @@ async def test_msg_to():
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
def test_init_action():
role = Role()
role.init_actions([MockAction, MockAction])
assert role.action_count == 2
@pytest.mark.asyncio
async def test_recover():
# Mock LLM actions
mock_llm = MagicMock(spec=BaseLLM)
mock_llm.aask.side_effect = ["1"]
role = Role()
assert role.is_watch(any_to_str(UserRequirement))
role.put_message(None)
role.publish_message(None)
role.llm = mock_llm
role.init_actions([MockAction, MockAction])
role.recovered = True
role.latest_observed_msg = Message(content="recover_test")
role.rc.state = 0
assert role.todo == any_to_name(MockAction)
rsp = await role.run()
assert rsp.cause_by == any_to_str(MockAction)
@pytest.mark.asyncio
async def test_think_act():
# Mock LLM actions
mock_llm = MagicMock(spec=BaseLLM)
mock_llm.aask.side_effect = ["ok"]
role = Role()
role.init_actions([MockAction])
await role.think()
role.rc.memory.add(Message("run"))
assert len(role.get_memories()) == 1
rsp = await role.act()
assert rsp
assert isinstance(rsp, ActionOutput)
assert rsp.content == "run"
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -16,8 +16,10 @@ from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.config import CONFIG
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
from metagpt.schema import (
AIMessage,
CodeSummarizeContext,
Document,
Message,
MessageQueue,
@ -61,6 +63,8 @@ def test_message():
assert m.role == "b"
assert m.send_to == {"c"}
assert m.cause_by == "c"
m.sent_from = "e"
assert m.sent_from == "e"
m.cause_by = "Message"
assert m.cause_by == "Message"
@ -121,6 +125,8 @@ def test_document():
@pytest.mark.asyncio
async def test_message_queue():
mq = MessageQueue()
val = await mq.dump()
assert val == "[]"
mq.push(Message(content="1"))
mq.push(Message(content="2中文测试aaa"))
msg = mq.pop()
@ -132,5 +138,23 @@ async def test_message_queue():
assert new_mq.pop_all() == mq.pop_all()
@pytest.mark.parametrize(
("file_list", "want"),
[
(
[f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"],
CodeSummarizeContext(
design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt"
),
)
],
)
def test_CodeSummarizeContext(file_list, want):
ctx = CodeSummarizeContext.loads(file_list)
assert ctx == want
m = {ctx: ctx}
assert want in m
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -27,6 +27,14 @@ async def test_redis():
assert await conn.get("test") == b"test"
await conn.close()
key = CONFIG.REDIS_HOST
CONFIG.REDIS_HOST = "YOUR_REDIS_HOST"
conn = Redis()
await conn.set("test", "test", timeout_sec=0)
assert not await conn.get("test") == b"test"
CONFIG.REDIS_HOST = key
await conn.close()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -41,17 +41,15 @@ async def test_s3():
res = await conn.cache(data, ".bak", "script")
assert "http" in res
@pytest.mark.asyncio
async def test_s3_no_error():
key = CONFIG.S3_ACCESS_KEY
CONFIG.S3_ACCESS_KEY = "YOUR_S3_ACCESS_KEY"
conn = S3()
key = conn.auth_config["aws_secret_access_key"]
conn.auth_config["aws_secret_access_key"] = ""
assert not conn.is_valid
try:
res = await conn.cache("ABC", ".bak", "script")
assert not res
finally:
conn.auth_config["aws_secret_access_key"] = key
CONFIG.S3_ACCESS_KEY = key
if __name__ == "__main__":