diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 81815e91b..f74c32fea 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -372,16 +372,6 @@ class Role(SerializationMixin, is_polymorphic_base=True): return msg - def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: - news = [] - # Warning, remove `id` here to make it work for recover - observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] - existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] - for idx, new in enumerate(observed_pure): - if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure: - news.append(observed[idx]) - return news - async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. @@ -407,29 +397,6 @@ class Role(SerializationMixin, is_polymorphic_base=True): logger.debug(f"{self._setting} observed: {news_text}") return len(self.rc.news) - # async def _observe(self, ignore_memory=False) -> int: - # """Prepare new messages for processing from the message buffer and other sources.""" - # # Read unprocessed messages from the msg buffer. - # news = self.rc.msg_buffer.pop_all() - # if self.recovered: - # news = [self.latest_observed_msg] if self.latest_observed_msg else [] - # else: - # self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg - # - # # Store the read messages in your own memory to prevent duplicate processing. - # old_messages = [] if ignore_memory else self.rc.memory.get() - # self.rc.memory.add_batch(news) - # # Filter out messages of interest. - # self.rc.news = self._find_news(news, old_messages) - # - # # Design Rules: - # # If you need to further categorize Message objects, you can do so using the Message.set_meta function. - # # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer. - # news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news] - # if news_text: - # logger.debug(f"{self._setting} observed: {news_text}") - # return len(self.rc.news) - def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index 1ad39be59..e4b455c6b 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -5,6 +5,7 @@ @Author : mashenquan @File : redis.py """ +from __future__ import annotations import traceback from datetime import timedelta @@ -22,7 +23,15 @@ class Redis: async def _connect(self, force=False): if self._client and not force: return True - if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None: + is_ready = ( + CONFIG.REDIS_HOST + and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" + and CONFIG.REDIS_PORT + and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" + and CONFIG.REDIS_DB is not None + and CONFIG.REDIS_PASSWORD is not None + ) + if not is_ready: return False try: @@ -37,7 +46,7 @@ class Redis: logger.warning(f"Redis initialization has failed:{e}") return False - async def get(self, key: str) -> bytes: + async def get(self, key: str) -> bytes | None: if not await self._connect() or not key: return None try: diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 33320715c..52d08e92e 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -10,15 +10,17 @@ functionality is to be consolidated into the `Environment` class. """ import uuid +from unittest.mock import MagicMock import pytest from pydantic import BaseModel from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment +from metagpt.provider.base_llm import BaseLLM from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_name, any_to_str class MockAction(Action): @@ -96,7 +98,7 @@ async def test_react(): @pytest.mark.asyncio -async def test_msg_to(): +async def test_send_to(): m = Message(content="a", send_to=["a", MockRole, Message]) assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} @@ -107,5 +109,50 @@ async def test_msg_to(): assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} +def test_init_action(): + role = Role() + role.init_actions([MockAction, MockAction]) + assert role.action_count == 2 + + +@pytest.mark.asyncio +async def test_recover(): + # Mock LLM actions + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.aask.side_effect = ["1"] + + role = Role() + assert role.is_watch(any_to_str(UserRequirement)) + role.put_message(None) + role.publish_message(None) + + role.llm = mock_llm + role.init_actions([MockAction, MockAction]) + role.recovered = True + role.latest_observed_msg = Message(content="recover_test") + role.rc.state = 0 + assert role.todo == any_to_name(MockAction) + + rsp = await role.run() + assert rsp.cause_by == any_to_str(MockAction) + + +@pytest.mark.asyncio +async def test_think_act(): + # Mock LLM actions + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.aask.side_effect = ["ok"] + + role = Role() + role.init_actions([MockAction]) + await role.think() + role.rc.memory.add(Message("run")) + assert len(role.get_memories()) == 1 + rsp = await role.act() + assert rsp + assert isinstance(rsp, ActionOutput) + assert rsp.content == "run" + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 1bf0d4c4c..816c186e2 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -16,8 +16,10 @@ from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode from metagpt.config import CONFIG +from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, + CodeSummarizeContext, Document, Message, MessageQueue, @@ -61,6 +63,8 @@ def test_message(): assert m.role == "b" assert m.send_to == {"c"} assert m.cause_by == "c" + m.sent_from = "e" + assert m.sent_from == "e" m.cause_by = "Message" assert m.cause_by == "Message" @@ -121,6 +125,8 @@ def test_document(): @pytest.mark.asyncio async def test_message_queue(): mq = MessageQueue() + val = await mq.dump() + assert val == "[]" mq.push(Message(content="1")) mq.push(Message(content="2中文测试aaa")) msg = mq.pop() @@ -132,5 +138,23 @@ async def test_message_queue(): assert new_mq.pop_all() == mq.pop_all() +@pytest.mark.parametrize( + ("file_list", "want"), + [ + ( + [f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"], + CodeSummarizeContext( + design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt" + ), + ) + ], +) +def test_CodeSummarizeContext(file_list, want): + ctx = CodeSummarizeContext.loads(file_list) + assert ctx == want + m = {ctx: ctx} + assert want in m + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py index 7c3fd26a9..a75341433 100644 --- a/tests/metagpt/utils/test_redis.py +++ b/tests/metagpt/utils/test_redis.py @@ -27,6 +27,14 @@ async def test_redis(): assert await conn.get("test") == b"test" await conn.close() + key = CONFIG.REDIS_HOST + CONFIG.REDIS_HOST = "YOUR_REDIS_HOST" + conn = Redis() + await conn.set("test", "test", timeout_sec=0) + assert not await conn.get("test") == b"test" + CONFIG.REDIS_HOST = key + await conn.close() + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index edf198028..9906d566f 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -41,17 +41,15 @@ async def test_s3(): res = await conn.cache(data, ".bak", "script") assert "http" in res - -@pytest.mark.asyncio -async def test_s3_no_error(): + key = CONFIG.S3_ACCESS_KEY + CONFIG.S3_ACCESS_KEY = "YOUR_S3_ACCESS_KEY" conn = S3() - key = conn.auth_config["aws_secret_access_key"] - conn.auth_config["aws_secret_access_key"] = "" + assert not conn.is_valid try: res = await conn.cache("ABC", ".bak", "script") assert not res finally: - conn.auth_config["aws_secret_access_key"] = key + CONFIG.S3_ACCESS_KEY = key if __name__ == "__main__":