Merge pull request #1679 from jason-jszhang/feat_role_ut

Feat role ut
This commit is contained in:
better629 2025-02-26 14:12:18 +08:00 committed by GitHub
commit da349c9ec1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 435 additions and 102 deletions

View file

@ -1,21 +1,78 @@
from unittest.mock import AsyncMock
import pytest
from metagpt.const import TEST_DATA_PATH
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.logs import logger
from metagpt.roles.di.data_analyst import DataAnalyst
from metagpt.tools.tool_recommend import BM25ToolRecommender
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "filename"), [("similarity search about '有哪些需求描述?' in document ", TEST_DATA_PATH / "requirements/2.pdf")]
)
async def test_similarity_search(query, filename):
di = DataAnalyst()
query += f"'{str(filename)}'"
class TestDataAnalyst:
def test_init(self):
analyst = DataAnalyst()
assert analyst.name == "David"
assert analyst.profile == "DataAnalyst"
assert "Browser" in analyst.tools
assert isinstance(analyst.write_code, WriteAnalysisCode)
assert isinstance(analyst.execute_code, ExecuteNbCode)
rsp = await di.run(query)
assert rsp
def test_set_custom_tool(self):
analyst = DataAnalyst()
analyst.custom_tools = ["web scraping", "Terminal"]
assert isinstance(analyst.custom_tool_recommender, BM25ToolRecommender)
@pytest.mark.asyncio
async def test_write_and_exec_code_no_task(self):
analyst = DataAnalyst()
result = await analyst.write_and_exec_code()
logger.info(result)
assert "No current_task found" in result
if __name__ == "__main__":
pytest.main([__file__, "-s"])
@pytest.mark.asyncio
async def test_write_and_exec_code_success(self):
analyst = DataAnalyst()
await analyst.execute_code.init_code()
analyst.planner.plan.goal = "construct a two-dimensional array"
analyst.planner.plan.append_task(
task_id="1",
dependent_task_ids=[],
instruction="construct a two-dimensional array",
assignee="David",
task_type="DATA_ANALYSIS",
)
result = await analyst.write_and_exec_code("construct a two-dimensional array")
logger.info(result)
assert "Success" in result
@pytest.mark.asyncio
async def test_write_and_exec_code_failure(self):
analyst = DataAnalyst()
await analyst.execute_code.init_code()
analyst.planner.plan.goal = "Execute a code that fails"
analyst.planner.plan.append_task(
task_id="1", dependent_task_ids=[], instruction="Execute a code that fails", assignee="David"
)
analyst.execute_code.run = AsyncMock(return_value=("Error: Division by zero", False))
result = await analyst.write_and_exec_code("divide by zero")
logger.info(result)
assert "Failed" in result
assert "Error: Division by zero" in result
@pytest.mark.asyncio
async def test_run_special_command(self):
analyst = DataAnalyst()
analyst.planner.plan.goal = "test goal"
analyst.planner.plan.append_task(task_id="1", dependent_task_ids=[], instruction="test task", assignee="David")
assert not analyst.planner.plan.is_plan_finished()
cmd = {"command_name": "end"}
result = await analyst._run_special_command(cmd)
assert "All tasks are finished" in result
assert analyst.planner.plan.is_plan_finished()

View file

@ -0,0 +1,41 @@
import pytest
from metagpt.actions import UserRequirement
from metagpt.logs import logger
from metagpt.roles.di.role_zero import RoleZero
from metagpt.schema import Message
@pytest.mark.asyncio
async def test_model_validators():
"""Test all model validators"""
role = RoleZero()
# Test set_plan_and_tool
assert role.react_mode == "react"
assert role.planner is not None
# Test set_tool_execution
assert "Plan.append_task" in role.tool_execution_map
assert "RoleZero.ask_human" in role.tool_execution_map
# Test set_longterm_memory
assert role.rc.memory is not None
@pytest.mark.asyncio
async def test_think_react_cycle():
"""Test the think-react cycle"""
# Setup test conditions
role = RoleZero(tools=["Plan"])
role.rc.todo = True
role.planner.plan.goal = "Test goal"
role.respond_language = "English"
# Test _think
result = await role._think()
assert result is True
role.rc.news = [Message(content="Test", cause_by=UserRequirement())]
result = await role._react()
logger.info(result)
assert isinstance(result, Message)

View file

@ -0,0 +1,47 @@
import pytest
from metagpt.roles.di.swe_agent import SWEAgent
from metagpt.schema import Message
from metagpt.tools.libs.terminal import Bash
from metagpt.environment.mgx.mgx_env import MGXEnv
from metagpt.roles.di.team_leader import TeamLeader
@pytest.fixture
def env():
test_env = MGXEnv()
tl = TeamLeader()
test_env.add_roles(
[
tl,
SWEAgent()
]
)
return test_env
@pytest.mark.asyncio
async def test_swe_agent(env):
requirement = "Fix bug in the calculator app"
swe = env.get_role("Swen")
message = Message(content=requirement, send_to={swe.name})
env.publish_message(message)
await swe.run()
history = env.history.get()
agent_messages = [msg for msg in history if msg.sent_from == swe.name]
assert swe.name == "Swen"
assert swe.profile == "Issue Solver"
assert isinstance(swe.terminal, Bash)
assert "Bash" in swe.tools
assert "git_create_pull" in swe.tool_execution_map
def is_valid_instruction_message(msg: Message) -> bool:
content = msg.content.lower()
return any(word in content for word in ["git", "bash", "check", "fix"])
assert any(is_valid_instruction_message(msg) for msg in agent_messages), "Should have valid instruction messages"

View file

