mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
commit
d2260a5958
12 changed files with 306 additions and 66 deletions
|
|
@ -372,16 +372,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
|
||||
return msg
|
||||
|
||||
def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
|
||||
news = []
|
||||
# Warning, remove `id` here to make it work for recover
|
||||
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
|
||||
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
|
||||
for idx, new in enumerate(observed_pure):
|
||||
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
|
||||
news.append(observed[idx])
|
||||
return news
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
"""Prepare new messages for processing from the message buffer and other sources."""
|
||||
# Read unprocessed messages from the msg buffer.
|
||||
|
|
@ -407,29 +397,6 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
logger.debug(f"{self._setting} observed: {news_text}")
|
||||
return len(self.rc.news)
|
||||
|
||||
# async def _observe(self, ignore_memory=False) -> int:
|
||||
# """Prepare new messages for processing from the message buffer and other sources."""
|
||||
# # Read unprocessed messages from the msg buffer.
|
||||
# news = self.rc.msg_buffer.pop_all()
|
||||
# if self.recovered:
|
||||
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
# else:
|
||||
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
#
|
||||
# # Store the read messages in your own memory to prevent duplicate processing.
|
||||
# old_messages = [] if ignore_memory else self.rc.memory.get()
|
||||
# self.rc.memory.add_batch(news)
|
||||
# # Filter out messages of interest.
|
||||
# self.rc.news = self._find_news(news, old_messages)
|
||||
#
|
||||
# # Design Rules:
|
||||
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
|
||||
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
|
||||
# if news_text:
|
||||
# logger.debug(f"{self._setting} observed: {news_text}")
|
||||
# return len(self.rc.news)
|
||||
|
||||
def publish_message(self, msg):
|
||||
"""If the role belongs to env, then the role's messages will be broadcast to env"""
|
||||
if not msg:
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
@Author : mashenquan
|
||||
@File : redis.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
from datetime import timedelta
|
||||
|
|
@ -22,7 +23,7 @@ class Redis:
|
|||
async def _connect(self, force=False):
|
||||
if self._client and not force:
|
||||
return True
|
||||
if not CONFIG.REDIS_HOST or not CONFIG.REDIS_PORT or CONFIG.REDIS_DB is None or CONFIG.REDIS_PASSWORD is None:
|
||||
if not self.is_configured:
|
||||
return False
|
||||
|
||||
try:
|
||||
|
|
@ -37,7 +38,7 @@ class Redis:
|
|||
logger.warning(f"Redis initialization has failed:{e}")
|
||||
return False
|
||||
|
||||
async def get(self, key: str) -> bytes:
|
||||
async def get(self, key: str) -> bytes | None:
|
||||
if not await self._connect() or not key:
|
||||
return None
|
||||
try:
|
||||
|
|
@ -65,3 +66,14 @@ class Redis:
|
|||
@property
|
||||
def is_valid(self) -> bool:
|
||||
return self._client is not None
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.REDIS_HOST
|
||||
and CONFIG.REDIS_HOST != "YOUR_REDIS_HOST"
|
||||
and CONFIG.REDIS_PORT
|
||||
and CONFIG.REDIS_PORT != "YOUR_REDIS_PORT"
|
||||
and CONFIG.REDIS_DB is not None
|
||||
and CONFIG.REDIS_PASSWORD is not None
|
||||
)
|
||||
|
|
|
|||
|
|
@ -154,16 +154,17 @@ class S3:
|
|||
|
||||
@property
|
||||
def is_valid(self):
|
||||
is_invalid = (
|
||||
not CONFIG.S3_ACCESS_KEY
|
||||
or CONFIG.S3_ACCESS_KEY == "YOUR_S3_ACCESS_KEY"
|
||||
or not CONFIG.S3_SECRET_KEY
|
||||
or CONFIG.S3_SECRET_KEY == "YOUR_S3_SECRET_KEY"
|
||||
or not CONFIG.S3_ENDPOINT_URL
|
||||
or CONFIG.S3_ENDPOINT_URL == "YOUR_S3_ENDPOINT_URL"
|
||||
or not CONFIG.S3_BUCKET
|
||||
or CONFIG.S3_BUCKET == "YOUR_S3_BUCKET"
|
||||
return self.is_configured
|
||||
|
||||
@property
|
||||
def is_configured(self) -> bool:
|
||||
return bool(
|
||||
CONFIG.S3_ACCESS_KEY
|
||||
and CONFIG.S3_ACCESS_KEY != "YOUR_S3_ACCESS_KEY"
|
||||
and CONFIG.S3_SECRET_KEY
|
||||
and CONFIG.S3_SECRET_KEY != "YOUR_S3_SECRET_KEY"
|
||||
and CONFIG.S3_ENDPOINT_URL
|
||||
and CONFIG.S3_ENDPOINT_URL != "YOUR_S3_ENDPOINT_URL"
|
||||
and CONFIG.S3_BUCKET
|
||||
and CONFIG.S3_BUCKET != "YOUR_S3_BUCKET"
|
||||
)
|
||||
if is_invalid:
|
||||
logger.info("S3 is invalid")
|
||||
return not is_invalid
|
||||
|
|
|
|||
92
tests/data/demo_project/game.py
Normal file
92
tests/data/demo_project/game.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
## game.py
|
||||
|
||||
import random
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class Game:
|
||||
def __init__(self):
|
||||
self.grid: List[List[int]] = [[0 for _ in range(4)] for _ in range(4)]
|
||||
self.score: int = 0
|
||||
self.game_over: bool = False
|
||||
|
||||
def reset_game(self):
|
||||
self.grid = [[0 for _ in range(4)] for _ in range(4)]
|
||||
self.score = 0
|
||||
self.game_over = False
|
||||
self.add_new_tile()
|
||||
self.add_new_tile()
|
||||
|
||||
def move(self, direction: str):
|
||||
if direction == "up":
|
||||
self._move_up()
|
||||
elif direction == "down":
|
||||
self._move_down()
|
||||
elif direction == "left":
|
||||
self._move_left()
|
||||
elif direction == "right":
|
||||
self._move_right()
|
||||
|
||||
def is_game_over(self) -> bool:
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
if self.grid[i][j] == 0:
|
||||
return False
|
||||
if j < 3 and self.grid[i][j] == self.grid[i][j + 1]:
|
||||
return False
|
||||
if i < 3 and self.grid[i][j] == self.grid[i + 1][j]:
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_empty_cells(self) -> List[Tuple[int, int]]:
|
||||
empty_cells = []
|
||||
for i in range(4):
|
||||
for j in range(4):
|
||||
if self.grid[i][j] == 0:
|
||||
empty_cells.append((i, j))
|
||||
return empty_cells
|
||||
|
||||
def add_new_tile(self):
|
||||
empty_cells = self.get_empty_cells()
|
||||
if empty_cells:
|
||||
x, y = random.choice(empty_cells)
|
||||
self.grid[x][y] = 2 if random.random() < 0.9 else 4
|
||||
|
||||
def get_score(self) -> int:
|
||||
return self.score
|
||||
|
||||
def _move_up(self):
|
||||
for j in range(4):
|
||||
for i in range(1, 4):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(i, 0, -1):
|
||||
if self.grid[k - 1][j] == 0:
|
||||
self.grid[k - 1][j] = self.grid[k][j]
|
||||
self.grid[k][j] = 0
|
||||
|
||||
def _move_down(self):
|
||||
for j in range(4):
|
||||
for i in range(2, -1, -1):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(i, 3):
|
||||
if self.grid[k + 1][j] == 0:
|
||||
self.grid[k + 1][j] = self.grid[k][j]
|
||||
self.grid[k][j] = 0
|
||||
|
||||
def _move_left(self):
|
||||
for i in range(4):
|
||||
for j in range(1, 4):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(j, 0, -1):
|
||||
if self.grid[i][k - 1] == 0:
|
||||
self.grid[i][k - 1] = self.grid[i][k]
|
||||
self.grid[i][k] = 0
|
||||
|
||||
def _move_right(self):
|
||||
for i in range(4):
|
||||
for j in range(2, -1, -1):
|
||||
if self.grid[i][j] != 0:
|
||||
for k in range(j, 3):
|
||||
if self.grid[i][k + 1] == 0:
|
||||
self.grid[i][k + 1] = self.grid[i][k]
|
||||
self.grid[i][k] = 0
|
||||
|
|
@ -15,20 +15,24 @@ from metagpt.learn.text_to_image import text_to_image
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test():
|
||||
async def test_metagpt_llm():
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
key = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -27,13 +27,16 @@ async def test_text_to_speech():
|
|||
assert "base64" in data or "http" in data
|
||||
|
||||
# test iflytek
|
||||
key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = ""
|
||||
## Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.AZURE_TTS_SUBSCRIPTION_KEY = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -7,18 +7,38 @@
|
|||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message handling.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, awrite
|
||||
from tests.metagpt.roles.mock import MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect():
|
||||
# FIXME: make git as env? Or should we support
|
||||
# Prerequisites
|
||||
filename = uuid.uuid4().hex + ".json"
|
||||
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
|
||||
|
||||
role = Architect()
|
||||
role.put_message(MockMessages.req)
|
||||
rsp = await role.run(MockMessages.prd)
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
|
||||
# test update
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
assert rsp
|
||||
assert rsp.cause_by == any_to_str(WriteDesign)
|
||||
assert len(rsp.content) > 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -5,3 +5,59 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_qa_engineer.py
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import QaEngineer
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, aread, awrite
|
||||
|
||||
|
||||
async def test_qa():
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
|
||||
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
|
||||
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
|
||||
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
|
||||
|
||||
class MockEnv(Environment):
|
||||
msgs: List[Message] = Field(default_factory=list)
|
||||
|
||||
def publish_message(self, message: Message, peekable: bool = True) -> bool:
|
||||
self.msgs.append(message)
|
||||
return True
|
||||
|
||||
env = MockEnv()
|
||||
|
||||
role = QaEngineer()
|
||||
role.set_env(env)
|
||||
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(WriteTest)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(RunCode)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
await role.run(with_message=msg)
|
||||
assert env.msgs
|
||||
assert env.msgs[0].cause_by == any_to_str(DebugError)
|
||||
msg = env.msgs[0]
|
||||
env.msgs.clear()
|
||||
role.test_round_allowed = 1
|
||||
rsp = await role.run(with_message=msg)
|
||||
assert "Exceeding" in rsp.content
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -10,15 +10,17 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, UserRequirement
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_name, any_to_str
|
||||
|
||||
|
||||
class MockAction(Action):
|
||||
|
|
@ -96,7 +98,7 @@ async def test_react():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_msg_to():
|
||||
async def test_send_to():
|
||||
m = Message(content="a", send_to=["a", MockRole, Message])
|
||||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
|
|
@ -107,5 +109,50 @@ async def test_msg_to():
|
|||
assert m.send_to == {"a", any_to_str(MockRole), any_to_str(Message)}
|
||||
|
||||
|
||||
def test_init_action():
|
||||
role = Role()
|
||||
role.init_actions([MockAction, MockAction])
|
||||
assert role.action_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_recover():
|
||||
# Mock LLM actions
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.aask.side_effect = ["1"]
|
||||
|
||||
role = Role()
|
||||
assert role.is_watch(any_to_str(UserRequirement))
|
||||
role.put_message(None)
|
||||
role.publish_message(None)
|
||||
|
||||
role.llm = mock_llm
|
||||
role.init_actions([MockAction, MockAction])
|
||||
role.recovered = True
|
||||
role.latest_observed_msg = Message(content="recover_test")
|
||||
role.rc.state = 0
|
||||
assert role.todo == any_to_name(MockAction)
|
||||
|
||||
rsp = await role.run()
|
||||
assert rsp.cause_by == any_to_str(MockAction)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_think_act():
|
||||
# Mock LLM actions
|
||||
mock_llm = MagicMock(spec=BaseLLM)
|
||||
mock_llm.aask.side_effect = ["ok"]
|
||||
|
||||
role = Role()
|
||||
role.init_actions([MockAction])
|
||||
await role.think()
|
||||
role.rc.memory.add(Message("run"))
|
||||
assert len(role.get_memories()) == 1
|
||||
rsp = await role.act()
|
||||
assert rsp
|
||||
assert isinstance(rsp, ActionOutput)
|
||||
assert rsp.content == "run"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -16,8 +16,10 @@ from metagpt.actions import Action
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
CodeSummarizeContext,
|
||||
Document,
|
||||
Message,
|
||||
MessageQueue,
|
||||
|
|
@ -61,6 +63,8 @@ def test_message():
|
|||
assert m.role == "b"
|
||||
assert m.send_to == {"c"}
|
||||
assert m.cause_by == "c"
|
||||
m.sent_from = "e"
|
||||
assert m.sent_from == "e"
|
||||
|
||||
m.cause_by = "Message"
|
||||
assert m.cause_by == "Message"
|
||||
|
|
@ -121,6 +125,8 @@ def test_document():
|
|||
@pytest.mark.asyncio
|
||||
async def test_message_queue():
|
||||
mq = MessageQueue()
|
||||
val = await mq.dump()
|
||||
assert val == "[]"
|
||||
mq.push(Message(content="1"))
|
||||
mq.push(Message(content="2中文测试aaa"))
|
||||
msg = mq.pop()
|
||||
|
|
@ -132,5 +138,23 @@ async def test_message_queue():
|
|||
assert new_mq.pop_all() == mq.pop_all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_list", "want"),
|
||||
[
|
||||
(
|
||||
[f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", f"{TASK_FILE_REPO}/b.txt"],
|
||||
CodeSummarizeContext(
|
||||
design_filename=f"{SYSTEM_DESIGN_FILE_REPO}/a.txt", task_filename=f"{TASK_FILE_REPO}/b.txt"
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
def test_CodeSummarizeContext(file_list, want):
|
||||
ctx = CodeSummarizeContext.loads(file_list)
|
||||
assert ctx == want
|
||||
m = {ctx: ctx}
|
||||
assert want in m
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -27,6 +27,19 @@ async def test_redis():
|
|||
assert await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["REDIS_HOST"] = "YOUR_REDIS_HOST"
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
conn = Redis()
|
||||
await conn.set("test", "test", timeout_sec=0)
|
||||
assert not await conn.get("test") == b"test"
|
||||
await conn.close()
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -41,17 +41,18 @@ async def test_s3():
|
|||
res = await conn.cache(data, ".bak", "script")
|
||||
assert "http" in res
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_s3_no_error():
|
||||
conn = S3()
|
||||
key = conn.auth_config["aws_secret_access_key"]
|
||||
conn.auth_config["aws_secret_access_key"] = ""
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["S3_ACCESS_KEY"] = "YOUR_S3_ACCESS_KEY"
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
conn = S3()
|
||||
assert not conn.is_valid
|
||||
res = await conn.cache("ABC", ".bak", "script")
|
||||
assert not res
|
||||
finally:
|
||||
conn.auth_config["aws_secret_access_key"] = key
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue