From 2f3e4c7f1555745718bcfab002723c2d4fc654df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 2 Jan 2024 11:59:03 +0800 Subject: [PATCH 1/3] feat: +unit test --- metagpt/roles/role.py | 33 -------------------- metagpt/utils/redis.py | 13 ++++++-- tests/metagpt/test_role.py | 51 +++++++++++++++++++++++++++++-- tests/metagpt/test_schema.py | 24 +++++++++++++++ tests/metagpt/utils/test_redis.py | 8 +++++ tests/metagpt/utils/test_s3.py | 10 +++--- 6 files changed, 96 insertions(+), 43 deletions(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 81815e91b..f74c32fea 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -372,16 +372,6 @@ class Role(SerializationMixin, is_polymorphic_base=True): return msg - def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]: - news = [] - # Warning, remove `id` here to make it work for recover - observed_pure = [msg.dict(exclude={"id": True}) for msg in observed] - existed_pure = [msg.dict(exclude={"id": True}) for msg in existed] - for idx, new in enumerate(observed_pure): - if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure: - news.append(observed[idx]) - return news - async def _observe(self, ignore_memory=False) -> int: """Prepare new messages for processing from the message buffer and other sources.""" # Read unprocessed messages from the msg buffer. @@ -407,29 +397,6 @@ class Role(SerializationMixin, is_polymorphic_base=True): logger.debug(f"{self._setting} observed: {news_text}") return len(self.rc.news) - # async def _observe(self, ignore_memory=False) -> int: - # """Prepare new messages for processing from the message buffer and other sources.""" - # # Read unprocessed messages from the msg buffer. - # news = self.rc.msg_buffer.pop_all() - # if self.recovered: - # news = [self.latest_observed_msg] if self.latest_observed_msg else [] - # else: - # self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg - # - # # Store the read messages in your own memory to prevent duplicate processing. - # old_messages = [] if ignore_memory else self.rc.memory.get() - # self.rc.memory.add_batch(news) - # # Filter out messages of interest. - # self.rc.news = self._find_news(news, old_messages) - # - # # Design Rules: - # # If you need to further categorize Message objects, you can do so using the Message.set_meta function. - # # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer. - # news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news] - # if news_text: - # logger.debug(f"{self._setting} observed: {news_text}") - # return len(self.rc.news) - def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" if not msg: diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index 1ad39be59..e4b455c6b 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -5,6 +5,7 @@ @Author : mashenquan @File : redis.py """ +from __future__ import annotations import traceback from datetime import timedelta @@ -22,7 +23,15 @@ class Redis: async def _connect(self, force=False): if self._client and not force: return True - if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None: + is_ready = ( + CONFIG.REDIS_HOST + and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" + and CONFIG.REDIS_PORT + and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" + and CONFIG.REDIS_DB is not None + and CONFIG.REDIS_PASSWORD is not None + ) + if not is_ready: return False try: @@ -37,7 +46,7 @@ class Redis: logger.warning(f"Redis initialization has failed:{e}") return False - async def get(self, key: str) -> bytes: + async def get(self, key: str) -> bytes | None: if not await self._connect() or not key: return None try: diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 33320715c..52d08e92e 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -10,15 +10,17 @@ functionality is to be consolidated into the `Environment` class. """ import uuid +from unittest.mock import MagicMock import pytest from pydantic import BaseModel from metagpt.actions import Action, ActionOutput, UserRequirement from metagpt.environment import Environment +from metagpt.provider.base_llm import BaseLLM from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import any_to_str +from metagpt.utils.common import any_to_name, any_to_str class MockAction(Action): @@ -96,7 +98,7 @@ async def test_react(): @pytest.mark.asyncio -async def test_msg_to(): +async def test_send_to(): m = Message(content="a", send_to=["a", MockRole, Message]) assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} @@ -107,5 +109,50 @@ async def test_msg_to(): assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)} +def test_init_action(): + role = Role() + role.init_actions([MockAction, MockAction]) + assert role.action_count == 2 + + +@pytest.mark.asyncio +async def test_recover(): + # Mock LLM actions + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.aask.side_effect = ["1"] + + role = Role() + assert role.is_watch(any_to_str(UserRequirement)) + role.put_message(None) + role.publish_message(None) + + role.llm = mock_llm + role.init_actions([MockAction, MockAction]) + role.recovered = True + role.latest_observed_msg = Message(content="recover_test") + role.rc.state = 0 + assert role.todo == any_to_name(MockAction) + + rsp = await role.run() + assert rsp.cause_by == any_to_str(MockAction) + + +@pytest.mark.asyncio +async def test_think_act(): + # Mock LLM actions + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.aask.side_effect = ["ok"] + + role = Role() + role.init_actions([MockAction]) + await role.think() + role.rc.memory.add(Message("run")) + assert len(role.get_memories()) == 1 + rsp = await role.act() + assert rsp + assert isinstance(rsp, ActionOutput) + assert rsp.content == "run" + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index 1bf0d4c4c..816c186e2 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -16,8 +16,10 @@ from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.actions.write_code import WriteCode from metagpt.config import CONFIG +from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO from metagpt.schema import ( AIMessage, + CodeSummarizeContext, Document, Message, MessageQueue, @@ -61,6 +63,8 @@ def test_message(): assert m.role == "b" assert m.send_to == {"c"} assert m.cause_by == "c" + m.sent_from = "e" + assert m.sent_from == "e" m.cause_by = "Message" assert m.cause_by == "Message" @@ -121,6 +125,8 @@ def test_document(): @pytest.mark.asyncio async def test_message_queue(): mq = MessageQueue() + val = await mq.dump() + assert val == "[]" mq.push(Message(content="1")) mq.push(Message(content="2中文测试aaa")) msg = mq.pop() @@ -132,5 +138,23 @@ async def test_message_queue(): assert new_mq.pop_all() == mq.pop_all() +@pytest.mark.parametrize( + ("file_list", "want"), + [ + ( + [f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"], + CodeSummarizeContext( + design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt" + ), + ) + ], +) +def test_CodeSummarizeContext(file_list, want): + ctx = CodeSummarizeContext.loads(file_list) + assert ctx == want + m = {ctx: ctx} + assert want in m + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py index 7c3fd26a9..a75341433 100644 --- a/tests/metagpt/utils/test_redis.py +++ b/tests/metagpt/utils/test_redis.py @@ -27,6 +27,14 @@ async def test_redis(): assert await conn.get("test") == b"test" await conn.close() + key = CONFIG.REDIS_HOST + CONFIG.REDIS_HOST = "YOUR_REDIS_HOST" + conn = Redis() + await conn.set("test", "test", timeout_sec=0) + assert not await conn.get("test") == b"test" + CONFIG.REDIS_HOST = key + await conn.close() + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index edf198028..9906d566f 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -41,17 +41,15 @@ async def test_s3(): res = await conn.cache(data, ".bak", "script") assert "http" in res - -@pytest.mark.asyncio -async def test_s3_no_error(): + key = CONFIG.S3_ACCESS_KEY + CONFIG.S3_ACCESS_KEY = "YOUR_S3_ACCESS_KEY" conn = S3() - key = conn.auth_config["aws_secret_access_key"] - conn.auth_config["aws_secret_access_key"] = "" + assert not conn.is_valid try: res = await conn.cache("ABC", ".bak", "script") assert not res finally: - conn.auth_config["aws_secret_access_key"] = key + CONFIG.S3_ACCESS_KEY = key if __name__ == "__main__": From d9c5809ccd6acc74aea3307c9ef0f5069f530c61 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 2 Jan 2024 14:21:55 +0800 Subject: [PATCH 2/3] feat: +qa unit test --- tests/data/demo_project/game.py | 92 +++++++++++++++++++++++++ tests/metagpt/roles/test_qa_engineer.py | 56 +++++++++++++++ 2 files changed, 148 insertions(+) create mode 100644 tests/data/demo_project/game.py diff --git a/tests/data/demo_project/game.py b/tests/data/demo_project/game.py new file mode 100644 index 000000000..22e77b260 --- /dev/null +++ b/tests/data/demo_project/game.py @@ -0,0 +1,92 @@ +## game.py + +import random +from typing import List, Tuple + + +class Game: + def __init__(self): + self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)] + self.score: int = 0 + self.game_over: bool = False + + def reset_game(self): + self.grid = [[0 for _ in range(4)] for _ in range(4)] + self.score = 0 + self.game_over = False + self.add_new_tile() + self.add_new_tile() + + def move(self, direction: str): + if direction == "up": + self._move_up() + elif direction == "down": + self._move_down() + elif direction == "left": + self._move_left() + elif direction == "right": + self._move_right() + + def is_game_over(self) -> bool: + for i in range(4): + for j in range(4): + if self.grid[i][j] == 0: + return False + if j < 3 and self.grid[i][j] == self.grid[i][j + 1]: + return False + if i < 3 and self.grid[i][j] == self.grid[i + 1][j]: + return False + return True + + def get_empty_cells(self) -> List[Tuple[int, int]]: + empty_cells = [] + for i in range(4): + for j in range(4): + if self.grid[i][j] == 0: + empty_cells.append((i, j)) + return empty_cells + + def add_new_tile(self): + empty_cells = self.get_empty_cells() + if empty_cells: + x, y = random.choice(empty_cells) + self.grid[x][y] = 2 if random.random() < 0.9 else 4 + + def get_score(self) -> int: + return self.score + + def _move_up(self): + for j in range(4): + for i in range(1, 4): + if self.grid[i][j] != 0: + for k in range(i, 0, -1): + if self.grid[k - 1][j] == 0: + self.grid[k - 1][j] = self.grid[k][j] + self.grid[k][j] = 0 + + def _move_down(self): + for j in range(4): + for i in range(2, -1, -1): + if self.grid[i][j] != 0: + for k in range(i, 3): + if self.grid[k + 1][j] == 0: + self.grid[k + 1][j] = self.grid[k][j] + self.grid[k][j] = 0 + + def _move_left(self): + for i in range(4): + for j in range(1, 4): + if self.grid[i][j] != 0: + for k in range(j, 0, -1): + if self.grid[i][k - 1] == 0: + self.grid[i][k - 1] = self.grid[i][k] + self.grid[i][k] = 0 + + def _move_right(self): + for i in range(4): + for j in range(2, -1, -1): + if self.grid[i][j] != 0: + for k in range(j, 3): + if self.grid[i][k + 1] == 0: + self.grid[i][k + 1] = self.grid[i][k] + self.grid[i][k] = 0 diff --git a/tests/metagpt/roles/test_qa_engineer.py b/tests/metagpt/roles/test_qa_engineer.py index 8fd7c0373..784c26a06 100644 --- a/tests/metagpt/roles/test_qa_engineer.py +++ b/tests/metagpt/roles/test_qa_engineer.py @@ -5,3 +5,59 @@ @Author : alexanderwu @File : test_qa_engineer.py """ +from pathlib import Path +from typing import List + +import pytest +from pydantic import Field + +from metagpt.actions import DebugError, RunCode, WriteTest +from metagpt.actions.summarize_code import SummarizeCode +from metagpt.config import CONFIG +from metagpt.environment import Environment +from metagpt.roles import QaEngineer +from metagpt.schema import Message +from metagpt.utils.common import any_to_str, aread, awrite + + +async def test_qa(): + # Prerequisites + demo_path = Path(__file__).parent / "../../data/demo_project" + CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048" + data = await aread(filename=demo_path / "game.py", encoding="utf-8") + await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8") + await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="") + + class MockEnv(Environment): + msgs: List[Message] = Field(default_factory=list) + + def publish_message(self, message: Message, peekable: bool = True) -> bool: + self.msgs.append(message) + return True + + env = MockEnv() + + role = QaEngineer() + role.set_env(env) + await role.run(with_message=Message(content="", cause_by=SummarizeCode)) + assert env.msgs + assert env.msgs[0].cause_by == any_to_str(WriteTest) + msg = env.msgs[0] + env.msgs.clear() + await role.run(with_message=msg) + assert env.msgs + assert env.msgs[0].cause_by == any_to_str(RunCode) + msg = env.msgs[0] + env.msgs.clear() + await role.run(with_message=msg) + assert env.msgs + assert env.msgs[0].cause_by == any_to_str(DebugError) + msg = env.msgs[0] + env.msgs.clear() + role.test_round_allowed = 1 + rsp = await role.run(with_message=msg) + assert "Exceeding" in rsp.content + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) From b7d74c64836f6b6ab293a9952a5b8fe04c6613b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Tue, 2 Jan 2024 15:31:49 +0800 Subject: [PATCH 3/3] fixbug: azure openai --- metagpt/provider/azure_openai_api.py | 2 +- metagpt/utils/redis.py | 21 +++++++++-------- metagpt/utils/s3.py | 25 +++++++++++---------- tests/metagpt/learn/test_text_to_image.py | 12 ++++++---- tests/metagpt/learn/test_text_to_speech.py | 9 +++++--- tests/metagpt/roles/test_architect.py | 26 +++++++++++++++++++--- tests/metagpt/utils/test_redis.py | 19 ++++++++++------ tests/metagpt/utils/test_s3.py | 13 ++++++----- 8 files changed, 83 insertions(+), 44 deletions(-) diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index b59326c7f..d15d1c82e 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -27,7 +27,7 @@ class AzureOpenAILLM(OpenAILLM): def _init_client(self): kwargs = self._make_client_kwargs() # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix - self.async_client = AsyncAzureOpenAI(**kwargs) + self.aclient = AsyncAzureOpenAI(**kwargs) self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs def _make_client_kwargs(self) -> dict: diff --git a/metagpt/utils/redis.py b/metagpt/utils/redis.py index e4b455c6b..10f33285c 100644 --- a/metagpt/utils/redis.py +++ b/metagpt/utils/redis.py @@ -23,15 +23,7 @@ class Redis: async def _connect(self, force=False): if self._client and not force: return True - is_ready = ( - CONFIG.REDIS_HOST - and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" - and CONFIG.REDIS_PORT - and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" - and CONFIG.REDIS_DB is not None - and CONFIG.REDIS_PASSWORD is not None - ) - if not is_ready: + if not self.is_configured: return False try: @@ -74,3 +66,14 @@ class Redis: @property def is_valid(self) -> bool: return self._client is not None + + @property + def is_configured(self) -> bool: + return bool( + CONFIG.REDIS_HOST + and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST" + and CONFIG.REDIS_PORT + and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT" + and CONFIG.REDIS_DB is not None + and CONFIG.REDIS_PASSWORD is not None + ) diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py index 6a38a80a4..2a2c1a31c 100644 --- a/metagpt/utils/s3.py +++ b/metagpt/utils/s3.py @@ -154,16 +154,17 @@ class S3: @property def is_valid(self): - is_invalid = ( - not CONFIG.S3_ACCESS_KEY - or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY" - or not CONFIG.S3_SECRET_KEY - or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY" - or not CONFIG.S3_ENDPOINT_URL - or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL" - or not CONFIG.S3_BUCKET - or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET" + return self.is_configured + + @property + def is_configured(self) -> bool: + return bool( + CONFIG.S3_ACCESS_KEY + and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY" + and CONFIG.S3_SECRET_KEY + and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY" + and CONFIG.S3_ENDPOINT_URL + and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL" + and CONFIG.S3_BUCKET + and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET" ) - if is_invalid: - logger.info("S3 is invalid") - return not is_invalid diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py index 0afe8534d..760b9d09c 100644 --- a/tests/metagpt/learn/test_text_to_image.py +++ b/tests/metagpt/learn/test_text_to_image.py @@ -15,20 +15,24 @@ from metagpt.learn.text_to_image import text_to_image @pytest.mark.asyncio -async def test(): +async def test_metagpt_llm(): # Prerequisites assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL assert CONFIG.OPENAI_API_KEY data = await text_to_image("Panda emoji", size_type="512x512") assert "base64" in data or "http" in data - key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL - CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None + + # Mock session env + old_options = CONFIG.options.copy() + new_options = old_options.copy() + new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None + CONFIG.set_context(new_options) try: data = await text_to_image("Panda emoji", size_type="512x512") assert "base64" in data or "http" in data finally: - CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key + CONFIG.set_context(old_options) if __name__ == "__main__": diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py index 02faecdde..aca08b9a2 100644 --- a/tests/metagpt/learn/test_text_to_speech.py +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -27,13 +27,16 @@ async def test_text_to_speech(): assert "base64" in data or "http" in data # test iflytek - key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY - CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = "" + ## Mock session env + old_options = CONFIG.options.copy() + new_options = old_options.copy() + new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = "" + CONFIG.set_context(new_options) try: data = await text_to_speech("panda emoji") assert "base64" in data or "http" in data finally: - CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key + CONFIG.set_context(old_options) if __name__ == "__main__": diff --git a/tests/metagpt/roles/test_architect.py b/tests/metagpt/roles/test_architect.py index 0c8fbfe04..06e4b2d11 100644 --- a/tests/metagpt/roles/test_architect.py +++ b/tests/metagpt/roles/test_architect.py @@ -7,18 +7,38 @@ @Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message distribution feature for message handling. """ +import uuid + import pytest +from metagpt.actions import WriteDesign, WritePRD +from metagpt.config import CONFIG +from metagpt.const import PRDS_FILE_REPO from metagpt.logs import logger from metagpt.roles import Architect +from metagpt.schema import Message +from metagpt.utils.common import any_to_str, awrite from tests.metagpt.roles.mock import MockMessages @pytest.mark.asyncio async def test_architect(): - # FIXME: make git as env? Or should we support + # Prerequisites + filename = uuid.uuid4().hex + ".json" + await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content) + role = Architect() - role.put_message(MockMessages.req) - rsp = await role.run(MockMessages.prd) + rsp = await role.run(with_message=Message(content="", cause_by=WritePRD)) logger.info(rsp) assert len(rsp.content) > 0 + assert rsp.cause_by == any_to_str(WriteDesign) + + # test update + rsp = await role.run(with_message=Message(content="", cause_by=WritePRD)) + assert rsp + assert rsp.cause_by == any_to_str(WriteDesign) + assert len(rsp.content) > 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_redis.py b/tests/metagpt/utils/test_redis.py index a75341433..b93ff0cdb 100644 --- a/tests/metagpt/utils/test_redis.py +++ b/tests/metagpt/utils/test_redis.py @@ -27,13 +27,18 @@ async def test_redis(): assert await conn.get("test") == b"test" await conn.close() - key = CONFIG.REDIS_HOST - CONFIG.REDIS_HOST = "YOUR_REDIS_HOST" - conn = Redis() - await conn.set("test", "test", timeout_sec=0) - assert not await conn.get("test") == b"test" - CONFIG.REDIS_HOST = key - await conn.close() + # Mock session env + old_options = CONFIG.options.copy() + new_options = old_options.copy() + new_options["REDIS_HOST"] = "YOUR_REDIS_HOST" + CONFIG.set_context(new_options) + try: + conn = Redis() + await conn.set("test", "test", timeout_sec=0) + assert not await conn.get("test") == b"test" + await conn.close() + finally: + CONFIG.set_context(old_options) if __name__ == "__main__": diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py index 9906d566f..f74e7b52a 100644 --- a/tests/metagpt/utils/test_s3.py +++ b/tests/metagpt/utils/test_s3.py @@ -41,15 +41,18 @@ async def test_s3(): res = await conn.cache(data, ".bak", "script") assert "http" in res - key = CONFIG.S3_ACCESS_KEY - CONFIG.S3_ACCESS_KEY = "YOUR_S3_ACCESS_KEY" - conn = S3() - assert not conn.is_valid + # Mock session env + old_options = CONFIG.options.copy() + new_options = old_options.copy() + new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY" + CONFIG.set_context(new_options) try: + conn = S3() + assert not conn.is_valid res = await conn.cache("ABC", ".bak", "script") assert not res finally: - CONFIG.S3_ACCESS_KEY = key + CONFIG.set_context(old_options) if __name__ == "__main__":