@ -40,78 +40,70 @@ def env():
@pytest.mark.asyncio
async def test_plan_for_software_requirement(env):
requirement = "create a 2048 game"
tl = env.get_role("Team Leader")
tl = env.get_role("Mike")
env.publish_message(Message(content=requirement, send_to=tl.name))
await tl.run()
# TL should assign tasks to 5 members first, then send message to the first assignee, 6 commands in total
assert len(tl.commands) == 6
plan_cmd = tl.commands[:5]
route_cmd = tl.commands[5]
history = env.history.get()
task_assignment = [task["args"]["assignee"] for task in plan_cmd]
assert task_assignment == [
ProductManager().name,
Architect().name,
ProjectManager().name,
Engineer().name,
QaEngineer().name,
]
assert route_cmd["command_name"] == "publish_message"
assert route_cmd["args"]["send_to"] == ProductManager().name
messages_to_team = [msg for msg in history if msg.sent_from == tl.name]
pm_messages = [msg for msg in messages_to_team if "Alice" in msg.send_to]
assert len(pm_messages) > 0, "Should have message sent to Product Manager"
found_task_msg = False
for msg in messages_to_team:
if "prd" in msg.content.lower() and any(role in msg.content for role in ["Alice", "Bob", "Alex", "David"]):
found_task_msg = True
break
assert found_task_msg, "Should have task assignment message"
@pytest.mark.asyncio
async def test_plan_for_data_related_requirement(env):
requirement = "I want to use yolov5 for target detection, yolov5 all the information from the following link, please help me according to the content of the link (https://github.com/ultralytics/yolov5), set up the environment and download the model parameters, and finally provide a few pictures for inference, the inference results will be saved!"
tl = env.get_role("Team Leader")
tl = env.get_role("Mike")
env.publish_message(Message(content=requirement, send_to=tl.name))
await tl.run()
# TL should assign 1 task to Data Analyst and send message to it
assert len(tl.commands) == 2
plan_cmd = tl.commands[0]
route_cmd = tl.commands[-1]
history = env.history.get()
messages_from_tl = [msg for msg in history if msg.sent_from == tl.name]
da_messages = [msg for msg in messages_from_tl if "David" in msg.send_to]
assert len(da_messages) > 0
da = env.get_role("Data Analyst")
assert plan_cmd["command_name"] == "append_task"
assert plan_cmd["args"]["assignee"] == da.name
da_message = da_messages[0]
assert "https://github.com/ultralytics/yolov5" in da_message.content
assert route_cmd["command_name"] == "publish_message"
assert "https://github.com" in route_cmd["args"]["content"] # necessary info must be in the message
assert route_cmd["args"]["send_to"] == da.name
def is_valid_task_message(msg: Message) -> bool:
content = msg.content.lower()
has_model_info = "yolov5" in content
has_task_info = any(word in content for word in ["detection", "inference", "environment", "parameters"])
has_link = "github.com" in content
return has_model_info and has_task_info and has_link
assert is_valid_task_message(da_message)
@pytest.mark.asyncio
async def test_plan_for_mixed_requirement(env):
requirement = "Search the web for the new game 2048X, then replicate it"
tl = env.get_role("Team Leader")
tl = env.get_role("Mike")
env.publish_message(Message(content=requirement, send_to=tl.name))
await tl.run()
# TL should assign 6 tasks, first to Data Analyst to search the web, following by the software team sequence
# TL should send message to Data Analyst after task assignment
assert len(tl.commands) == 7
plan_cmd = tl.commands[:6]
route_cmd = tl.commands[-1]
history = env.history.get()
messages_from_tl = [msg for msg in history if msg.sent_from == tl.name]
task_assignment = [task["args"]["assignee"] for task in plan_cmd]
da = env.get_role("Data Analyst")
assert task_assignment == [
da.name,
ProductManager().name,
Architect().name,
ProjectManager().name,
Engineer().name,
QaEngineer().name,
]
da_messages = [msg for msg in messages_from_tl if "David" in msg.send_to]
assert len(da_messages) > 0
assert route_cmd["command_name"] == "publish_message"
assert route_cmd["args"]["send_to"] == da.name
da_message = da_messages[0]
def is_valid_search_task(msg: Message) -> bool:
content = msg.content.lower()
return "2048x" in content and "search" in content
assert is_valid_search_task(da_message)
PRD_MSG_CONTENT = """{'docs': {'20240424153821.json': {'root_path': 'docs/prd', 'filename': '20240424153821.json', 'content': '{"Language":"en_us","Programming Language":"Python","Original Requirements":"create a 2048 game","Project Name":"game_2048","Product Goals":["Develop an intuitive and addictive 2048 game variant","Ensure the game is accessible and performs well on various devices","Design a visually appealing and modern user interface"],"User Stories":["As a player, I want to be able to undo my last move so I can correct mistakes","As a player, I want to see my high scores to track my progress over time","As a player, I want to be able to play the game without any internet connection"],"Competitive Analysis":["2048 Original: Classic gameplay, minimalistic design, lacks social sharing features","2048 Hex: Unique hexagon board, but not mobile-friendly","2048 Multiplayer: Offers real-time competition, but overwhelming ads","2048 Bricks: Innovative gameplay with bricks, but poor performance on older devices","2048.io: Multiplayer battle royale mode, but complicated UI for new players","2048 Animated: Animated tiles add fun, but the game consumes a lot of battery","2048 3D: 3D version of the game, but has a steep learning curve"],"Competitive Quadrant Chart":"quadrantChart\\n title \\"User Experience and Feature Set of 2048 Games\\"\\n x-axis \\"Basic Features\\" --> \\"Rich Features\\"\\n y-axis \\"Poor Experience\\" --> \\"Great Experience\\"\\n quadrant-1 \\"Need Improvement\\"\\n quadrant-2 \\"Feature-Rich but Complex\\"\\n quadrant-3 \\"Simplicity with Poor UX\\"\\n quadrant-4 \\"Balanced\\"\\n \\"2048 Original\\": [0.2, 0.7]\\n \\"2048 Hex\\": [0.3, 0.4]\\n \\"2048 Multiplayer\\": [0.6, 0.5]\\n \\"2048 Bricks\\": [0.4, 0.3]\\n \\"2048.io\\": [0.7, 0.4]\\n \\"2048 Animated\\": [0.5, 0.6]\\n \\"2048 3D\\": [0.6, 0.3]\\n \\"Our Target Product\\": [0.8, 0.9]","Requirement Analysis":"The game must be engaging and retain players, which requires a balance of simplicity and challenge. Accessibility on various devices is crucial for a wider reach. A modern UI is needed to attract and retain the modern user. The ability to play offline is important for users on the go. High score tracking and the ability to undo moves are features that will enhance user experience.","Requirement Pool":[["P0","Implement core 2048 gameplay mechanics"],["P0","Design responsive UI for multiple devices"],["P1","Develop undo move feature"],["P1","Integrate high score tracking system"],["P2","Enable offline gameplay capability"]],"UI Design draft":"The UI will feature a clean and modern design with a minimalist color scheme. The game board will be center-aligned with smooth tile animations. Score and high score will be displayed at the top. Undo and restart buttons will be easily accessible. The design will be responsive to fit various screen sizes.","Anything UNCLEAR":"The monetization strategy for the game is not specified. Further clarification is needed on whether the game should include advertisements, in-app purchases, or be completely free."}'}}}"""
@ -122,48 +114,60 @@ DESIGN_CONTENT = """{"docs":{"20240424214432.json":{"root_path":"docs/system_des
async def test_plan_update_and_routing(env):
requirement = "create a 2048 game"
tl = env.get_role("Team Leader")
tl = env.get_role("Mike")
env.publish_message(Message(content=requirement))
await tl.run()
# Assuming Product Manager finishes its task
env.publish_message(Message(content=PRD_MSG_CONTENT, role="Alice(Product Manager)", sent_from="Alice"))
# Verify message routing after PM completes task
env.publish_message(Message(content=PRD_MSG_CONTENT, sent_from="Alice", send_to={"<all>"}))
await tl.run()
# TL should mark current task as finished, and forward Product Manager's message to Architect
# Current task should be updated to the second task
plan_cmd = tl.commands[0]
route_cmd = tl.commands[-1]
assert plan_cmd["command_name"] == "finish_current_task"
assert route_cmd["command_name"] == "publish_message"
assert route_cmd["args"]["send_to"] == Architect().name
assert tl.planner.plan.current_task_id == "2"
# Get message history
history = env.history.get()
messages_from_tl = [msg for msg in history if msg.sent_from == tl.name]
# Next step, assuming Architect finishes its task
env.publish_message(Message(content=DESIGN_CONTENT, role="Bob(Architect)", sent_from="Bob"))
# Verify messages sent to architect
architect_messages = [msg for msg in messages_from_tl if "Bob" in msg.send_to]
assert len(architect_messages) > 0, "Should have message forwarded to architect"
# Verify message content contains PRD info
architect_message = architect_messages[-1]
assert "2048 game based on the PRD" in architect_message.content, "Message to architect should contain PRD info"
# Verify message routing after architect completes task
env.publish_message(Message(content=DESIGN_CONTENT, sent_from="Bob", send_to={"<all>"}))
await tl.run()
plan_cmd = tl.commands[0]
route_cmd = tl.commands[-1]
assert plan_cmd["command_name"] == "finish_current_task"
assert route_cmd["command_name"] == "publish_message"
assert route_cmd["args"]["send_to"] == ProjectManager().name
assert tl.planner.plan.current_task_id == "3"
@pytest.mark.asyncio
async def test_reply_to_human(env):
requirement = "create a 2048 game"
tl = env.get_role("Team Leader")
tl = env.get_role("Mike")
env.publish_message(Message(content=requirement))
await tl.run()
# Assuming Product Manager finishes its task
env.publish_message(Message(content=PRD_MSG_CONTENT, role="Alice(Product Manager)", sent_from="Alice"))
# PM finishes task
env.publish_message(Message(content=PRD_MSG_CONTENT, sent_from="Alice", send_to={"<all>"}))
await tl.run()
# Human inquires about the progress
env.publish_message(Message(content="Who is working? How does the project go?"))
# Get history before human inquiry
history_before = env.history.get()
# Human inquires about progress
env.publish_message(Message(content="Who is working? How does the project go?", send_to={tl.name}))
await tl.run()
assert tl.commands[0]["command_name"] == "reply_to_human"
# Get new messages after human inquiry
history_after = env.history.get()
new_messages = [msg for msg in history_after if msg not in history_before]
# Verify team leader's response
tl_responses = [msg for msg in new_messages if msg.sent_from == tl.name]
assert len(tl_responses) > 0, "Should have response from team leader"
# Verify response contains project status
response = tl_responses[0].content
assert any(
keyword in response.lower() for keyword in ["progress", "status", "working"]
), "Response should contain project status information"

View file

@ -11,31 +11,32 @@ import uuid
import pytest
from metagpt.actions import WriteDesign, WritePRD
from metagpt.actions import WritePRD
from metagpt.const import PRDS_FILE_REPO
from metagpt.logs import logger
from metagpt.roles import Architect
from metagpt.schema import Message
from metagpt.utils.common import any_to_str, awrite
from tests.metagpt.roles.mock import MockMessages
from pathlib import Path
from metagpt.actions.di.run_command import RunCommand
@pytest.mark.asyncio
async def test_architect(context):
# Prerequisites
filename = uuid.uuid4().hex + ".json"
await awrite(context.repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
await awrite(Path(context.config.project_path) / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
role = Architect(context=context)
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
logger.info(rsp)
assert len(rsp.content) > 0
assert rsp.cause_by == any_to_str(WriteDesign)
assert rsp.cause_by == any_to_str(RunCommand)
# test update
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
assert rsp
assert rsp.cause_by == any_to_str(WriteDesign)
assert rsp.cause_by == any_to_str(RunCommand)
assert len(rsp.content) > 0

View file

@ -20,18 +20,27 @@ from metagpt.schema import CodingContext, Message
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
from metagpt.utils.git_repository import ChangeType
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
from metagpt.utils.project_repo import ProjectRepo
from types import SimpleNamespace
@pytest.mark.asyncio
async def test_engineer(context):
# Prerequisites
rqno = "20231221155954.json"
await context.repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await context.repo.docs.prd.save(rqno, content=MockMessages.prd.content)
await context.repo.docs.system_design.save(rqno, content=MockMessages.system_design.content)
await context.repo.docs.task.save(rqno, content=MockMessages.json_tasks.content)
project_repo = ProjectRepo(context.config.project_path)
# 设置engineer
engineer = Engineer(context=context)
engineer.repo = project_repo
engineer.input_args = SimpleNamespace(project_path=context.config.project_path)
# 使用project_repo保存所需文件
await project_repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content)
await project_repo.docs.prd.save(rqno, content=MockMessages.prd.content)
await project_repo.docs.system_design.save(rqno, content=MockMessages.system_design.content)
await project_repo.docs.task.save(rqno, content=MockMessages.json_tasks.content)
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
logger.info(rsp)

View file

@ -10,27 +10,25 @@ import json
import pytest
from metagpt.actions import WritePRD
from metagpt.const import REQUIREMENT_FILENAME
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.roles import ProductManager
from metagpt.utils.common import any_to_str
from tests.metagpt.roles.mock import MockMessages
from metagpt.utils.git_repository import GitRepository
@pytest.mark.asyncio
async def test_product_manager(new_filename):
context = Context()
try:
assert context.git_repo is None
assert context.repo is None
product_manager = ProductManager(context=context)
# prepare documents
logger.info(MockMessages.req)
rsp = await product_manager.run(MockMessages.req)
assert context.git_repo
assert context.repo
logger.info(rsp)
assert rsp.cause_by == any_to_str(WritePRD)
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
# assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
logger.info(rsp)
assert len(rsp.content) > 0
doc = list(rsp.instruct_content.docs.values())[0]
@ -43,7 +41,10 @@ async def test_product_manager(new_filename):
except Exception as e:
assert not e
finally:
context.git_repo.delete_repository()
# Clean up using the project path
if context.config.project_path:
git_repo = GitRepository(context.config.project_path)
git_repo.delete_repository()
if __name__ == "__main__":

View file

@ -15,5 +15,5 @@ from tests.metagpt.roles.mock import MockMessages
@pytest.mark.asyncio
async def test_project_manager(context):
project_manager = ProjectManager(context=context)
rsp = await project_manager.run(MockMessages.system_design)
rsp = await project_manager.run(MockMessages.tasks)
logger.info(rsp)