mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-26 15:49:42 +02:00
Merge branch 'code_intepreter' of https://gitlab.deepwisdomai.com/agents/data_agents_opt into code_intepreter
resolve conflict
This commit is contained in:
commit
7d38181f56
313 changed files with 8523 additions and 4255 deletions
|
|
@ -11,7 +11,7 @@ import pytest
|
|||
from pydantic import ValidationError
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.action_node import ActionNode, ReviewMode, ReviseMode
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -23,14 +23,12 @@ from metagpt.team import Team
|
|||
async def test_debate_two_roles():
|
||||
action1 = Action(name="AlexSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
action2 = Action(name="BobSay", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(
|
||||
alex = Role(
|
||||
name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2]
|
||||
)
|
||||
trump = Role(
|
||||
name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]
|
||||
)
|
||||
bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden, trump])
|
||||
team = Team(investment=10.0, env=env, roles=[alex, bob])
|
||||
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
|
@ -39,9 +37,9 @@ async def test_debate_two_roles():
|
|||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role_in_env():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
env = Environment(desc="US election live broadcast")
|
||||
team = Team(investment=10.0, env=env, roles=[biden])
|
||||
team = Team(investment=10.0, env=env, roles=[alex])
|
||||
history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3)
|
||||
assert "Alex" in history
|
||||
|
||||
|
|
@ -49,8 +47,8 @@ async def test_debate_one_role_in_env():
|
|||
@pytest.mark.asyncio
|
||||
async def test_debate_one_role():
|
||||
action = Action(name="Debate", instruction="Express your opinion with emotion and don't repeat it")
|
||||
biden = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await biden.run("Topic: climate change. Under 80 words per message.")
|
||||
alex = Role(name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action])
|
||||
msg: Message = await alex.run("Topic: climate change. Under 80 words per message.")
|
||||
|
||||
assert len(msg.content) > 10
|
||||
assert msg.sent_from == "metagpt.roles.role.Role"
|
||||
|
|
@ -98,6 +96,83 @@ async def test_action_node_two_layer():
|
|||
assert "579" in answer2.content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_review():
|
||||
key = "Project Name"
|
||||
node_a = ActionNode(
|
||||
key=key,
|
||||
expected_type=str,
|
||||
instruction='According to the content of "Original Requirements," name the project using snake case style '
|
||||
"with underline, like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
|
||||
|
||||
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
review_comments = await node_a.review(strgy="complex", review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 0
|
||||
|
||||
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node.review()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
|
||||
review_comments = await node.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
review_comments = await node.review(strgy="complex", review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
assert list(review_comments.keys())[0] == key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_node_revise():
|
||||
key = "Project Name"
|
||||
node_a = ActionNode(
|
||||
key=key,
|
||||
expected_type=str,
|
||||
instruction='According to the content of "Original Requirements," name the project using snake case style '
|
||||
"with underline, like 'game_2048' or 'simple_crm.",
|
||||
example="game_2048",
|
||||
)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
|
||||
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node_a.instruct_content, key)
|
||||
|
||||
revise_contents = await node_a.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 0
|
||||
|
||||
node = ActionNode.from_children(key="WritePRD", nodes=[node_a])
|
||||
with pytest.raises(RuntimeError):
|
||||
_ = await node.revise()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
setattr(node.instruct_content, key, "game snake")
|
||||
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node.instruct_content, key)
|
||||
|
||||
revise_contents = await node.revise(strgy="complex", revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
assert "game_snake" in getattr(node.instruct_content, key)
|
||||
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
"Required Other language third-party packages": '"""\nNo third-party packages required for other languages.\n"""\n',
|
||||
|
|
@ -138,10 +213,10 @@ def test_create_model_class():
|
|||
assert test_class.__name__ == "test_class"
|
||||
|
||||
output = test_class(**t_dict)
|
||||
print(output.schema())
|
||||
assert output.schema()["title"] == "test_class"
|
||||
assert output.schema()["type"] == "object"
|
||||
assert output.schema()["properties"]["Full API spec"]
|
||||
print(output.model_json_schema())
|
||||
assert output.model_json_schema()["title"] == "test_class"
|
||||
assert output.model_json_schema()["type"] == "object"
|
||||
assert output.model_json_schema()["properties"]["Full API spec"]
|
||||
|
||||
|
||||
def test_create_model_class_with_fields_unrecognized():
|
||||
|
|
|
|||
46
tests/metagpt/actions/test_action_outcls_registry.py
Normal file
46
tests/metagpt/actions/test_action_outcls_registry.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of action_outcls_registry
|
||||
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
|
||||
def test_action_outcls_registry():
|
||||
class_name = "test"
|
||||
out_mapping = {"field": (list[str], ...), "field1": (str, ...)}
|
||||
out_data = {"field": ["field value1", "field value2"], "field1": "field1 value1"}
|
||||
|
||||
outcls = ActionNode.create_model_class(class_name, mapping=out_mapping)
|
||||
outinst = outcls(**out_data)
|
||||
|
||||
outcls1 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
|
||||
outinst1 = outcls1(**out_data)
|
||||
assert outinst1 == outinst
|
||||
|
||||
outcls2 = ActionNode(key="", expected_type=str, instruction="", example="").create_model_class(
|
||||
class_name, out_mapping
|
||||
)
|
||||
outinst2 = outcls2(**out_data)
|
||||
assert outinst2 == outinst
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field": (list[str], ...)} # different order
|
||||
outcls3 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping)
|
||||
outinst3 = outcls3(**out_data)
|
||||
assert outinst3 == outinst
|
||||
|
||||
out_mapping2 = {"field1": (str, ...), "field": (List[str], ...)} # typing case
|
||||
outcls4 = ActionNode.create_model_class(class_name=class_name, mapping=out_mapping2)
|
||||
outinst4 = outcls4(**out_data)
|
||||
assert outinst4 == outinst
|
||||
|
||||
out_data2 = {"field2": ["field2 value1", "field2 value2"], "field1": "field1 value1"}
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)} # List first
|
||||
outcls5 = ActionNode.create_model_class(class_name, out_mapping)
|
||||
outinst5 = outcls5(**out_data2)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (list[str], ...)}
|
||||
outcls6 = ActionNode.create_model_class(class_name, out_mapping)
|
||||
outinst6 = outcls6(**out_data2)
|
||||
assert outinst5 == outinst6
|
||||
|
|
@ -48,7 +48,7 @@ def sort_array(arr):
|
|||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code
|
||||
assert "def sort_array(arr)" in new_code["code"]
|
||||
|
||||
|
||||
def test_messages_to_str():
|
||||
|
|
|
|||
|
|
@ -11,10 +11,7 @@ import uuid
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
CODE_CONTENT = '''
|
||||
from typing import List
|
||||
|
|
@ -117,8 +114,8 @@ if __name__ == '__main__':
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_error():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / uuid.uuid4().hex
|
||||
async def test_debug_error(context):
|
||||
context.src_workspace = context.git_repo.workdir / uuid.uuid4().hex
|
||||
ctx = RunCodeContext(
|
||||
code_filename="player.py",
|
||||
test_filename="test_player.py",
|
||||
|
|
@ -126,8 +123,8 @@ async def test_debug_error():
|
|||
output_filename="output.log",
|
||||
)
|
||||
|
||||
await FileRepository.save_file(filename=ctx.code_filename, content=CODE_CONTENT, relative_path=CONFIG.src_workspace)
|
||||
await FileRepository.save_file(filename=ctx.test_filename, content=TEST_CONTENT, relative_path=TEST_CODES_FILE_REPO)
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(filename=ctx.code_filename, content=CODE_CONTENT)
|
||||
await context.repo.tests.save(filename=ctx.test_filename, content=TEST_CONTENT)
|
||||
output_data = RunCodeResult(
|
||||
stdout=";",
|
||||
stderr="",
|
||||
|
|
@ -141,24 +138,11 @@ async def test_debug_error():
|
|||
"----------------------------------------------------------------------\n"
|
||||
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
|
||||
)
|
||||
debug_error = DebugError(context=ctx)
|
||||
await context.repo.test_outputs.save(filename=ctx.output_filename, content=output_data.model_dump_json())
|
||||
debug_error = DebugError(i_context=ctx, context=context)
|
||||
|
||||
rsp = await debug_error.run()
|
||||
|
||||
assert "class Player" in rsp # rewrite the same class
|
||||
# Problematic code:
|
||||
# ```
|
||||
# if self.score > 21 and any(card.rank == 'A' for card in self.hand):
|
||||
# self.score -= 10
|
||||
# ```
|
||||
# Should rewrite to (used "gpt-3.5-turbo-1106"):
|
||||
# ```
|
||||
# ace_count = sum(1 for card in self.hand if card.rank == 'A')
|
||||
# while self.score > 21 and ace_count > 0:
|
||||
# self.score -= 10
|
||||
# ace_count -= 1
|
||||
# ```
|
||||
assert "while self.score > 21" in rsp
|
||||
# a key logic to rewrite to (original one is "if self.score > 12")
|
||||
assert "self.score" in rsp
|
||||
|
|
|
|||
|
|
@ -9,20 +9,17 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_markdown import PRD_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。", PRD_SAMPLE]
|
||||
async def test_design_api(context):
|
||||
inputs = ["我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"] # PRD_SAMPLE
|
||||
for prd in inputs:
|
||||
await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO)
|
||||
await context.repo.docs.prd.save(filename="new_prd.txt", content=prd)
|
||||
|
||||
design_api = WriteDesign()
|
||||
design_api = WriteDesign(context=context)
|
||||
|
||||
result = await design_api.run(Message(content=prd, instruct_content=None))
|
||||
logger.info(result)
|
||||
|
|
|
|||
46
tests/metagpt/actions/test_design_api_an.py
Normal file
46
tests/metagpt/actions/test_design_api_an.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/03
|
||||
@Author : mannaandpoem
|
||||
@File : test_design_api_an.py
|
||||
"""
|
||||
import pytest
|
||||
from openai._models import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode, dict_to_markdown
|
||||
from metagpt.actions.design_api import NEW_REQ_TEMPLATE
|
||||
from metagpt.actions.design_api_an import REFINED_DESIGN_NODE
|
||||
from metagpt.llm import LLM
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
DESIGN_SAMPLE,
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_PRD_JSON,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm():
|
||||
return LLM()
|
||||
|
||||
|
||||
def mock_refined_design_json():
|
||||
return REFINED_DESIGN_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_design_an(mocker):
|
||||
root = ActionNode.from_children(
|
||||
"RefinedDesignAPI", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_refined_design_json
|
||||
mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON))
|
||||
node = await REFINED_DESIGN_NODE.fill(prompt, llm)
|
||||
|
||||
assert "Refined Implementation Approach" in node.instruct_content.model_dump()
|
||||
assert "Refined File list" in node.instruct_content.model_dump()
|
||||
assert "Refined Data structures and interfaces" in node.instruct_content.model_dump()
|
||||
assert "Refined Program call flow" in node.instruct_content.model_dump()
|
||||
|
|
@ -11,7 +11,7 @@ from metagpt.actions.design_api_review import DesignReview
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api_review():
|
||||
async def test_design_api_review(context):
|
||||
prd = "我们需要一个音乐播放器,它应该有播放、暂停、上一曲、下一曲等功能。"
|
||||
api_design = """
|
||||
数据结构:
|
||||
|
|
@ -26,7 +26,7 @@ API列表:
|
|||
"""
|
||||
_ = "API设计看起来非常合理,满足了PRD中的所有需求。"
|
||||
|
||||
design_api_review = DesignReview()
|
||||
design_api_review = DesignReview(context=context)
|
||||
|
||||
result = await design_api_review.run(prd, api_design)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,6 @@ from metagpt.actions.fix_bug import FixBug
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fix_bug():
|
||||
fix_bug = FixBug()
|
||||
async def test_fix_bug(context):
|
||||
fix_bug = FixBug(context=context)
|
||||
assert fix_bug.name == "FixBug"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from metagpt.actions.generate_questions import GenerateQuestions
|
||||
from metagpt.logs import logger
|
||||
|
||||
context = """
|
||||
msg = """
|
||||
## topic
|
||||
如何做一个生日蛋糕
|
||||
|
||||
|
|
@ -20,9 +20,9 @@ context = """
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_questions():
|
||||
action = GenerateQuestions()
|
||||
rsp = await action.run(context)
|
||||
async def test_generate_questions(context):
|
||||
action = GenerateQuestions(context=context)
|
||||
rsp = await action.run(msg)
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
assert "Questions" in rsp.content
|
||||
|
|
|
|||
|
|
@ -23,9 +23,9 @@ from metagpt.const import TEST_DATA_PATH
|
|||
Path("invoices/invoice-4.zip"),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr(invoice_path: Path):
|
||||
async def test_invoice_ocr(invoice_path: Path, context):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
resp = await InvoiceOCR().run(file_path=Path(invoice_path))
|
||||
resp = await InvoiceOCR(context=context).run(file_path=Path(invoice_path))
|
||||
assert isinstance(resp, list)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,52 +0,0 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.write_analysis_code import MakeTools
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools():
|
||||
code = "import yfinance as yf\n\n# Collect Alibaba stock data\nalibaba = yf.Ticker('BABA')\ndata = alibaba.history(period='1d', start='2022-01-01', end='2022-12-31')\nprint(data.head())"
|
||||
msgs = [{"role": "assistant", "content": code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = "!pip install yfinance\n" + tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools2():
|
||||
code = """import pandas as pd\npath = "./tests/data/test.csv"\ndf = pd.read_csv(path)\ndata = df.copy()\n
|
||||
data['started_at'] = data['started_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['ended_at'] = data['ended_at'].apply(lambda r: pd.to_datetime(r))\ndata.head()"""
|
||||
msgs = [{"role": "assistant", "content": code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools3():
|
||||
code = """import pandas as pd\npath = "./tests/data/test.csv"\ndf = pd.read_csv(path)\ndata = df.copy()\n
|
||||
data['started_at'] = data['started_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['ended_at'] = data['ended_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['duration_hour'] = (data['ended_at'] - data['started_at']).dt.seconds/3600\ndata.head()"""
|
||||
msgs = [{"role": "assistant", "content": code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
46
tests/metagpt/actions/test_ml_action.py
Normal file
46
tests/metagpt/actions/test_ml_action.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.ml_action import WriteCodeWithToolsML
|
||||
from metagpt.schema import Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
task_map = {
|
||||
"1": Task(
|
||||
task_id="1",
|
||||
instruction="随机生成一个pandas DataFrame数据集",
|
||||
task_type="other",
|
||||
dependent_task_ids=[],
|
||||
code="""
|
||||
import pandas as pd
|
||||
df = pd.DataFrame({
|
||||
'a': [1, 2, 3, 4, 5],
|
||||
'b': [1.1, 2.2, 3.3, 4.4, np.nan],
|
||||
'c': ['aa', 'bb', 'cc', 'dd', 'ee'],
|
||||
'd': [1, 2, 3, 4, 5]
|
||||
})
|
||||
""",
|
||||
is_finished=True,
|
||||
),
|
||||
"2": Task(
|
||||
task_id="2",
|
||||
instruction="对数据集进行数据清洗",
|
||||
task_type="data_preprocess",
|
||||
dependent_task_ids=["1"],
|
||||
),
|
||||
}
|
||||
plan = Plan(
|
||||
goal="构造数据集并进行数据清洗",
|
||||
tasks=list(task_map.values()),
|
||||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
_, code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
code_with_ml = code_with_ml["code"]
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
|
|
@ -9,22 +9,19 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import Context
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_documents():
|
||||
msg = Message(content="New user requirements balabala...")
|
||||
context = Context()
|
||||
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.delete_repository()
|
||||
CONFIG.git_repo = None
|
||||
|
||||
await PrepareDocuments().run(with_messages=[msg])
|
||||
assert CONFIG.git_repo
|
||||
doc = await FileRepository.get_file(filename=REQUIREMENT_FILENAME, relative_path=DOCS_FILE_REPO)
|
||||
await PrepareDocuments(context=context).run(with_messages=[msg])
|
||||
assert context.git_repo
|
||||
assert context.repo
|
||||
doc = await context.repo.docs.get(filename=REQUIREMENT_FILENAME)
|
||||
assert doc
|
||||
assert doc.content == msg.content
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from metagpt.logs import logger
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prepare_interview():
|
||||
action = PrepareInterview()
|
||||
async def test_prepare_interview(context):
|
||||
action = PrepareInterview(context=context)
|
||||
rsp = await action.run("I just graduated and hope to find a job as a Python engineer")
|
||||
logger.info(f"{rsp.content=}")
|
||||
|
||||
|
|
|
|||
|
|
@ -9,21 +9,18 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_json import DESIGN, PRD
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_design_api():
|
||||
await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO)
|
||||
await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
logger.info(CONFIG.git_repo)
|
||||
async def test_design_api(context):
|
||||
await context.repo.docs.prd.save("1.txt", content=str(PRD))
|
||||
await context.repo.docs.system_design.save("1.txt", content=str(DESIGN))
|
||||
logger.info(context.git_repo)
|
||||
|
||||
action = WriteTasks()
|
||||
action = WriteTasks(context=context)
|
||||
|
||||
result = await action.run(Message(content="", instruct_content=None))
|
||||
logger.info(result)
|
||||
|
|
|
|||
45
tests/metagpt/actions/test_project_management_an.py
Normal file
45
tests/metagpt/actions/test_project_management_an.py
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/03
|
||||
@Author : mannaandpoem
|
||||
@File : test_project_management_an.py
|
||||
"""
|
||||
import pytest
|
||||
from openai._models import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode, dict_to_markdown
|
||||
from metagpt.actions.project_management import NEW_REQ_TEMPLATE
|
||||
from metagpt.actions.project_management_an import REFINED_PM_NODE
|
||||
from metagpt.llm import LLM
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
REFINED_DESIGN_JSON,
|
||||
REFINED_TASKS_JSON,
|
||||
TASKS_SAMPLE,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm():
|
||||
return LLM()
|
||||
|
||||
|
||||
def mock_refined_tasks_json():
|
||||
return REFINED_TASKS_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_management_an(mocker):
|
||||
root = ActionNode.from_children(
|
||||
"RefinedProjectManagement", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_refined_tasks_json
|
||||
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_task=TASKS_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
|
||||
node = await REFINED_PM_NODE.fill(prompt, llm)
|
||||
|
||||
assert "Refined Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Refined Task list" in node.instruct_content.model_dump()
|
||||
assert "Refined Shared Knowledge" in node.instruct_content.model_dump()
|
||||
|
|
@ -11,19 +11,19 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_class_view import RebuildClassView
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
async def test_rebuild(context):
|
||||
action = RebuildClassView(
|
||||
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
name="RedBean",
|
||||
i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"),
|
||||
llm=LLM(),
|
||||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -10,33 +10,30 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.rebuild_sequence_view import RebuildSequenceView
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import GRAPH_REPO_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rebuild():
|
||||
@pytest.mark.skip
|
||||
async def test_rebuild(context):
|
||||
# Mock
|
||||
data = await aread(filename=Path(__file__).parent / "../../data/graph_db/networkx.json")
|
||||
graph_db_filename = Path(CONFIG.git_repo.workdir.name).with_suffix(".json")
|
||||
await FileRepository.save_file(
|
||||
filename=str(graph_db_filename),
|
||||
relative_path=GRAPH_REPO_FILE_REPO,
|
||||
content=data,
|
||||
)
|
||||
CONFIG.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
CONFIG.git_repo.commit("commit1")
|
||||
graph_db_filename = Path(context.repo.workdir.name).with_suffix(".json")
|
||||
await context.repo.docs.graph_repo.save(filename=str(graph_db_filename), content=data)
|
||||
context.git_repo.add_change({f"{GRAPH_REPO_FILE_REPO}/{graph_db_filename}": ChangeType.UNTRACTED})
|
||||
context.git_repo.commit("commit1")
|
||||
|
||||
action = RebuildSequenceView(
|
||||
name="RedBean", context=str(Path(__file__).parent.parent.parent.parent / "metagpt"), llm=LLM()
|
||||
name="RedBean",
|
||||
i_context=str(Path(__file__).parent.parent.parent.parent / "metagpt"),
|
||||
llm=LLM(),
|
||||
context=context,
|
||||
)
|
||||
await action.run()
|
||||
graph_file_repo = CONFIG.git_repo.new_file_repository(relative_path=GRAPH_REPO_FILE_REPO)
|
||||
assert graph_file_repo.changed_files
|
||||
assert context.repo.docs.graph_repo.changed_files
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
|
|
|||
|
|
@ -9,10 +9,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import research
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links(mocker):
|
||||
async def test_collect_links(mocker, search_engine_mocker, context):
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
|
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
|
|||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks().run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), context=context).run(
|
||||
"The application of MetaGPT"
|
||||
)
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links_with_rank_func(mocker):
|
||||
async def test_collect_links_with_rank_func(mocker, search_engine_mocker, context):
|
||||
rank_before = []
|
||||
rank_after = []
|
||||
url_per_query = 4
|
||||
|
|
@ -45,14 +49,16 @@ async def test_collect_links_with_rank_func(mocker):
|
|||
return results
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(
|
||||
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func, context=context
|
||||
).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
assert [i["link"] for i in y] == z
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_web_browse_and_summarize(mocker):
|
||||
async def test_web_browse_and_summarize(mocker, context):
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "metagpt"
|
||||
|
||||
|
|
@ -60,20 +66,20 @@ async def test_web_browse_and_summarize(mocker):
|
|||
url = "https://github.com/geekan/MetaGPT"
|
||||
url2 = "https://github.com/trending"
|
||||
query = "What's new in metagpt"
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
resp = await research.WebBrowseAndSummarize(context=context).run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
assert resp[url] == "metagpt"
|
||||
|
||||
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
|
||||
resp = await research.WebBrowseAndSummarize(context=context).run(url, url2, query=query)
|
||||
assert len(resp) == 2
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
return "Not relevant."
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.WebBrowseAndSummarize().run(url, query=query)
|
||||
resp = await research.WebBrowseAndSummarize(context=context).run(url, query=query)
|
||||
|
||||
assert len(resp) == 1
|
||||
assert url in resp
|
||||
|
|
@ -81,7 +87,7 @@ async def test_web_browse_and_summarize(mocker):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conduct_research(mocker):
|
||||
async def test_conduct_research(mocker, context):
|
||||
data = None
|
||||
|
||||
async def mock_llm_ask(*args, **kwargs):
|
||||
|
|
@ -95,7 +101,7 @@ async def test_conduct_research(mocker):
|
|||
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
|
||||
)
|
||||
|
||||
resp = await research.ConductResearch().run("The application of MetaGPT", content)
|
||||
resp = await research.ConductResearch(context=context).run("The application of MetaGPT", content)
|
||||
assert resp == data
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,19 +24,19 @@ async def test_run_text():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_script():
|
||||
async def test_run_script(context):
|
||||
# Successful command
|
||||
out, err = await RunCode.run_script(".", command=["echo", "Hello World"])
|
||||
out, err = await RunCode(context=context).run_script(".", command=["echo", "Hello World"])
|
||||
assert out.strip() == "Hello World"
|
||||
assert err == ""
|
||||
|
||||
# Unsuccessful command
|
||||
out, err = await RunCode.run_script(".", command=["python", "-c", "print(1/0)"])
|
||||
out, err = await RunCode(context=context).run_script(".", command=["python", "-c", "print(1/0)"])
|
||||
assert "ZeroDivisionError" in err
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
async def test_run(context):
|
||||
inputs = [
|
||||
(RunCodeContext(mode="text", code_filename="a.txt", code="print('Hello, World')"), "PASS"),
|
||||
(
|
||||
|
|
@ -61,5 +61,5 @@ async def test_run():
|
|||
),
|
||||
]
|
||||
for ctx, result in inputs:
|
||||
rsp = await RunCode(context=ctx).run()
|
||||
rsp = await RunCode(i_context=ctx, context=context).run()
|
||||
assert result in rsp.summary
|
||||
|
|
|
|||
|
|
@ -47,18 +47,18 @@ class TestSkillAction:
|
|||
assert args.get("size_type") == "512x512"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parser_action(self, mocker):
|
||||
async def test_parser_action(self, mocker, context):
|
||||
# mock
|
||||
mocker.patch("metagpt.learn.text_to_image", return_value="https://mock.com/xxx")
|
||||
|
||||
parser_action = ArgumentsParingAction(skill=self.skill, ask="Draw an apple")
|
||||
parser_action = ArgumentsParingAction(skill=self.skill, ask="Draw an apple", context=context)
|
||||
rsp = await parser_action.run()
|
||||
assert rsp
|
||||
assert parser_action.args
|
||||
assert parser_action.args.get("text") == "Draw an apple"
|
||||
assert parser_action.args.get("size_type") == "512x512"
|
||||
|
||||
action = SkillAction(skill=self.skill, args=parser_action.args)
|
||||
action = SkillAction(skill=self.skill, args=parser_action.args, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert "image/png;base64," in rsp.content or "http" in rsp.content
|
||||
|
|
@ -81,8 +81,8 @@ class TestSkillAction:
|
|||
await SkillAction.find_and_call_function("dummy_call", {"a": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_action_error(self):
|
||||
action = SkillAction(skill=self.skill, args={})
|
||||
async def test_skill_action_error(self, context):
|
||||
action = SkillAction(skill=self.skill, args={}, context=context)
|
||||
rsp = await action.run()
|
||||
assert "Error" in rsp.content
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@
|
|||
@File : test_summarize_code.py
|
||||
@Modifiled By: mashenquan, 2023-12-6. Unit test for summarize_code.py
|
||||
"""
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
DESIGN_CONTENT = """
|
||||
{"Implementation approach": "To develop this snake game, we will use the Python language and choose the Pygame library. Pygame is an open-source Python module collection specifically designed for writing video games. It provides functionalities such as displaying images and playing sounds, making it suitable for creating intuitive and responsive user interfaces. We will ensure efficient game logic to prevent any delays during gameplay. The scoring system will be simple, with the snake gaining points for each food it eats. We will use Pygame's event handling system to implement pause and resume functionality, as well as high-score tracking. The difficulty will increase by speeding up the snake's movement. In the initial version, we will focus on single-player mode and consider adding multiplayer mode and customizable skins in future updates. Based on the new requirement, we will also add a moving obstacle that appears randomly. If the snake eats this obstacle, the game will end. If the snake does not eat the obstacle, it will disappear after 5 seconds. For this, we need to add mechanisms for obstacle generation, movement, and disappearance in the game logic.", "Project_name": "snake_game", "File list": ["main.py", "game.py", "snake.py", "food.py", "obstacle.py", "scoreboard.py", "constants.py", "assets/styles.css", "assets/index.html"], "Data structures and interfaces": "```mermaid\n classDiagram\n class Game{\n +int score\n +int speed\n +bool game_over\n +bool paused\n +Snake snake\n +Food food\n +Obstacle obstacle\n +Scoreboard scoreboard\n +start_game() void\n +pause_game() void\n +resume_game() void\n +end_game() void\n +increase_difficulty() void\n +update() void\n +render() void\n Game()\n }\n class Snake{\n +list body_parts\n +str direction\n +bool grow\n +move() void\n +grow() void\n +check_collision() bool\n Snake()\n }\n class Food{\n +tuple position\n +spawn() void\n Food()\n }\n class Obstacle{\n +tuple position\n +int lifetime\n +bool active\n +spawn() void\n +move() void\n +check_collision() bool\n +disappear() void\n Obstacle()\n }\n class Scoreboard{\n +int high_score\n +update_score(int) void\n +reset_score() void\n +load_high_score() void\n +save_high_score() void\n Scoreboard()\n }\n class Constants{\n }\n Game \"1\" -- \"1\" Snake: has\n Game \"1\" -- \"1\" Food: has\n Game \"1\" -- \"1\" Obstacle: has\n Game \"1\" -- \"1\" Scoreboard: has\n ```", "Program call flow": "```sequenceDiagram\n participant M as Main\n participant G as Game\n participant S as Snake\n participant F as Food\n participant O as Obstacle\n participant SB as Scoreboard\n M->>G: start_game()\n loop game loop\n G->>S: move()\n G->>S: check_collision()\n G->>F: spawn()\n G->>O: spawn()\n G->>O: move()\n G->>O: check_collision()\n G->>O: disappear()\n G->>SB: update_score(score)\n G->>G: update()\n G->>G: render()\n alt if paused\n M->>G: pause_game()\n M->>G: resume_game()\n end\n alt if game_over\n G->>M: end_game()\n end\n end\n```", "Anything UNCLEAR": "There is no need for further clarification as the requirements are already clear."}
|
||||
|
|
@ -177,19 +177,28 @@ class Snake:
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summarize_code():
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "src"
|
||||
await FileRepository.save_file(filename="1.json", relative_path=SYSTEM_DESIGN_FILE_REPO, content=DESIGN_CONTENT)
|
||||
await FileRepository.save_file(filename="1.json", relative_path=TASK_FILE_REPO, content=TASK_CONTENT)
|
||||
await FileRepository.save_file(filename="food.py", relative_path=CONFIG.src_workspace, content=FOOD_PY)
|
||||
await FileRepository.save_file(filename="game.py", relative_path=CONFIG.src_workspace, content=GAME_PY)
|
||||
await FileRepository.save_file(filename="main.py", relative_path=CONFIG.src_workspace, content=MAIN_PY)
|
||||
await FileRepository.save_file(filename="snake.py", relative_path=CONFIG.src_workspace, content=SNAKE_PY)
|
||||
async def test_summarize_code(context):
|
||||
git_dir = Path(__file__).parent / f"unittest/{uuid.uuid4().hex}"
|
||||
git_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
all_files = src_file_repo.all_files
|
||||
ctx = CodeSummarizeContext(design_filename="1.json", task_filename="1.json", codes_filenames=all_files)
|
||||
action = SummarizeCode(context=ctx)
|
||||
context.src_workspace = context.git_repo.workdir / "src"
|
||||
await context.repo.docs.system_design.save(filename="1.json", content=DESIGN_CONTENT)
|
||||
await context.repo.docs.task.save(filename="1.json", content=TASK_CONTENT)
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(filename="food.py", content=FOOD_PY)
|
||||
assert context.repo.srcs.workdir == context.src_workspace
|
||||
await context.repo.srcs.save(filename="game.py", content=GAME_PY)
|
||||
await context.repo.srcs.save(filename="main.py", content=MAIN_PY)
|
||||
await context.repo.srcs.save(filename="snake.py", content=SNAKE_PY)
|
||||
|
||||
all_files = context.repo.srcs.all_files
|
||||
summarization_context = CodeSummarizeContext(
|
||||
design_filename="1.json", task_filename="1.json", codes_filenames=all_files
|
||||
)
|
||||
action = SummarizeCode(context=context, i_context=summarization_context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
logger.info(rsp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -9,13 +9,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("agent_description", "language", "context", "knowledge", "history_summary"),
|
||||
("agent_description", "language", "talk_context", "knowledge", "history_summary"),
|
||||
[
|
||||
(
|
||||
"mathematician",
|
||||
|
|
@ -33,12 +32,12 @@ from metagpt.schema import Message
|
|||
),
|
||||
],
|
||||
)
|
||||
async def test_prompt(agent_description, language, context, knowledge, history_summary):
|
||||
async def test_prompt(agent_description, language, talk_context, knowledge, history_summary, context):
|
||||
# Prerequisites
|
||||
CONFIG.agent_description = agent_description
|
||||
CONFIG.language = language
|
||||
context.kwargs.agent_description = agent_description
|
||||
context.kwargs.language = language
|
||||
|
||||
action = TalkAction(context=context, knowledge=knowledge, history_summary=history_summary)
|
||||
action = TalkAction(i_context=talk_context, knowledge=knowledge, history_summary=history_summary, context=context)
|
||||
assert "{" not in action.prompt
|
||||
assert "{" not in action.prompt_gpt4
|
||||
|
||||
|
|
|
|||
|
|
@ -3,16 +3,13 @@ import asyncio
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.write_analysis_code import (
|
||||
WriteCodeByGenerate,
|
||||
WriteCodeWithTools,
|
||||
WriteCodeWithToolsML,
|
||||
)
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
|
||||
from metagpt.logs import logger
|
||||
from metagpt.plan.planner import STRUCTURAL_CONTEXT
|
||||
from metagpt.schema import Message, Plan, Task
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_by_list_plan():
|
||||
write_code = WriteCodeByGenerate()
|
||||
|
|
@ -23,35 +20,31 @@ async def test_write_code_by_list_plan():
|
|||
print(f"\n任务: {task}\n\n")
|
||||
messages.append(Message(task, role="assistant"))
|
||||
code = await write_code.run(messages)
|
||||
messages.append(Message(code, role="assistant"))
|
||||
messages.append(Message(code["code"], role="assistant"))
|
||||
assert len(code) > 0
|
||||
output = await execute_code.run(code)
|
||||
output = await execute_code.run(code["code"])
|
||||
print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n")
|
||||
messages.append(output[0])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_recommendation():
|
||||
task = "对已经读取的数据集进行数据清洗"
|
||||
code_steps = """
|
||||
step 1: 对数据集进行去重
|
||||
step 2: 对数据集进行缺失值处理
|
||||
"""
|
||||
task = "clean and preprocess the data"
|
||||
code_steps = ""
|
||||
available_tools = {
|
||||
"fill_missing_value": "Completing missing values with simple strategies",
|
||||
"split_bins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
"FillMissingValue": "Filling missing values",
|
||||
"SplitBins": "Bin continuous data into intervals and return the bin identifier encoded as an integer value",
|
||||
}
|
||||
write_code = WriteCodeWithTools()
|
||||
tools = await write_code._tool_recommendation(task, code_steps, available_tools)
|
||||
tools = await write_code._recommend_tool(task, code_steps, available_tools)
|
||||
|
||||
assert len(tools) == 1
|
||||
assert tools[0] == "fill_missing_value"
|
||||
assert "FillMissingValue" in tools
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_with_tools():
|
||||
write_code = WriteCodeWithTools()
|
||||
write_code_ml = WriteCodeWithToolsML()
|
||||
|
||||
requirement = "构造数据集并进行数据清洗"
|
||||
task_map = {
|
||||
|
|
@ -84,7 +77,6 @@ async def test_write_code_with_tools():
|
|||
task_map=task_map,
|
||||
current_task_id="2",
|
||||
)
|
||||
column_info = ""
|
||||
|
||||
context = STRUCTURAL_CONTEXT.format(
|
||||
user_requirement=requirement,
|
||||
|
|
@ -95,13 +87,10 @@ async def test_write_code_with_tools():
|
|||
context_msg = [Message(content=context, role="user")]
|
||||
|
||||
code = await write_code.run(context_msg, plan)
|
||||
code = code["code"]
|
||||
assert len(code) > 0
|
||||
print(code)
|
||||
|
||||
code_with_ml = await write_code_ml.run([], plan, column_info)
|
||||
assert len(code_with_ml) > 0
|
||||
print(code_with_ml)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_to_correct_error():
|
||||
|
|
@ -150,6 +139,7 @@ async def test_write_code_to_correct_error():
|
|||
Message(content=error, role="user"),
|
||||
]
|
||||
new_code = await WriteCodeByGenerate().run(context=context)
|
||||
new_code = new_code["code"]
|
||||
print(new_code)
|
||||
assert "read_csv" in new_code # should correct read_excel to read_csv
|
||||
|
||||
|
|
@ -189,10 +179,12 @@ async def test_write_code_reuse_code_simple():
|
|||
Message(content=structural_context, role="user"),
|
||||
]
|
||||
code = await WriteCodeByGenerate().run(context=context)
|
||||
code = code["code"]
|
||||
print(code)
|
||||
assert "pandas" not in code and "read_csv" not in code # should reuse import and read statement from previous one
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long():
|
||||
"""test code reuse for long context"""
|
||||
|
|
@ -245,13 +237,14 @@ async def test_write_code_reuse_code_long():
|
|||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_iris" not in result and "iris_data" in result for result in trial_results
|
||||
"load_iris" not in result["code"] and "iris_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
assert success_rate >= 0.8
|
||||
|
||||
|
||||
@pytest.mark.skip
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_reuse_code_long_for_wine():
|
||||
"""test code reuse for long context"""
|
||||
|
|
@ -318,7 +311,7 @@ async def test_write_code_reuse_code_long_for_wine():
|
|||
trial_results = await asyncio.gather(*trials)
|
||||
print(*trial_results, sep="\n\n***\n\n")
|
||||
success = [
|
||||
"load_wine" not in result and "wine_data" in result for result in trial_results
|
||||
"load_wine" not in result["code"] and "wine_data" in result["code"] for result in trial_results
|
||||
] # should reuse iris_data from previous tasks
|
||||
success_rate = sum(success) / trials_num
|
||||
logger.info(f"success rate: {success_rate :.2f}")
|
||||
|
|
|
|||
|
|
@ -12,28 +12,22 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
CODE_SUMMARIES_FILE_REPO,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from metagpt.utils.common import aread
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code():
|
||||
context = CodingContext(
|
||||
async def test_write_code(context):
|
||||
# Prerequisites
|
||||
context.src_workspace = context.git_repo.workdir / "writecode"
|
||||
|
||||
coding_ctx = CodingContext(
|
||||
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
|
||||
)
|
||||
doc = Document(content=context.model_dump_json())
|
||||
write_code = WriteCode(context=doc)
|
||||
doc = Document(content=coding_ctx.model_dump_json())
|
||||
write_code = WriteCode(i_context=doc, context=context)
|
||||
|
||||
code = await write_code.run()
|
||||
logger.info(code.model_dump_json())
|
||||
|
|
@ -44,48 +38,44 @@ async def test_write_code():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_directly():
|
||||
async def test_write_code_directly(context):
|
||||
prompt = WRITE_CODE_PROMPT_SAMPLE + "\n" + TASKS_2[0]
|
||||
llm = LLM()
|
||||
llm = context.llm_with_cost_manager_from_llm_config(context.config.llm)
|
||||
rsp = await llm.aask(prompt)
|
||||
logger.info(rsp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deps():
|
||||
async def test_write_code_deps(context):
|
||||
# Prerequisites
|
||||
CONFIG.src_workspace = CONFIG.git_repo.workdir / "snake1/snake1"
|
||||
context.src_workspace = context.git_repo.workdir / "snake1/snake1"
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
await FileRepository.save_file(
|
||||
filename="test_game.py.json",
|
||||
content=await aread(str(demo_path / "test_game.py.json")),
|
||||
relative_path=TEST_OUTPUTS_FILE_REPO,
|
||||
await context.repo.test_outputs.save(
|
||||
filename="test_game.py.json", content=await aread(str(demo_path / "test_game.py.json"))
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
await context.repo.docs.code_summary.save(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "code_summaries.json")),
|
||||
relative_path=CODE_SUMMARIES_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
await context.repo.docs.system_design.save(
|
||||
filename="20231221155954.json",
|
||||
content=await aread(str(demo_path / "system_design.json")),
|
||||
relative_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json")), relative_path=TASK_FILE_REPO
|
||||
await context.repo.docs.task.save(
|
||||
filename="20231221155954.json", content=await aread(str(demo_path / "tasks.json"))
|
||||
)
|
||||
await FileRepository.save_file(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()', relative_path=CONFIG.src_workspace
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(
|
||||
filename="main.py", content='if __name__ == "__main__":\nmain()'
|
||||
)
|
||||
context = CodingContext(
|
||||
ccontext = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=SYSTEM_DESIGN_FILE_REPO),
|
||||
task_doc=await FileRepository.get_file(filename="20231221155954.json", relative_path=TASK_FILE_REPO),
|
||||
design_doc=await context.repo.docs.system_design.get(filename="20231221155954.json"),
|
||||
task_doc=await context.repo.docs.task.get(filename="20231221155954.json"),
|
||||
code_doc=Document(filename="game.py", content="", root_path="snake1"),
|
||||
)
|
||||
coding_doc = Document(root_path="snake1", filename="game.py", content=context.json())
|
||||
coding_doc = Document(root_path="snake1", filename="game.py", content=ccontext.json())
|
||||
|
||||
action = WriteCode(context=coding_doc)
|
||||
action = WriteCode(i_context=coding_doc, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
assert rsp.code_doc.content
|
||||
|
|
|
|||
71
tests/metagpt/actions/test_write_code_plan_and_change_an.py
Normal file
71
tests/metagpt/actions/test_write_code_plan_and_change_an.py
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/03
|
||||
@Author : mannaandpoem
|
||||
@File : test_write_code_plan_and_change_an.py
|
||||
"""
|
||||
import pytest
|
||||
from openai._models import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.actions.write_code_plan_and_change_an import (
|
||||
REFINED_TEMPLATE,
|
||||
WriteCodePlanAndChange,
|
||||
)
|
||||
from metagpt.schema import CodePlanAndChangeContext
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
CODE_PLAN_AND_CHANGE_SAMPLE,
|
||||
DESIGN_SAMPLE,
|
||||
NEW_REQUIREMENT_SAMPLE,
|
||||
REFINED_CODE_INPUT_SAMPLE,
|
||||
REFINED_CODE_SAMPLE,
|
||||
TASKS_SAMPLE,
|
||||
)
|
||||
|
||||
|
||||
def mock_code_plan_and_change():
|
||||
return CODE_PLAN_AND_CHANGE_SAMPLE
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_plan_and_change_an(mocker):
|
||||
root = ActionNode.from_children(
|
||||
"WriteCodePlanAndChange", [ActionNode(key="", expected_type=str, instruction="", example="")]
|
||||
)
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_code_plan_and_change
|
||||
mocker.patch("metagpt.actions.write_code_plan_and_change_an.WriteCodePlanAndChange.run", return_value=root)
|
||||
|
||||
requirement = "New requirement"
|
||||
prd_filename = "prd.md"
|
||||
design_filename = "design.md"
|
||||
task_filename = "task.md"
|
||||
code_plan_and_change_context = CodePlanAndChangeContext(
|
||||
requirement=requirement,
|
||||
prd_filename=prd_filename,
|
||||
design_filename=design_filename,
|
||||
task_filename=task_filename,
|
||||
)
|
||||
node = await WriteCodePlanAndChange(i_context=code_plan_and_change_context).run()
|
||||
|
||||
assert "Code Plan And Change" in node.instruct_content.model_dump()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refine_code(mocker):
|
||||
mocker.patch.object(WriteCode, "_aask", return_value=REFINED_CODE_SAMPLE)
|
||||
prompt = REFINED_TEMPLATE.format(
|
||||
user_requirement=NEW_REQUIREMENT_SAMPLE,
|
||||
code_plan_and_change=CODE_PLAN_AND_CHANGE_SAMPLE,
|
||||
design=DESIGN_SAMPLE,
|
||||
task=TASKS_SAMPLE,
|
||||
code=REFINED_CODE_INPUT_SAMPLE,
|
||||
logs="",
|
||||
feedback="",
|
||||
filename="game.py",
|
||||
summary_log="",
|
||||
)
|
||||
code = await WriteCode().write_code(prompt=prompt)
|
||||
assert "def" in code
|
||||
|
|
@ -12,28 +12,25 @@ from metagpt.schema import CodingContext, Document
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review(capfd):
|
||||
async def test_write_code_review(capfd, context):
|
||||
context.src_workspace = context.repo.workdir / "srcs"
|
||||
code = """
|
||||
def add(a, b):
|
||||
return a +
|
||||
"""
|
||||
context = CodingContext(
|
||||
coding_context = CodingContext(
|
||||
filename="math.py", design_doc=Document(content="编写一个从a加b的函数,返回a+b"), code_doc=Document(content=code)
|
||||
)
|
||||
|
||||
context = await WriteCodeReview(context=context).run()
|
||||
await WriteCodeReview(i_context=coding_context, context=context).run()
|
||||
|
||||
# 我们不能精确地预测生成的代码评审,但我们可以检查返回的是否为字符串
|
||||
assert isinstance(context.code_doc.content, str)
|
||||
assert len(context.code_doc.content) > 0
|
||||
assert isinstance(coding_context.code_doc.content, str)
|
||||
assert len(coding_context.code_doc.content) > 0
|
||||
|
||||
captured = capfd.readouterr()
|
||||
print(f"输出内容: {captured.out}")
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_write_code_review_directly():
|
||||
# code = SEARCH_CODE_SAMPLE
|
||||
# write_code_review = WriteCodeReview("write_code_review")
|
||||
# review = await write_code_review.run(code)
|
||||
# logger.info(review)
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -27,8 +27,8 @@ class Person:
|
|||
],
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_write_docstring(style: str, part: str):
|
||||
ret = await WriteDocstring().run(code, style=style)
|
||||
async def test_write_docstring(style: str, part: str, context):
|
||||
ret = await WriteDocstring(context=context).run(code, style=style)
|
||||
assert part in ret
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,21 +9,19 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, PRDS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import RoleReactMode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd(new_filename):
|
||||
product_manager = ProductManager()
|
||||
async def test_write_prd(new_filename, context):
|
||||
product_manager = ProductManager(context=context)
|
||||
requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结"
|
||||
await FileRepository.save_file(filename=REQUIREMENT_FILENAME, content=requirements, relative_path=DOCS_FILE_REPO)
|
||||
await context.repo.docs.save(filename=REQUIREMENT_FILENAME, content=requirements)
|
||||
product_manager.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
prd = await product_manager.run(Message(content=requirements, cause_by=UserRequirement))
|
||||
assert prd.cause_by == any_to_str(WritePRD)
|
||||
|
|
@ -33,7 +31,7 @@ async def test_write_prd(new_filename):
|
|||
# Assert the prd is not None or empty
|
||||
assert prd is not None
|
||||
assert prd.content != ""
|
||||
assert CONFIG.git_repo.new_file_repository(relative_path=PRDS_FILE_REPO).changed_files
|
||||
assert product_manager.context.repo.docs.prd.changed_files
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
48
tests/metagpt/actions/test_write_prd_an.py
Normal file
48
tests/metagpt/actions/test_write_prd_an.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/01/03
|
||||
@Author : mannaandpoem
|
||||
@File : test_write_prd_an.py
|
||||
"""
|
||||
import pytest
|
||||
from openai._models import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_prd import NEW_REQ_TEMPLATE
|
||||
from metagpt.actions.write_prd_an import REFINED_PRD_NODE
|
||||
from metagpt.llm import LLM
|
||||
from tests.data.incremental_dev_project.mock import (
|
||||
NEW_REQUIREMENT_SAMPLE,
|
||||
PRD_SAMPLE,
|
||||
REFINED_PRD_JSON,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def llm():
|
||||
return LLM()
|
||||
|
||||
|
||||
def mock_refined_prd_json():
|
||||
return REFINED_PRD_JSON
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_an(mocker):
|
||||
root = ActionNode.from_children("RefinedPRD", [ActionNode(key="", expected_type=str, instruction="", example="")])
|
||||
root.instruct_content = BaseModel()
|
||||
root.instruct_content.model_dump = mock_refined_prd_json
|
||||
mocker.patch("metagpt.actions.write_prd_an.REFINED_PRD_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(
|
||||
requirements=NEW_REQUIREMENT_SAMPLE,
|
||||
old_prd=PRD_SAMPLE,
|
||||
)
|
||||
node = await REFINED_PRD_NODE.fill(prompt, llm)
|
||||
|
||||
assert "Refined Requirements" in node.instruct_content.model_dump()
|
||||
assert "Refined Product Goals" in node.instruct_content.model_dump()
|
||||
assert "Refined User Stories" in node.instruct_content.model_dump()
|
||||
assert "Refined Requirement Analysis" in node.instruct_content.model_dump()
|
||||
assert "Refined Requirement Pool" in node.instruct_content.model_dump()
|
||||
|
|
@ -11,7 +11,7 @@ from metagpt.actions.write_prd_review import WritePRDReview
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_prd_review():
|
||||
async def test_write_prd_review(context):
|
||||
prd = """
|
||||
Introduction: This is a new feature for our product.
|
||||
Goals: The goal is to improve user engagement.
|
||||
|
|
@ -23,7 +23,7 @@ async def test_write_prd_review():
|
|||
Timeline: The feature should be ready for testing in 1.5 months.
|
||||
"""
|
||||
|
||||
write_prd_review = WritePRDReview(name="write_prd_review")
|
||||
write_prd_review = WritePRDReview(name="write_prd_review", context=context)
|
||||
|
||||
prd_review = await write_prd_review.run(prd)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_review import WriteReview
|
||||
|
||||
CONTEXT = """
|
||||
TEMPLATE_CONTEXT = """
|
||||
{
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
|
|
@ -46,8 +46,8 @@ CONTEXT = """
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_review():
|
||||
write_review = WriteReview()
|
||||
review = await write_review.run(CONTEXT)
|
||||
async def test_write_review(context):
|
||||
write_review = WriteReview(context=context)
|
||||
review = await write_review.run(TEMPLATE_CONTEXT)
|
||||
assert review.instruct_content
|
||||
assert review.get("LGTM") in ["LGTM", "LBTM"]
|
||||
|
|
|
|||
|
|
@ -13,11 +13,11 @@ from metagpt.actions.write_teaching_plan import WriteTeachingPlanPart
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("topic", "context"),
|
||||
("topic", "content"),
|
||||
[("Title", "Lesson 1: Learn to draw an apple."), ("Teaching Content", "Lesson 1: Learn to draw an apple.")],
|
||||
)
|
||||
async def test_write_teaching_plan_part(topic, context):
|
||||
action = WriteTeachingPlanPart(topic=topic, context=context)
|
||||
async def test_write_teaching_plan_part(topic, content, context):
|
||||
action = WriteTeachingPlanPart(topic=topic, i_context=content, context=context)
|
||||
rsp = await action.run()
|
||||
assert rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from metagpt.schema import Document, TestingContext
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_test():
|
||||
async def test_write_test(context):
|
||||
code = """
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
|
@ -25,8 +25,8 @@ async def test_write_test():
|
|||
def generate(self, max_y: int, max_x: int):
|
||||
self.position = (random.randint(1, max_y - 1), random.randint(1, max_x - 1))
|
||||
"""
|
||||
context = TestingContext(filename="food.py", code_doc=Document(filename="food.py", content=code))
|
||||
write_test = WriteTest(context=context)
|
||||
testing_context = TestingContext(filename="food.py", code_doc=Document(filename="food.py", content=code))
|
||||
write_test = WriteTest(i_context=testing_context, context=context)
|
||||
|
||||
context = await write_test.run()
|
||||
logger.info(context.model_dump_json())
|
||||
|
|
@ -39,12 +39,12 @@ async def test_write_test():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_invalid_code(mocker):
|
||||
async def test_write_code_invalid_code(mocker, context):
|
||||
# Mock the _aask method to return an invalid code string
|
||||
mocker.patch.object(WriteTest, "_aask", return_value="Invalid Code String")
|
||||
|
||||
# Create an instance of WriteTest
|
||||
write_test = WriteTest()
|
||||
write_test = WriteTest(context=context)
|
||||
|
||||
# Call the write_code method
|
||||
code = await write_test.write_code("Some prompt:")
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory(language: str, topic: str):
|
||||
ret = await WriteDirectory(language=language).run(topic=topic)
|
||||
async def test_write_directory(language: str, topic: str, context):
|
||||
ret = await WriteDirectory(language=language, context=context).run(topic=topic)
|
||||
assert isinstance(ret, dict)
|
||||
assert "title" in ret
|
||||
assert "directory" in ret
|
||||
|
|
@ -29,8 +29,8 @@ async def test_write_directory(language: str, topic: str):
|
|||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content(language: str, topic: str, directory: Dict):
|
||||
ret = await WriteContent(language=language, directory=directory).run(topic=topic)
|
||||
async def test_write_content(language: str, topic: str, directory: Dict, context):
|
||||
ret = await WriteContent(language=language, directory=directory, context=context).run(topic=topic)
|
||||
assert isinstance(ret, str)
|
||||
assert list(directory.keys())[0] in ret
|
||||
for value in list(directory.values())[0]:
|
||||
|
|
|
|||
|
|
@ -10,13 +10,12 @@ from pathlib import Path
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.learn.skill_loader import SkillsDeclaration
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_suite():
|
||||
CONFIG.agent_skills = [
|
||||
async def test_suite(context):
|
||||
context.kwargs.agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
|
|
@ -27,7 +26,7 @@ async def test_suite():
|
|||
]
|
||||
pathname = Path(__file__).parent / "../../../docs/.well-known/skills.yaml"
|
||||
loader = await SkillsDeclaration.load(skill_yaml_file_name=pathname)
|
||||
skills = loader.get_skill_list()
|
||||
skills = loader.get_skill_list(context=context)
|
||||
assert skills
|
||||
assert len(skills) >= 3
|
||||
for desc, name in skills.items():
|
||||
|
|
|
|||
|
|
@ -6,19 +6,33 @@
|
|||
@File : test_text_to_embedding.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_embedding import text_to_embedding
|
||||
from metagpt.utils.common import aread
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_embedding():
|
||||
# Prerequisites
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
async def test_text_to_embedding(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.post")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
data = await aread(Path(__file__).parent / "../../data/openai/embedding.json")
|
||||
mock_response.json.return_value = json.loads(data)
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
config.get_openai_llm().proxy = mocker.PropertyMock(return_value="http://mock.proxy")
|
||||
|
||||
v = await text_to_embedding(text="Panda emoji")
|
||||
# Prerequisites
|
||||
assert config.get_openai_llm().api_key
|
||||
assert config.get_openai_llm().proxy
|
||||
|
||||
v = await text_to_embedding(text="Panda emoji", config=config)
|
||||
assert len(v.data) > 0
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@
|
|||
@File : test_text_to_image.py
|
||||
@Desc : Unit tests.
|
||||
"""
|
||||
import base64
|
||||
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_image import text_to_image
|
||||
from metagpt.tools.metagpt_text_to_image import MetaGPTText2Image
|
||||
from metagpt.tools.openai_text_to_image import OpenAIText2Image
|
||||
|
|
@ -24,23 +26,37 @@ async def test_text_to_image(mocker):
|
|||
mocker.patch.object(OpenAIText2Image, "text_2_image", return_value=b"mock OpenAIText2Image")
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock/s3")
|
||||
|
||||
# Prerequisites
|
||||
assert CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
assert CONFIG.OPENAI_API_KEY
|
||||
config = Config.default()
|
||||
assert config.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["METAGPT_TEXT_TO_IMAGE_MODEL_URL"] = None
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_image("Panda emoji", size_type="512x512")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_text_to_image(mocker):
|
||||
# mocker
|
||||
mock_url = mocker.Mock()
|
||||
mock_url.url.return_value = "http://mock.com/0.png"
|
||||
|
||||
class _MockData(BaseModel):
|
||||
data: list
|
||||
|
||||
mock_data = _MockData(data=[mock_url])
|
||||
mocker.patch.object(openai.resources.images.AsyncImages, "generate", return_value=mock_data)
|
||||
mock_post = mocker.patch("aiohttp.ClientSession.get")
|
||||
mock_response = mocker.AsyncMock()
|
||||
mock_response.status = 200
|
||||
mock_response.read.return_value = base64.b64encode(b"success")
|
||||
mock_post.return_value.__aenter__.return_value = mock_response
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png")
|
||||
|
||||
config = Config.default()
|
||||
config.METAGPT_TEXT_TO_IMAGE_MODEL_URL = None
|
||||
assert config.get_openai_llm()
|
||||
|
||||
data = await text_to_image("Panda emoji", size_type="512x512", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,35 +8,65 @@
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.learn.text_to_speech import text_to_speech
|
||||
from metagpt.tools.iflytek_tts import IFlyTekTTS
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
# Prerequisites
|
||||
assert CONFIG.IFLYTEK_APP_ID
|
||||
assert CONFIG.IFLYTEK_API_KEY
|
||||
assert CONFIG.IFLYTEK_API_SECRET
|
||||
assert CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert CONFIG.AZURE_TTS_REGION
|
||||
async def test_azure_text_to_speech(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
config.IFLYTEK_API_KEY = None
|
||||
config.IFLYTEK_API_SECRET = None
|
||||
config.IFLYTEK_APP_ID = None
|
||||
mock_result = mocker.Mock()
|
||||
mock_result.audio_data = b"mock audio data"
|
||||
mock_result.reason = ResultReason.SynthesizingAudioCompleted
|
||||
mock_data = mocker.Mock()
|
||||
mock_data.get.return_value = mock_result
|
||||
mocker.patch.object(SpeechSynthesizer, "speak_ssml_async", return_value=mock_data)
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.wav")
|
||||
|
||||
# Prerequisites
|
||||
assert not config.IFLYTEK_APP_ID
|
||||
assert not config.IFLYTEK_API_KEY
|
||||
assert not config.IFLYTEK_API_SECRET
|
||||
assert config.AZURE_TTS_SUBSCRIPTION_KEY and config.AZURE_TTS_SUBSCRIPTION_KEY != "YOUR_API_KEY"
|
||||
assert config.AZURE_TTS_REGION
|
||||
|
||||
config.copy()
|
||||
# test azure
|
||||
data = await text_to_speech("panda emoji")
|
||||
data = await text_to_speech("panda emoji", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
# test iflytek
|
||||
## Mock session env
|
||||
old_options = CONFIG.options.copy()
|
||||
new_options = old_options.copy()
|
||||
new_options["AZURE_TTS_SUBSCRIPTION_KEY"] = ""
|
||||
CONFIG.set_context(new_options)
|
||||
try:
|
||||
data = await text_to_speech("panda emoji")
|
||||
assert "base64" in data or "http" in data
|
||||
finally:
|
||||
CONFIG.set_context(old_options)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_iflytek_text_to_speech(mocker):
|
||||
# mock
|
||||
config = Config.default()
|
||||
config.AZURE_TTS_SUBSCRIPTION_KEY = None
|
||||
config.AZURE_TTS_REGION = None
|
||||
mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None)
|
||||
mock_data = mocker.AsyncMock()
|
||||
mock_data.read.return_value = b"mock iflytek"
|
||||
mock_reader = mocker.patch("aiofiles.open")
|
||||
mock_reader.return_value.__aenter__.return_value = mock_data
|
||||
mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/1.mp3")
|
||||
|
||||
# Prerequisites
|
||||
assert config.IFLYTEK_APP_ID
|
||||
assert config.IFLYTEK_API_KEY
|
||||
assert config.IFLYTEK_API_SECRET
|
||||
assert not config.AZURE_TTS_SUBSCRIPTION_KEY or config.AZURE_TTS_SUBSCRIPTION_KEY == "YOUR_API_KEY"
|
||||
assert not config.AZURE_TTS_REGION
|
||||
|
||||
# test azure
|
||||
data = await text_to_speech("panda emoji", config=config)
|
||||
assert "base64" in data or "http" in data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -46,7 +45,7 @@ def test_extract_info(input, tag, val):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("llm", [LLM(provider=LLMProviderEnum.OPENAI), LLM(provider=LLMProviderEnum.METAGPT)])
|
||||
@pytest.mark.parametrize("llm", [LLM()])
|
||||
async def test_memory_llm(llm):
|
||||
memory = BrainMemory()
|
||||
for i in range(500):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Desc : unittest of `metagpt/memory/longterm_memory.py`
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -10,17 +9,15 @@ import os
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.memory.longterm_memory import LongTermMemory
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
|
||||
|
||||
def test_ltm_search():
|
||||
assert hasattr(CONFIG, "long_term_memory") is True
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
assert len(CONFIG.openai_api_key) > 20
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def test_memory():
|
|||
messages = memory.get_by_action(UserRequirement)
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.get_by_actions([UserRequirement])
|
||||
messages = memory.get_by_actions({UserRequirement})
|
||||
assert len(messages) == 2
|
||||
|
||||
messages = memory.try_remember("test message")
|
||||
|
|
|
|||
|
|
@ -11,12 +11,12 @@ from typing import List
|
|||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
os.environ.setdefault("OPENAI_API_KEY", CONFIG.openai_api_key)
|
||||
os.environ.setdefault("OPENAI_API_KEY", config.get_openai_llm().api_key)
|
||||
|
||||
|
||||
def test_idea_message():
|
||||
|
|
|
|||
44
tests/metagpt/provider/mock_llm_config.py
Normal file
44
tests/metagpt/provider/mock_llm_config.py
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/1/8 17:03
|
||||
@Author : alexanderwu
|
||||
@File : mock_llm_config.py
|
||||
"""
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
|
||||
mock_llm_config = LLMConfig(
|
||||
llm_type="mock",
|
||||
api_key="mock_api_key",
|
||||
base_url="mock_base_url",
|
||||
app_id="mock_app_id",
|
||||
api_secret="mock_api_secret",
|
||||
domain="mock_domain",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_proxy = LLMConfig(
|
||||
llm_type="mock",
|
||||
api_key="mock_api_key",
|
||||
base_url="mock_base_url",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_azure = LLMConfig(
|
||||
llm_type="azure",
|
||||
api_version="2023-09-01-preview",
|
||||
api_key="mock_api_key",
|
||||
base_url="mock_base_url",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
||||
|
||||
mock_llm_config_zhipu = LLMConfig(
|
||||
llm_type="zhipu",
|
||||
api_key="mock_api_key.zhipu",
|
||||
base_url="mock_base_url",
|
||||
model="mock_zhipu_model",
|
||||
proxy="http://localhost:8080",
|
||||
)
|
||||
|
|
@ -6,10 +6,8 @@
|
|||
import pytest
|
||||
from anthropic.resources.completions import Completion
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
CONFIG.anthropic_api_key = "xxx"
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
|
@ -25,10 +23,10 @@ async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_
|
|||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
assert resp == Claude2(mock_llm_config).ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
assert resp == await Claude2(mock_llm_config).aask(prompt)
|
||||
|
|
|
|||
12
tests/metagpt/provider/test_azure_llm.py
Normal file
12
tests/metagpt/provider/test_azure_llm.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from metagpt.provider import AzureOpenAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_azure
|
||||
|
||||
|
||||
def test_azure_llm():
|
||||
llm = AzureOpenAILLM(mock_llm_config_azure)
|
||||
kwargs = llm._make_client_kwargs()
|
||||
assert kwargs["azure_endpoint"] == mock_llm_config_azure.base_url
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
|
||||
CONFIG.OPENAI_API_VERSION = "xx"
|
||||
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
|
||||
|
||||
|
||||
def test_azure_openai_api():
|
||||
_ = AzureOpenAILLM()
|
||||
|
|
@ -8,6 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -28,6 +29,9 @@ resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
|||
|
||||
|
||||
class MockBaseLLM(BaseLLM):
|
||||
def __init__(self, config: LLMConfig = None):
|
||||
pass
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
|
|
@ -102,5 +106,5 @@ async def test_async_base_llm():
|
|||
resp = await base_llm.aask_batch([prompt_msg])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await base_llm.aask_code([prompt_msg])
|
||||
assert resp == resp_content
|
||||
# resp = await base_llm.aask_code([prompt_msg])
|
||||
# assert resp == resp_content
|
||||
|
|
@ -13,17 +13,13 @@ from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
|||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireworksLLM,
|
||||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.fireworks_api_key = "xxx"
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
|
|
@ -92,7 +88,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
fireworks_gpt = FireworksLLM()
|
||||
fireworks_gpt = FireworksLLM(mock_llm_config)
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
|
|
@ -9,10 +9,8 @@ import pytest
|
|||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
|
||||
CONFIG.gemini_api_key = "xx"
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -62,7 +60,7 @@ async def test_gemini_acompletion(mocker):
|
|||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM()
|
||||
gemini_gpt = GeminiLLM(mock_llm_config)
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "test"
|
||||
resp_exit = "exit"
|
||||
|
|
@ -13,7 +14,7 @@ resp_exit = "exit"
|
|||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("builtins.input", lambda _: resp_content)
|
||||
human_provider = HumanProvider()
|
||||
human_provider = HumanProvider(mock_llm_config)
|
||||
|
||||
resp = human_provider.ask(resp_content)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -1,14 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_api.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMProviderEnum.METAGPT)
|
||||
assert llm
|
||||
|
|
@ -3,13 +3,14 @@
|
|||
"""
|
||||
@Time : 2023/8/30
|
||||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
@File : test_metagpt_llm.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTLLM()
|
||||
llm = MetaGPTLLM(mock_llm_config)
|
||||
assert llm
|
||||
|
||||
|
||||
|
|
@ -7,8 +7,8 @@ from typing import Any, Tuple
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
|
@ -16,9 +16,6 @@ messages = [{"role": "user", "content": prompt_msg}]
|
|||
resp_content = "I'm ollama"
|
||||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
CONFIG.max_budget = 10
|
||||
|
||||
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
if stream:
|
||||
|
|
@ -44,7 +41,7 @@ async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[An
|
|||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM()
|
||||
ollama_gpt = OllamaLLM(mock_llm_config)
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
|
|
|||
|
|
@ -13,12 +13,9 @@ from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
|||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
|
|
@ -71,7 +68,7 @@ async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs)
|
|||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
openllm_gpt = OpenLLM()
|
||||
openllm_gpt = OpenLLM(mock_llm_config)
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
|
|
@ -9,134 +7,92 @@ from openai.types.chat import (
|
|||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider import OpenAILLM
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
|
||||
CONFIG.openai_proxy = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_to_speech():
|
||||
llm = LLM()
|
||||
resp = await llm.atext_to_speech(
|
||||
model="tts-1",
|
||||
voice="alloy",
|
||||
input="人生说起来长,但直到一个岁月回头看,许多事件仅是仓促的。一段一段拼凑一起,合成了人生。苦难当头时,当下不免觉得是折磨;回头看,也不够是一段短短的人生旅程。",
|
||||
)
|
||||
assert 200 == resp.response.status_code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_speech_to_text():
|
||||
llm = LLM()
|
||||
audio_file = open(f"{TEST_DATA_PATH}/audio/hello.mp3", "rb")
|
||||
resp = await llm.aspeech_to_text(file=audio_file, model="whisper-1")
|
||||
assert "你好" == resp.text
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_calls_rsp():
|
||||
function_rsps = [
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'}', name="execute"),
|
||||
Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": """print("hello world")"""}', name="execute"),
|
||||
Function(arguments='\nprint("hello world")\\n', name="execute"),
|
||||
# only `{` in arguments
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
# no `{`, `}` in arguments
|
||||
Function(arguments='\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
]
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) for i, f in enumerate(function_rsps)
|
||||
]
|
||||
messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls]
|
||||
# 添加一个纯文本响应
|
||||
messages.append(
|
||||
ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None)
|
||||
)
|
||||
# 添加 openai tool calls respond bug, code 出现在ChatCompletionMessage.content中
|
||||
messages.extend(
|
||||
[
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(content="'''python\nprint('hello world')'''", role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(content='"""python\nprint(\'hello world\')"""', role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(content="'''python\nprint(\"hello world\")'''", role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages)
|
||||
]
|
||||
return [
|
||||
ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion")
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy=None,
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="other",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def config_azure_proxy(self):
|
||||
return Mock(
|
||||
openai_api_key="test_key",
|
||||
OPENAI_API_KEY="test_key",
|
||||
openai_api_version="test_version",
|
||||
openai_base_url="test_url",
|
||||
OPENAI_BASE_URL="test_url",
|
||||
openai_proxy="http://proxy.com",
|
||||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tool_calls_rsp(self):
|
||||
function_rsps = [
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'}', name="execute"),
|
||||
Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": """print("hello world")"""}', name="execute"),
|
||||
Function(arguments='\nprint("hello world")\\n', name="execute"),
|
||||
# only `{` in arguments
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
# no `{`, `}` in arguments
|
||||
Function(arguments='\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
]
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f)
|
||||
for i, f in enumerate(function_rsps)
|
||||
]
|
||||
messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls]
|
||||
# 添加一个纯文本响应
|
||||
messages.append(
|
||||
ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None)
|
||||
)
|
||||
# 添加 openai tool calls respond bug, code 出现在ChatCompletionMessage.content中
|
||||
messages.extend(
|
||||
[
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(content="'''python\nprint('hello world')'''", role="assistant", tool_calls=None),
|
||||
ChatCompletionMessage(
|
||||
content='"""python\nprint(\'hello world\')"""', role="assistant", tool_calls=None
|
||||
),
|
||||
ChatCompletionMessage(
|
||||
content="'''python\nprint(\"hello world\")'''", role="assistant", tool_calls=None
|
||||
),
|
||||
ChatCompletionMessage(content="```python\nprint('hello world')```", role="assistant", tool_calls=None),
|
||||
]
|
||||
)
|
||||
choices = [
|
||||
Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages)
|
||||
]
|
||||
return [
|
||||
ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion")
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
def test_make_client_kwargs_without_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert kwargs["api_key"] == "mock_api_key"
|
||||
assert kwargs["base_url"] == "mock_base_url"
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
assert "http_client" not in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
def test_make_client_kwargs_with_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp):
|
||||
instance = OpenAILLM()
|
||||
instance = OpenAILLM(mock_llm_config_proxy)
|
||||
for i, rsp in enumerate(tool_calls_rsp):
|
||||
code = instance.get_choice_function_arguments(rsp)
|
||||
logger.info(f"\ntest get function call arguments {i}: {code}")
|
||||
|
|
|
|||
|
|
@ -4,14 +4,9 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
|
||||
CONFIG.spark_appid = "xxx"
|
||||
CONFIG.spark_api_secret = "xxx"
|
||||
CONFIG.spark_api_key = "xxx"
|
||||
CONFIG.domain = "xxxxxx"
|
||||
CONFIG.spark_url = "xxxx"
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
|
@ -28,8 +23,8 @@ class MockWebSocketApp(object):
|
|||
def test_get_msg_from_web(mocker):
|
||||
mocker.patch("websocket.WebSocketApp", MockWebSocketApp)
|
||||
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
get_msg_from_web = GetMessageFromWeb(prompt_msg, mock_llm_config)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "mock_domain"
|
||||
|
||||
ret = get_msg_from_web.run()
|
||||
assert ret == ""
|
||||
|
|
@ -39,11 +34,19 @@ def mock_spark_get_msg_from_web_run(self) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_aask():
|
||||
llm = SparkLLM(Config.from_home("spark.yaml").llm)
|
||||
|
||||
resp = await llm.aask("Hello!")
|
||||
print(resp)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
|
||||
spark_gpt = SparkLLM()
|
||||
spark_gpt = SparkLLM(mock_llm_config)
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -3,47 +3,25 @@
|
|||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
|
||||
CONFIG.zhipuai_api_key = "xxx.xxx"
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config_zhipu
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": resp_content, "role": "assistant"}}],
|
||||
"usage": {"completion_tokens": 22, "prompt_tokens": 19, "total_tokens": 41},
|
||||
}
|
||||
|
||||
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
async def mock_zhipuai_acreate_stream(**kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
events = [{"choices": [{"index": 0, "delta": {"content": resp_content, "role": "assistant"}}]}]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
|
|
@ -52,23 +30,26 @@ async def mock_zhipuai_asse_invoke(**kwargs):
|
|||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async def stream(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
async def mock_zhipuai_acreate(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate", mock_zhipuai_acreate)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.acreate_stream", mock_zhipuai_acreate_stream)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM()
|
||||
zhipu_gpt = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
assert resp["choices"][0]["message"]["content"] == resp_content
|
||||
|
||||
resp = await zhipu_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
|
@ -84,6 +65,7 @@ async def test_zhipuai_acompletion(mocker):
|
|||
|
||||
|
||||
def test_zhipuai_proxy():
|
||||
# CONFIG.openai_proxy = "http://127.0.0.1:8080"
|
||||
_ = ZhiPuAILLM()
|
||||
# assert openai.proxy == CONFIG.openai_proxy
|
||||
# it seems like zhipuai would be inflected by the proxy of openai, maybe it's a bug
|
||||
# but someone may want to use openai.proxy, so we keep this test case
|
||||
# assert openai.proxy == config.llm.proxy
|
||||
_ = ZhiPuAILLM(mock_llm_config_zhipu)
|
||||
|
|
|
|||
|
|
@ -11,16 +11,16 @@ from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
|||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
yield b'data: {"test_key": "test_value"}'
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert "test_value" in chunk.values()
|
||||
|
||||
class InvalidIterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"invalid: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=InvalidIterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert not event
|
||||
async for chunk in async_sse_client.stream():
|
||||
assert not chunk
|
||||
|
|
|
|||
|
|
@ -6,15 +6,13 @@ from typing import Any, Tuple
|
|||
|
||||
import pytest
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = b'{"result": "test response"}'
|
||||
default_resp = b'{"choices": [{"finish_reason": "stop", "index": 0, "message": {"content": "test response", "role": "assistant"}}]}'
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
|
|
@ -23,22 +21,15 @@ async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipu_model_api(mocker):
|
||||
header = ZhiPuModelAPI.get_header()
|
||||
zhipuai_default_headers.update({"Authorization": api_key})
|
||||
assert header == zhipuai_default_headers
|
||||
|
||||
sse_header = ZhiPuModelAPI.get_sse_header()
|
||||
assert len(sse_header["Authorization"]) == 191
|
||||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
url_prefix, url_suffix = ZhiPuModelAPI(api_key=api_key).split_zhipu_api_url()
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
assert url_suffix == "/paas/v4/chat/completions"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
result = await ZhiPuModelAPI(api_key=api_key).arequest(
|
||||
stream=False, method="get", headers={}, kwargs={"model": "glm-3-turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
||||
result = await ZhiPuModelAPI.ainvoke()
|
||||
assert result["result"] == "test response"
|
||||
result = await ZhiPuModelAPI(api_key=api_key).acreate()
|
||||
assert result["choices"][0]["message"]["content"] == "test response"
|
||||
|
|
|
|||
|
|
@ -9,9 +9,7 @@ from metagpt.schema import Plan
|
|||
from metagpt.utils.recovery_util import load_history, save_history
|
||||
|
||||
|
||||
async def run_code_interpreter(
|
||||
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
|
||||
):
|
||||
async def run_code_interpreter(role_class, requirement, auto_run, use_tools, save_dir, tools):
|
||||
"""
|
||||
The main function to run the MLEngineer with optional history loading.
|
||||
|
||||
|
|
@ -25,16 +23,11 @@ async def run_code_interpreter(
|
|||
"""
|
||||
|
||||
if role_class == "ci":
|
||||
role = CodeInterpreter(
|
||||
goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, tools=tools
|
||||
)
|
||||
role = CodeInterpreter(auto_run=auto_run, use_tools=use_tools, tools=tools)
|
||||
else:
|
||||
role = MLEngineer(
|
||||
goal=requirement,
|
||||
auto_run=auto_run,
|
||||
use_tools=use_tools,
|
||||
use_code_steps=use_code_steps,
|
||||
make_udfs=make_udfs,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
|
|
@ -50,10 +43,10 @@ async def run_code_interpreter(
|
|||
try:
|
||||
await role.run(requirement)
|
||||
except Exception as e:
|
||||
save_path = save_history(role, save_dir)
|
||||
|
||||
logger.exception(f"An error occurred: {e}, save trajectory here: {save_path}")
|
||||
|
||||
save_history(role, save_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
|
|
@ -73,8 +66,6 @@ if __name__ == "__main__":
|
|||
role_class = "mle"
|
||||
auto_run = True
|
||||
use_tools = True
|
||||
make_udfs = False
|
||||
use_udfs = False
|
||||
tools = []
|
||||
# tools = ["FillMissingValue", "CatCross", "non_existing_test"]
|
||||
|
||||
|
|
@ -83,14 +74,9 @@ if __name__ == "__main__":
|
|||
requirement: str = requirement,
|
||||
auto_run: bool = auto_run,
|
||||
use_tools: bool = use_tools,
|
||||
use_code_steps: bool = False,
|
||||
make_udfs: bool = make_udfs,
|
||||
use_udfs: bool = use_udfs,
|
||||
save_dir: str = save_dir,
|
||||
tools=tools,
|
||||
):
|
||||
await run_code_interpreter(
|
||||
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
|
||||
)
|
||||
await run_code_interpreter(role_class, requirement, auto_run, use_tools, save_dir, tools)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ import uuid
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WritePRD
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PRDS_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect
|
||||
|
|
@ -22,12 +21,12 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect():
|
||||
async def test_architect(context):
|
||||
# Prerequisites
|
||||
filename = uuid.uuid4().hex + ".json"
|
||||
await awrite(CONFIG.git_repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
|
||||
await awrite(context.repo.workdir / PRDS_FILE_REPO / filename, data=MockMessages.prd.content)
|
||||
|
||||
role = Architect()
|
||||
role = Architect(context=context)
|
||||
rsp = await role.run(with_message=Message(content="", cause_by=WritePRD))
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from pydantic import BaseModel
|
|||
|
||||
from metagpt.actions.skill_action import SkillAction
|
||||
from metagpt.actions.talk_action import TalkAction
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.memory.brain_memory import BrainMemory
|
||||
from metagpt.roles.assistant import Assistant
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -20,8 +19,11 @@ from metagpt.utils.common import any_to_str
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.language = "Chinese"
|
||||
async def test_run(mocker, context):
|
||||
# mock
|
||||
mocker.patch("metagpt.learn.text_to_image", return_value="http://mock.com/1.png")
|
||||
|
||||
context.kwargs.language = "Chinese"
|
||||
|
||||
class Input(BaseModel):
|
||||
memory: BrainMemory
|
||||
|
|
@ -65,7 +67,7 @@ async def test_run():
|
|||
"cause_by": any_to_str(SkillAction),
|
||||
},
|
||||
]
|
||||
CONFIG.agent_skills = [
|
||||
agent_skills = [
|
||||
{"id": 1, "name": "text_to_speech", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 2, "name": "text_to_image", "type": "builtin", "config": {}, "enabled": True},
|
||||
{"id": 3, "name": "ai_call", "type": "builtin", "config": {}, "enabled": True},
|
||||
|
|
@ -77,9 +79,11 @@ async def test_run():
|
|||
|
||||
for i in inputs:
|
||||
seed = Input(**i)
|
||||
CONFIG.language = seed.language
|
||||
CONFIG.agent_description = seed.agent_description
|
||||
role = Assistant(language="Chinese")
|
||||
role = Assistant(language="Chinese", context=context)
|
||||
role.context.kwargs.language = seed.language
|
||||
role.context.kwargs.agent_description = seed.agent_description
|
||||
role.context.kwargs.agent_skills = agent_skills
|
||||
|
||||
role.memory = seed.memory # Restore historical conversation content.
|
||||
while True:
|
||||
has_action = await role.think()
|
||||
|
|
@ -110,21 +114,16 @@ async def test_run():
|
|||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory(memory):
|
||||
role = Assistant()
|
||||
async def test_memory(memory, context):
|
||||
role = Assistant(context=context)
|
||||
role.context.kwargs.agent_skills = []
|
||||
role.load_memory(memory)
|
||||
|
||||
val = role.get_memory()
|
||||
assert val
|
||||
|
||||
await role.talk("draw apple")
|
||||
|
||||
agent_skills = CONFIG.agent_skills
|
||||
CONFIG.agent_skills = []
|
||||
try:
|
||||
await role.think()
|
||||
finally:
|
||||
CONFIG.agent_skills = agent_skills
|
||||
await role.think()
|
||||
assert isinstance(role.rc.todo, TalkAction)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,9 @@ from metagpt.roles.code_interpreter import CodeInterpreter
|
|||
@pytest.mark.asyncio
|
||||
async def test_code_interpreter():
|
||||
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
ci = CodeInterpreter(goal=requirement, auto_run=True, use_tools=False)
|
||||
tools = []
|
||||
|
||||
ci = CodeInterpreter(auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await ci.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
|
|||
|
|
@ -1,50 +0,0 @@
|
|||
import pytest
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ml_engineer import ExecutePyCode, MLEngineer
|
||||
from metagpt.schema import Plan
|
||||
|
||||
|
||||
def reset(role):
|
||||
"""Restart role with the same goal."""
|
||||
role.working_memory.clear()
|
||||
role.planner.plan = Plan(goal=role.planner.plan.goal)
|
||||
role.execute_code = ExecutePyCode()
|
||||
|
||||
|
||||
async def make_use_tools(requirement: str, auto_run: bool = True):
|
||||
"""make and use tools for requirement."""
|
||||
role = MLEngineer(goal=requirement, auto_run=auto_run)
|
||||
# make udfs
|
||||
role.use_tools = False
|
||||
role.use_code_steps = False
|
||||
role.make_udfs = True
|
||||
role.use_udfs = False
|
||||
await role.run(requirement)
|
||||
# use udfs
|
||||
reset(role)
|
||||
role.make_udfs = False
|
||||
role.use_udfs = True
|
||||
role.use_code_steps = False
|
||||
role.use_tools = False
|
||||
await role.run(requirement)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_use_tools():
|
||||
requirements = [
|
||||
"Run data analysis on sklearn Iris dataset, include a plot",
|
||||
"Run data analysis on sklearn Diabetes dataset, include a plot",
|
||||
"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy",
|
||||
"Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy",
|
||||
"Run EDA and visualization on this dataset, train a model to predict survival, report metrics on validation set (20%), dataset: tests/data/titanic.csv",
|
||||
]
|
||||
success = 0
|
||||
for requirement in tqdm(requirements, total=len(requirements)):
|
||||
try:
|
||||
await make_use_tools(requirement)
|
||||
success += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Found Error in {requirement}, {e}")
|
||||
logger.info(f"success: {round(success/len(requirements), 1)*100}%")
|
||||
|
|
@ -13,40 +13,30 @@ from pathlib import Path
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode, WriteTasks
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.const import REQUIREMENT_FILENAME, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.schema import CodingContext, Message
|
||||
from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import ChangeType
|
||||
from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer():
|
||||
async def test_engineer(context):
|
||||
# Prerequisites
|
||||
rqno = "20231221155954.json"
|
||||
await FileRepository.save_file(REQUIREMENT_FILENAME, content=MockMessages.req.content)
|
||||
await FileRepository.save_file(rqno, relative_path=PRDS_FILE_REPO, content=MockMessages.prd.content)
|
||||
await FileRepository.save_file(
|
||||
rqno, relative_path=SYSTEM_DESIGN_FILE_REPO, content=MockMessages.system_design.content
|
||||
)
|
||||
await FileRepository.save_file(rqno, relative_path=TASK_FILE_REPO, content=MockMessages.json_tasks.content)
|
||||
await context.repo.save(REQUIREMENT_FILENAME, content=MockMessages.req.content)
|
||||
await context.repo.docs.prd.save(rqno, content=MockMessages.prd.content)
|
||||
await context.repo.docs.system_design.save(rqno, content=MockMessages.system_design.content)
|
||||
await context.repo.docs.task.save(rqno, content=MockMessages.json_tasks.content)
|
||||
|
||||
engineer = Engineer()
|
||||
engineer = Engineer(context=context)
|
||||
rsp = await engineer.run(Message(content="", cause_by=WriteTasks))
|
||||
|
||||
logger.info(rsp)
|
||||
assert rsp.cause_by == any_to_str(WriteCode)
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
assert src_file_repo.changed_files
|
||||
assert context.repo.with_src_path(context.src_workspace).srcs.changed_files
|
||||
|
||||
|
||||
def test_parse_str():
|
||||
|
|
@ -109,54 +99,52 @@ def test_parse_code():
|
|||
|
||||
def test_todo():
|
||||
role = Engineer()
|
||||
assert role.todo == any_to_name(WriteCode)
|
||||
assert role.action_description == any_to_name(WriteCode)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_coding_context():
|
||||
async def test_new_coding_context(context):
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
deps = json.loads(await aread(demo_path / "dependencies.json"))
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
dependency = await context.git_repo.get_dependency()
|
||||
for k, v in deps.items():
|
||||
await dependency.update(k, set(v))
|
||||
data = await aread(demo_path / "system_design.json")
|
||||
rqno = "20231221155954.json"
|
||||
await awrite(CONFIG.git_repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
|
||||
await awrite(context.repo.workdir / SYSTEM_DESIGN_FILE_REPO / rqno, data)
|
||||
data = await aread(demo_path / "tasks.json")
|
||||
await awrite(CONFIG.git_repo.workdir / TASK_FILE_REPO / rqno, data)
|
||||
await awrite(context.repo.workdir / TASK_FILE_REPO / rqno, data)
|
||||
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "game_2048"
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(relative_path=CONFIG.src_workspace)
|
||||
task_file_repo = CONFIG.git_repo.new_file_repository(relative_path=TASK_FILE_REPO)
|
||||
design_file_repo = CONFIG.git_repo.new_file_repository(relative_path=SYSTEM_DESIGN_FILE_REPO)
|
||||
context.src_workspace = Path(context.repo.workdir) / "game_2048"
|
||||
|
||||
filename = "game.py"
|
||||
ctx_doc = await Engineer._new_coding_doc(
|
||||
filename=filename,
|
||||
src_file_repo=src_file_repo,
|
||||
task_file_repo=task_file_repo,
|
||||
design_file_repo=design_file_repo,
|
||||
dependency=dependency,
|
||||
)
|
||||
assert ctx_doc
|
||||
assert ctx_doc.filename == filename
|
||||
assert ctx_doc.content
|
||||
ctx = CodingContext.model_validate_json(ctx_doc.content)
|
||||
assert ctx.filename == filename
|
||||
assert ctx.design_doc
|
||||
assert ctx.design_doc.content
|
||||
assert ctx.task_doc
|
||||
assert ctx.task_doc.content
|
||||
assert ctx.code_doc
|
||||
try:
|
||||
filename = "game.py"
|
||||
engineer = Engineer(context=context)
|
||||
ctx_doc = await engineer._new_coding_doc(
|
||||
filename=filename,
|
||||
dependency=dependency,
|
||||
)
|
||||
assert ctx_doc
|
||||
assert ctx_doc.filename == filename
|
||||
assert ctx_doc.content
|
||||
ctx = CodingContext.model_validate_json(ctx_doc.content)
|
||||
assert ctx.filename == filename
|
||||
assert ctx.design_doc
|
||||
assert ctx.design_doc.content
|
||||
assert ctx.task_doc
|
||||
assert ctx.task_doc.content
|
||||
assert ctx.code_doc
|
||||
|
||||
CONFIG.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
|
||||
CONFIG.git_repo.commit("mock env")
|
||||
await src_file_repo.save(filename=filename, content="content")
|
||||
role = Engineer()
|
||||
assert not role.code_todos
|
||||
await role._new_code_actions()
|
||||
assert role.code_todos
|
||||
context.git_repo.add_change({f"{TASK_FILE_REPO}/{rqno}": ChangeType.UNTRACTED})
|
||||
context.git_repo.commit("mock env")
|
||||
await context.repo.with_src_path(context.src_workspace).srcs.save(filename=filename, content="content")
|
||||
role = Engineer(context=context)
|
||||
assert not role.code_todos
|
||||
await role._new_code_actions()
|
||||
assert role.code_todos
|
||||
finally:
|
||||
context.git_repo.delete_repository()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -41,9 +41,11 @@ from metagpt.schema import Message
|
|||
),
|
||||
],
|
||||
)
|
||||
async def test_invoice_ocr_assistant(query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict):
|
||||
async def test_invoice_ocr_assistant(
|
||||
query: str, invoice_path: Path, invoice_table_path: Path, expected_result: dict, context
|
||||
):
|
||||
invoice_path = TEST_DATA_PATH / invoice_path
|
||||
role = InvoiceOCRAssistant()
|
||||
role = InvoiceOCRAssistant(context=context)
|
||||
await role.run(Message(content=query, instruct_content=InvoicePath(file_path=invoice_path)))
|
||||
invoice_table_path = DATA_PATH / invoice_table_path
|
||||
df = pd.read_excel(invoice_table_path)
|
||||
|
|
|
|||
21
tests/metagpt/roles/test_ml_engineer.py
Normal file
21
tests/metagpt/roles/test_ml_engineer.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ml_engineer import MLEngineer
|
||||
|
||||
|
||||
def test_mle_init():
|
||||
ci = MLEngineer(goal="test", auto_run=True, use_tools=True, tools=["tool1", "tool2"])
|
||||
assert ci.tools == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ml_engineer():
|
||||
data_path = "tests/data/ml_datasets/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
tools = ["FillMissingValue", "CatCross", "dummy_tool"]
|
||||
|
||||
mle = MLEngineer(goal=requirement, auto_run=True, use_tools=True, tools=tools)
|
||||
rsp = await mle.run(requirement)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
|
|
@ -5,17 +5,51 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_product_manager.py
|
||||
"""
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.const import REQUIREMENT_FILENAME
|
||||
from metagpt.context import Context
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import ProductManager
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.roles.mock import MockMessages
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager(new_filename):
|
||||
product_manager = ProductManager()
|
||||
rsp = await product_manager.run(MockMessages.req)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
assert rsp.content == MockMessages.req.content
|
||||
context = Context()
|
||||
try:
|
||||
assert context.git_repo is None
|
||||
assert context.repo is None
|
||||
product_manager = ProductManager(context=context)
|
||||
# prepare documents
|
||||
rsp = await product_manager.run(MockMessages.req)
|
||||
assert context.git_repo
|
||||
assert context.repo
|
||||
assert rsp.cause_by == any_to_str(PrepareDocuments)
|
||||
assert REQUIREMENT_FILENAME in context.repo.docs.changed_files
|
||||
|
||||
# write prd
|
||||
rsp = await product_manager.run(rsp)
|
||||
assert rsp.cause_by == any_to_str(WritePRD)
|
||||
logger.info(rsp)
|
||||
assert len(rsp.content) > 0
|
||||
doc = list(rsp.instruct_content.docs.values())[0]
|
||||
m = json.loads(doc.content)
|
||||
assert m["Original Requirements"] == MockMessages.req.content
|
||||
|
||||
# nothing to do
|
||||
rsp = await product_manager.run(rsp)
|
||||
assert rsp is None
|
||||
except Exception as e:
|
||||
assert not e
|
||||
finally:
|
||||
context.git_repo.delete_repository()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tests.metagpt.roles.mock import MockMessages
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager():
|
||||
project_manager = ProjectManager()
|
||||
async def test_project_manager(context):
|
||||
project_manager = ProjectManager(context=context)
|
||||
rsp = await project_manager.run(MockMessages.system_design)
|
||||
logger.info(rsp)
|
||||
|
|
|
|||
|
|
@ -13,20 +13,19 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles import QaEngineer
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, aread, awrite
|
||||
|
||||
|
||||
async def test_qa():
|
||||
async def test_qa(context):
|
||||
# Prerequisites
|
||||
demo_path = Path(__file__).parent / "../../data/demo_project"
|
||||
CONFIG.src_workspace = Path(CONFIG.git_repo.workdir) / "qa/game_2048"
|
||||
context.src_workspace = Path(context.repo.workdir) / "qa/game_2048"
|
||||
data = await aread(filename=demo_path / "game.py", encoding="utf-8")
|
||||
await awrite(filename=CONFIG.src_workspace / "game.py", data=data, encoding="utf-8")
|
||||
await awrite(filename=Path(CONFIG.git_repo.workdir) / "requirements.txt", data="")
|
||||
await awrite(filename=context.src_workspace / "game.py", data=data, encoding="utf-8")
|
||||
await awrite(filename=Path(context.repo.workdir) / "requirements.txt", data="")
|
||||
|
||||
class MockEnv(Environment):
|
||||
msgs: List[Message] = Field(default_factory=list)
|
||||
|
|
@ -37,7 +36,7 @@ async def test_qa():
|
|||
|
||||
env = MockEnv()
|
||||
|
||||
role = QaEngineer()
|
||||
role = QaEngineer(context=context)
|
||||
role.set_env(env)
|
||||
await role.run(with_message=Message(content="", cause_by=SummarizeCode))
|
||||
assert env.msgs
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.research import CollectLinks
|
||||
from metagpt.roles import researcher
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
|
|
@ -25,16 +28,20 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
async def test_researcher(mocker, search_engine_mocker, context):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
role = researcher.Researcher(context=context)
|
||||
for i in role.actions:
|
||||
if isinstance(i, CollectLinks):
|
||||
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
|
||||
await role.run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
def test_write_report(mocker):
|
||||
def test_write_report(mocker, context):
|
||||
with TemporaryDirectory() as dirname:
|
||||
for i, topic in enumerate(
|
||||
[
|
||||
|
|
@ -46,7 +53,7 @@ def test_write_report(mocker):
|
|||
):
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
content = "# Research Report"
|
||||
researcher.Researcher().write_report(topic, content)
|
||||
researcher.Researcher(context=context).write_report(topic, content)
|
||||
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
# @Desc : unittest of Role
|
||||
import pytest
|
||||
|
||||
from metagpt.llm import HumanProvider
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
|
|
@ -13,8 +13,8 @@ def test_role_desc():
|
|||
assert role.desc == "Best Seller"
|
||||
|
||||
|
||||
def test_role_human():
|
||||
role = Role(is_human=True)
|
||||
def test_role_human(context):
|
||||
role = Role(is_human=True, context=context)
|
||||
assert isinstance(role.llm, HumanProvider)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,19 +5,17 @@
|
|||
@Author : mashenquan
|
||||
@File : test_teacher.py
|
||||
"""
|
||||
import os
|
||||
from typing import Dict, Optional
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.context import Context
|
||||
from metagpt.roles.teacher import Teacher
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip
|
||||
async def test_init():
|
||||
class Inputs(BaseModel):
|
||||
name: str
|
||||
|
|
@ -31,6 +29,7 @@ async def test_init():
|
|||
expect_goal: str
|
||||
expect_constraints: str
|
||||
expect_desc: str
|
||||
exclude: list = Field(default_factory=list)
|
||||
|
||||
inputs = [
|
||||
{
|
||||
|
|
@ -45,6 +44,7 @@ async def test_init():
|
|||
"kwargs": {},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaa{language}",
|
||||
"exclude": ["language", "key1", "something_big", "teaching_language"],
|
||||
},
|
||||
{
|
||||
"name": "Lily{language}",
|
||||
|
|
@ -58,20 +58,21 @@ async def test_init():
|
|||
"kwargs": {"language": "CN", "key1": "HaHa", "something_big": "sleep", "teaching_language": "EN"},
|
||||
"desc": "aaa{language}",
|
||||
"expect_desc": "aaaCN",
|
||||
"language": "CN",
|
||||
"teaching_language": "EN",
|
||||
},
|
||||
]
|
||||
|
||||
env = os.environ.copy()
|
||||
for i in inputs:
|
||||
seed = Inputs(**i)
|
||||
os.environ.clear()
|
||||
os.environ.update(env)
|
||||
CONFIG = Config()
|
||||
CONFIG.set_context(seed.kwargs)
|
||||
print(CONFIG.options)
|
||||
assert bool("language" in seed.kwargs) == bool("language" in CONFIG.options)
|
||||
context = Context()
|
||||
for k in seed.exclude:
|
||||
context.kwargs.set(k, None)
|
||||
for k, v in seed.kwargs.items():
|
||||
context.kwargs.set(k, v)
|
||||
|
||||
teacher = Teacher(
|
||||
context=context,
|
||||
name=seed.name,
|
||||
profile=seed.profile,
|
||||
goal=seed.goal,
|
||||
|
|
@ -105,7 +106,6 @@ async def test_new_file_name():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run():
|
||||
CONFIG.set_context({"language": "Chinese", "teaching_language": "English"})
|
||||
lesson = """
|
||||
UNIT 1 Making New Friends
|
||||
TOPIC 1 Welcome to China!
|
||||
|
|
@ -149,7 +149,10 @@ async def test_run():
|
|||
|
||||
3c Match the big letters with the small ones. Then write them on the lines.
|
||||
"""
|
||||
teacher = Teacher()
|
||||
context = Context()
|
||||
context.kwargs.language = "Chinese"
|
||||
context.kwargs.teaching_language = "English"
|
||||
teacher = Teacher(context=context)
|
||||
rsp = await teacher.run(Message(content=lesson))
|
||||
assert rsp
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("language", "topic"), [("Chinese", "Write a tutorial about pip")])
|
||||
async def test_tutorial_assistant(language: str, topic: str):
|
||||
role = TutorialAssistant(language=language)
|
||||
async def test_tutorial_assistant(language: str, topic: str, context):
|
||||
role = TutorialAssistant(language=language, context=context)
|
||||
msg = await role.run(topic)
|
||||
assert TUTORIAL_PATH.exists()
|
||||
filename = msg.content
|
||||
|
|
|
|||
|
|
@ -5,28 +5,22 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = Action()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" not in ser_action_dict
|
||||
|
||||
action = Action(name="test")
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = Action()
|
||||
serialized_data = action.model_dump()
|
||||
async def test_action_serdeser(context):
|
||||
action = Action(context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" in ser_action_dict
|
||||
|
||||
new_action = Action(**serialized_data)
|
||||
action = Action(name="test", context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
assert new_action.name == "Action"
|
||||
assert isinstance(new_action.llm, type(LLM()))
|
||||
new_action = Action(**ser_action_dict, context=context)
|
||||
|
||||
assert new_action.name == "test"
|
||||
assert isinstance(new_action.llm, type(context.llm()))
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
|
|||
|
|
@ -8,21 +8,21 @@ from metagpt.actions.action import Action
|
|||
from metagpt.roles.architect import Architect
|
||||
|
||||
|
||||
def test_architect_serialize():
|
||||
role = Architect()
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_serdeser(context):
|
||||
role = Architect(context=context)
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_deserialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = Architect(**ser_role_dict)
|
||||
# new_role = Architect.deserialize(ser_role_dict)
|
||||
new_role = Architect(**ser_role_dict, context=context)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role.actions) == 1
|
||||
assert len(new_role.rc.watch) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
await new_role.actions[0].run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
|
|
@ -10,7 +9,7 @@ from metagpt.actions.project_management import WriteTasks
|
|||
from metagpt.environment import Environment
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
ActionRaise,
|
||||
|
|
@ -19,23 +18,20 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_env_serialize():
|
||||
env = Environment()
|
||||
def test_env_serdeser(context):
|
||||
env = Environment(context=context)
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
|
||||
ser_env_dict = env.model_dump()
|
||||
assert "roles" in ser_env_dict
|
||||
assert len(ser_env_dict["roles"]) == 0
|
||||
|
||||
|
||||
def test_env_deserialize():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
ser_env_dict = env.model_dump()
|
||||
new_env = Environment(**ser_env_dict)
|
||||
new_env = Environment(**ser_env_dict, context=context)
|
||||
assert len(new_env.roles) == 0
|
||||
assert len(new_env.history) == 25
|
||||
|
||||
|
||||
def test_environment_serdeser():
|
||||
def test_environment_serdeser(context):
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
|
@ -44,7 +40,7 @@ def test_environment_serdeser():
|
|||
content="prd", instruct_content=ic_obj(**out_data), role="product manager", cause_by=any_to_str(UserRequirement)
|
||||
)
|
||||
|
||||
environment = Environment()
|
||||
environment = Environment(context=context)
|
||||
role_c = RoleC()
|
||||
environment.add_role(role_c)
|
||||
environment.publish_message(message)
|
||||
|
|
@ -52,7 +48,7 @@ def test_environment_serdeser():
|
|||
ser_data = environment.model_dump()
|
||||
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
new_env: Environment = Environment(**ser_data, context=context)
|
||||
assert len(new_env.roles) == 1
|
||||
|
||||
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
|
||||
|
|
@ -61,30 +57,31 @@ def test_environment_serdeser():
|
|||
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
|
||||
|
||||
|
||||
def test_environment_serdeser_v2():
|
||||
environment = Environment()
|
||||
def test_environment_serdeser_v2(context):
|
||||
environment = Environment(context=context)
|
||||
pm = ProjectManager()
|
||||
environment.add_role(pm)
|
||||
|
||||
ser_data = environment.model_dump()
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
new_env: Environment = Environment(**ser_data, context=context)
|
||||
role = new_env.get_role(pm.profile)
|
||||
assert isinstance(role, ProjectManager)
|
||||
assert isinstance(role.actions[0], WriteTasks)
|
||||
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
|
||||
|
||||
|
||||
def test_environment_serdeser_save():
|
||||
environment = Environment()
|
||||
def test_environment_serdeser_save(context):
|
||||
environment = Environment(context=context)
|
||||
role_c = RoleC()
|
||||
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
env_path = stg_path.joinpath("env.json")
|
||||
environment.add_role(role_c)
|
||||
environment.serialize(stg_path)
|
||||
|
||||
new_env: Environment = Environment.deserialize(stg_path)
|
||||
write_json_file(env_path, environment.model_dump())
|
||||
|
||||
env_dict = read_json_file(env_path)
|
||||
new_env: Environment = Environment(**env_dict, context=context)
|
||||
assert len(new_env.roles) == 1
|
||||
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
|
||||
|
|
|
|||
|
|
@ -9,11 +9,11 @@ from metagpt.actions.add_requirement import UserRequirement
|
|||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
|
||||
|
||||
|
||||
def test_memory_serdeser():
|
||||
def test_memory_serdeser(context):
|
||||
msg1 = Message(role="Boss", content="write a snake game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field2": (list[str], ...)}
|
||||
|
|
@ -39,7 +39,7 @@ def test_memory_serdeser():
|
|||
assert memory.count() == 2
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
def test_memory_serdeser_save(context):
|
||||
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
|
|
@ -53,14 +53,14 @@ def test_memory_serdeser_save():
|
|||
memory.add_batch([msg1, msg2])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
memory.serialize(stg_path)
|
||||
assert stg_path.joinpath("memory.json").exists()
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
write_json_file(memory_path, memory.model_dump())
|
||||
assert memory_path.exists()
|
||||
|
||||
new_memory = Memory.deserialize(stg_path)
|
||||
memory_dict = read_json_file(memory_path)
|
||||
new_memory = Memory(**memory_dict)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(1)[0]
|
||||
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
|
||||
assert new_msg2.cause_by == any_to_str(WriteDesign)
|
||||
assert len(new_memory.index) == 2
|
||||
|
||||
stg_path.joinpath("memory.json").unlink()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of polymorphic conditions
|
||||
import copy
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SerializeAsAny
|
||||
|
||||
|
|
@ -12,6 +13,8 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
|
||||
|
||||
class ActionSubClasses(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
actions: list[SerializeAsAny[Action]] = []
|
||||
|
||||
|
||||
|
|
@ -40,19 +43,21 @@ def test_no_serialize_as_any():
|
|||
|
||||
|
||||
def test_polymorphic():
|
||||
_ = ActionOKV2(
|
||||
ok_v2 = ActionOKV2(
|
||||
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
|
||||
)
|
||||
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
action_subcls_dict2 = copy.deepcopy(action_subcls_dict)
|
||||
|
||||
assert "__module_class_name" in action_subcls_dict["actions"][0]
|
||||
|
||||
new_action_subcls = ActionSubClasses(**action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert new_action_subcls.actions[0].extra_field == ok_v2.extra_field
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict2)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from metagpt.actions.prepare_interview import PrepareInterview
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = PrepareInterview()
|
||||
async def test_action_serdeser(context):
|
||||
action = PrepareInterview(context=context)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "PrepareInterview"
|
||||
|
||||
new_action = PrepareInterview(**serialized_data)
|
||||
new_action = PrepareInterview(**serialized_data, context=context)
|
||||
|
||||
assert new_action.name == "PrepareInterview"
|
||||
assert type(await new_action.run("python developer")) == ActionNode
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ from metagpt.schema import Message
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager_deserialize(new_filename):
|
||||
role = ProductManager()
|
||||
async def test_product_manager_serdeser(new_filename, context):
|
||||
role = ProductManager(context=context)
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = ProductManager(**ser_role_dict)
|
||||
new_role = ProductManager(**ser_role_dict, context=context)
|
||||
|
||||
assert new_role.name == "Alice"
|
||||
assert len(new_role.actions) == 2
|
||||
|
|
|
|||
|
|
@ -9,20 +9,15 @@ from metagpt.actions.project_management import WriteTasks
|
|||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_project_manager_serialize():
|
||||
role = ProjectManager()
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_serdeser(context):
|
||||
role = ProjectManager(context=context)
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_deserialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
|
||||
new_role = ProjectManager(**ser_role_dict)
|
||||
new_role = ProjectManager(**ser_role_dict, context=context)
|
||||
assert new_role.name == "Eve"
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from metagpt.roles.researcher import Researcher
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_assistant_deserialize():
|
||||
role = Researcher()
|
||||
async def test_tutorial_assistant_serdeser(context):
|
||||
role = Researcher(context=context)
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "language" in ser_role_dict
|
||||
|
||||
new_role = Researcher(**ser_role_dict)
|
||||
new_role = Researcher(**ser_role_dict, context=context)
|
||||
assert new_role.language == "en-us"
|
||||
assert len(new_role.actions) == 3
|
||||
assert isinstance(new_role.actions[0], CollectLinks)
|
||||
|
|
|
|||
|
|
@ -10,13 +10,12 @@ from pydantic import BaseModel, SerializeAsAny
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import format_trackback_info
|
||||
from metagpt.utils.common import format_trackback_info, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
|
|
@ -27,7 +26,7 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_roles():
|
||||
def test_roles(context):
|
||||
role_a = RoleA()
|
||||
assert len(role_a.rc.watch) == 1
|
||||
role_b = RoleB()
|
||||
|
|
@ -38,7 +37,7 @@ def test_roles():
|
|||
assert len(role_d.actions) == 1
|
||||
|
||||
|
||||
def test_role_subclasses():
|
||||
def test_role_subclasses(context):
|
||||
"""test subclasses of role with same fields in ser&deser"""
|
||||
|
||||
class RoleSubClasses(BaseModel):
|
||||
|
|
@ -52,7 +51,7 @@ def test_role_subclasses():
|
|||
assert isinstance(new_role_subcls.roles[1], RoleB)
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
def test_role_serialize(context):
|
||||
role = Role()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
|
|
@ -60,60 +59,56 @@ def test_role_serialize():
|
|||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
def test_engineer_serialize():
|
||||
def test_engineer_serdeser(context):
|
||||
role = Engineer()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_deserialize():
|
||||
role = Engineer(use_code_review=True)
|
||||
ser_role_dict = role.model_dump()
|
||||
|
||||
new_role = Engineer(**ser_role_dict)
|
||||
assert new_role.name == "Alex"
|
||||
assert new_role.use_code_review is True
|
||||
assert new_role.use_code_review is False
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], WriteCode)
|
||||
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
|
||||
|
||||
|
||||
def test_role_serdeser_save():
|
||||
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
|
||||
def test_role_serdeser_save(context):
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
pm = ProductManager()
|
||||
role_tag = f"{pm.__class__.__name__}_{pm.name}"
|
||||
stg_path = stg_path_prefix.joinpath(role_tag)
|
||||
pm.serialize(stg_path)
|
||||
|
||||
new_pm = Role.deserialize(stg_path)
|
||||
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{pm.__class__.__name__}_{pm.name}")
|
||||
role_path = stg_path.joinpath("role.json")
|
||||
write_json_file(role_path, pm.model_dump())
|
||||
|
||||
role_dict = read_json_file(role_path)
|
||||
new_pm = ProductManager(**role_dict)
|
||||
assert new_pm.name == pm.name
|
||||
assert len(new_pm.get_memories(1)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_serdeser_interrupt():
|
||||
async def test_role_serdeser_interrupt(context):
|
||||
role_c = RoleC()
|
||||
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
role_path = stg_path.joinpath("role.json")
|
||||
try:
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
except Exception:
|
||||
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
|
||||
role_c.serialize(stg_path)
|
||||
logger.error(f"Exception in `role_c.run`, detail: {format_trackback_info()}")
|
||||
write_json_file(role_path, role_c.model_dump())
|
||||
|
||||
assert role_c.rc.memory.count() == 1
|
||||
|
||||
new_role_a: Role = Role.deserialize(stg_path)
|
||||
assert new_role_a.rc.state == 1
|
||||
role_dict = read_json_file(role_path)
|
||||
new_role_c: Role = RoleC(**role_dict)
|
||||
assert new_role_c.rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
await new_role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of schema ser&deser
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.schema import CodingContext, Document, Documents, Message, TestingContext
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
MockICMessage,
|
||||
|
|
@ -12,12 +13,16 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_message_serdeser():
|
||||
def test_message_serdeser_from_create_model():
|
||||
with pytest.raises(KeyError):
|
||||
_ = Message(content="code", instruct_content={"class": "test", "key": "value"})
|
||||
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
ic_inst = ic_obj(**out_data)
|
||||
|
||||
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
|
||||
message = Message(content="code", instruct_content=ic_inst, role="engineer", cause_by=WriteCode)
|
||||
ser_data = message.model_dump()
|
||||
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert ser_data["instruct_content"]["class"] == "code"
|
||||
|
|
@ -25,28 +30,72 @@ def test_message_serdeser():
|
|||
new_message = Message(**ser_data)
|
||||
assert new_message.cause_by == any_to_str(WriteCode)
|
||||
assert new_message.cause_by in [any_to_str(WriteCode)]
|
||||
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
|
||||
message = Message(content="test_ic", instruct_content=MockICMessage())
|
||||
assert new_message.instruct_content == ic_obj(**out_data)
|
||||
assert new_message.instruct_content == ic_inst
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
assert new_message == message
|
||||
|
||||
mock_msg = MockMessage()
|
||||
message = Message(content="test_ic", instruct_content=mock_msg)
|
||||
ser_data = message.model_dump()
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.instruct_content != MockICMessage() # TODO
|
||||
|
||||
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
|
||||
ser_data = message.model_dump()
|
||||
assert "class" in ser_data["instruct_content"]
|
||||
assert new_message.instruct_content == mock_msg
|
||||
assert new_message == message
|
||||
|
||||
|
||||
def test_message_without_postprocess():
|
||||
"""to explain `instruct_content` should be postprocessed"""
|
||||
"""to explain `instruct_content` from `create_model_class` should be postprocessed"""
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
|
||||
message = MockICMessage(content="code", instruct_content=ic_obj(**out_data))
|
||||
ser_data = message.model_dump()
|
||||
assert ser_data["instruct_content"] == {}
|
||||
|
||||
ser_data["instruct_content"] = None
|
||||
new_message = MockMessage(**ser_data)
|
||||
new_message = MockICMessage(**ser_data)
|
||||
assert new_message.instruct_content != ic_obj(**out_data)
|
||||
assert new_message != message
|
||||
|
||||
|
||||
def test_message_serdeser_from_basecontext():
|
||||
doc_msg = Message(content="test_document", instruct_content=Document(content="test doc"))
|
||||
ser_data = doc_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["value"]["content"] == "test doc"
|
||||
assert ser_data["instruct_content"]["value"]["filename"] == ""
|
||||
|
||||
docs_msg = Message(
|
||||
content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")})
|
||||
)
|
||||
ser_data = docs_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["class"] == "Documents"
|
||||
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["content"] == "test doc"
|
||||
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["filename"] == ""
|
||||
|
||||
code_ctxt = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=Document(root_path="docs/system_design", filename="xx.json", content="xxx"),
|
||||
task_doc=Document(root_path="docs/tasks", filename="xx.json", content="xxx"),
|
||||
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
|
||||
)
|
||||
code_ctxt_msg = Message(content="coding_context", instruct_content=code_ctxt)
|
||||
ser_data = code_ctxt_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["class"] == "CodingContext"
|
||||
|
||||
new_code_ctxt_msg = Message(**ser_data)
|
||||
assert new_code_ctxt_msg.instruct_content == code_ctxt
|
||||
assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py"
|
||||
assert new_code_ctxt_msg == code_ctxt_msg
|
||||
|
||||
testing_ctxt = TestingContext(
|
||||
filename="test.py",
|
||||
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
|
||||
test_doc=Document(root_path="docs/tests", filename="test.py", content="xxx"),
|
||||
)
|
||||
testing_ctxt_msg = Message(content="testing_context", instruct_content=testing_ctxt)
|
||||
ser_data = testing_ctxt_msg.model_dump()
|
||||
new_testing_ctxt_msg = Message(**ser_data)
|
||||
assert new_testing_ctxt_msg.instruct_content == testing_ctxt
|
||||
assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py"
|
||||
assert new_testing_ctxt_msg == testing_ctxt_msg
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@ from metagpt.roles.role import Role, RoleReactMode
|
|||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class MockICMessage(BaseModel):
|
||||
content: str = "test_ic"
|
||||
|
||||
|
||||
class MockMessage(BaseModel):
|
||||
content: str = "test_msg"
|
||||
|
||||
|
||||
class MockICMessage(BaseModel):
|
||||
"""to test normal dict without postprocess"""
|
||||
|
||||
content: str = ""
|
||||
content: str = "test_ic_msg"
|
||||
instruct_content: Optional[BaseModel] = Field(default=None)
|
||||
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ class RoleA(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleA, self).__init__(**kwargs)
|
||||
self._init_actions([ActionPass])
|
||||
self.set_actions([ActionPass])
|
||||
self._watch([UserRequirement])
|
||||
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ class RoleB(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleB, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self.set_actions([ActionOK, ActionRaise])
|
||||
self._watch([ActionPass])
|
||||
self.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
|
||||
|
|
@ -92,7 +92,7 @@ class RoleC(Role):
|
|||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleC, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self.set_actions([ActionOK, ActionRaise])
|
||||
self._watch([UserRequirement])
|
||||
self.rc.react_mode = RoleReactMode.BY_ORDER
|
||||
self.rc.memory.ignore_id = True
|
||||
|
|
|
|||
|
|
@ -5,15 +5,8 @@ import pytest
|
|||
from metagpt.roles.sk_agent import SkAgent
|
||||
|
||||
|
||||
def test_sk_agent_serialize():
|
||||
role = SkAgent()
|
||||
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
|
||||
assert "name" in ser_role_dict
|
||||
assert "planner" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_agent_deserialize():
|
||||
async def test_sk_agent_serdeser():
|
||||
role = SkAgent()
|
||||
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
|
||||
assert "name" in ser_role_dict
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# @Desc :
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
from metagpt.utils.common import write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
|
|
@ -20,8 +21,8 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_team_deserialize():
|
||||
company = Team()
|
||||
def test_team_deserialize(context):
|
||||
company = Team(context=context)
|
||||
|
||||
pm = ProductManager()
|
||||
arch = Architect()
|
||||
|
|
@ -45,9 +46,16 @@ def test_team_deserialize():
|
|||
assert new_company.env.get_role(arch.profile) is not None
|
||||
|
||||
|
||||
def test_team_serdeser_save():
|
||||
company = Team()
|
||||
def mock_team_serialize(self, stg_path: Path = serdeser_path.joinpath("team")):
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
|
||||
|
||||
def test_team_serdeser_save(mocker, context):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
company = Team(context=context)
|
||||
company.hire([RoleC()])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
|
|
@ -61,12 +69,14 @@ def test_team_serdeser_save():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover():
|
||||
async def test_team_recover(mocker, context):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
company = Team(context=context)
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
|
|
@ -75,9 +85,9 @@ async def test_team_recover():
|
|||
ser_data = company.model_dump()
|
||||
new_company = Team(**ser_data)
|
||||
|
||||
new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
|
||||
# assert new_role_c.rc.env != role_c.rc.env # TODO
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
assert new_role_c.rc.memory == role_c.rc.memory
|
||||
assert new_role_c.rc.env != role_c.rc.env
|
||||
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
|
||||
|
||||
new_company.run_project(idea)
|
||||
|
|
@ -85,12 +95,14 @@ async def test_team_recover():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_save():
|
||||
async def test_team_recover_save(mocker, context):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a 2048 web game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
company = Team(context=context)
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
|
|
@ -98,8 +110,8 @@ async def test_team_recover_save():
|
|||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c.rc.memory == role_c.rc.memory
|
||||
# assert new_role_c.rc.env != role_c.rc.env
|
||||
assert new_role_c.rc.memory == role_c.rc.memory
|
||||
assert new_role_c.rc.env != role_c.rc.env
|
||||
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
|
||||
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
|
||||
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
|
||||
|
|
@ -109,15 +121,17 @@ async def test_team_recover_save():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_multi_roles_save():
|
||||
async def test_team_recover_multi_roles_save(mocker, context):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
role_a = RoleA()
|
||||
role_b = RoleB()
|
||||
|
||||
company = Team()
|
||||
company = Team(context=context)
|
||||
company.hire([role_a, role_b])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
|
@ -130,3 +144,7 @@ async def test_team_recover_multi_roles_save():
|
|||
assert new_company.env.get_role(role_b.profile).rc.state == 1
|
||||
|
||||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_assistant_deserialize():
|
||||
async def test_tutorial_assistant_serdeser(context):
|
||||
role = TutorialAssistant()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
|
|
|
|||
|
|
@ -9,22 +9,23 @@ from metagpt.actions import WriteCode
|
|||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteCode()
|
||||
def test_write_design_serdeser(context):
|
||||
action = WriteCode(context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert ser_action_dict["name"] == "WriteCode"
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deserialize():
|
||||
context = CodingContext(
|
||||
async def test_write_code_serdeser(context):
|
||||
context.src_workspace = context.repo.workdir / "srcs"
|
||||
coding_context = CodingContext(
|
||||
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
|
||||
)
|
||||
doc = Document(content=context.model_dump_json())
|
||||
action = WriteCode(context=doc)
|
||||
doc = Document(content=coding_context.model_dump_json())
|
||||
action = WriteCode(i_context=doc, context=context)
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteCode(**serialized_data)
|
||||
new_action = WriteCode(**serialized_data, context=context)
|
||||
|
||||
assert new_action.name == "WriteCode"
|
||||
await action.run()
|
||||
|
|
|
|||
|
|
@ -9,22 +9,23 @@ from metagpt.schema import CodingContext, Document
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review_deserialize():
|
||||
async def test_write_code_review_serdeser(context):
|
||||
context.src_workspace = context.repo.workdir / "srcs"
|
||||
code_content = """
|
||||
def div(a: int, b: int = 0):
|
||||
return a / b
|
||||
"""
|
||||
context = CodingContext(
|
||||
coding_context = CodingContext(
|
||||
filename="test_op.py",
|
||||
design_doc=Document(content="divide two numbers"),
|
||||
code_doc=Document(content=code_content),
|
||||
)
|
||||
|
||||
action = WriteCodeReview(context=context)
|
||||
action = WriteCodeReview(i_context=coding_context)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteCodeReview"
|
||||
|
||||
new_action = WriteCodeReview(**serialized_data)
|
||||
new_action = WriteCodeReview(**serialized_data, context=context)
|
||||
|
||||
assert new_action.name == "WriteCodeReview"
|
||||
await new_action.run()
|
||||
|
|
|
|||
|
|
@ -7,33 +7,25 @@ import pytest
|
|||
from metagpt.actions import WriteDesign, WriteTasks
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteDesign()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
def test_write_task_serialize():
|
||||
action = WriteTasks()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_design_deserialize():
|
||||
action = WriteDesign()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
async def test_write_design_serialize(context):
|
||||
action = WriteDesign(context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
new_action = WriteDesign(**ser_action_dict, context=context)
|
||||
assert new_action.name == "WriteDesign"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_task_deserialize():
|
||||
action = WriteTasks()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
async def test_write_task_serialize(context):
|
||||
action = WriteTasks(context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
new_action = WriteTasks(**ser_action_dict, context=context)
|
||||
assert new_action.name == "WriteTasks"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
|
|
|||
|
|
@ -29,14 +29,14 @@ class Person:
|
|||
],
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_action_deserialize(style: str, part: str):
|
||||
action = WriteDocstring()
|
||||
async def test_action_serdeser(style: str, part: str, context):
|
||||
action = WriteDocstring(context=context)
|
||||
serialized_data = action.model_dump()
|
||||
|
||||
assert "name" in serialized_data
|
||||
assert serialized_data["desc"] == "Write docstring for code."
|
||||
|
||||
new_action = WriteDocstring(**serialized_data)
|
||||
new_action = WriteDocstring(**serialized_data, context=context)
|
||||
|
||||
assert new_action.name == "WriteDocstring"
|
||||
assert new_action.desc == "Write docstring for code."
|
||||
|
|
|
|||
|
|
@ -9,18 +9,14 @@ from metagpt.actions import WritePRD
|
|||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_action_serialize(new_filename):
|
||||
action = WritePRD()
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_serdeser(new_filename, context):
|
||||
action = WritePRD(context=context)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize(new_filename):
|
||||
action = WritePRD()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WritePRD(**serialized_data)
|
||||
new_action = WritePRD(**ser_action_dict, context=context)
|
||||
assert new_action.name == "WritePRD"
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
with pytest.raises(FileNotFoundError):
|
||||
await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_review import WriteReview
|
||||
|
||||
CONTEXT = """
|
||||
TEMPLATE_CONTEXT = """
|
||||
{
|
||||
"Language": "zh_cn",
|
||||
"Programming Language": "Python",
|
||||
|
|
@ -42,13 +42,13 @@ CONTEXT = """
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = WriteReview()
|
||||
async def test_action_serdeser(context):
|
||||
action = WriteReview(context=context)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteReview"
|
||||
|
||||
new_action = WriteReview(**serialized_data)
|
||||
review = await new_action.run(CONTEXT)
|
||||
new_action = WriteReview(**serialized_data, context=context)
|
||||
review = await new_action.run(TEMPLATE_CONTEXT)
|
||||
|
||||
assert new_action.name == "WriteReview"
|
||||
assert type(review) == ActionNode
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory_deserialize(language: str, topic: str):
|
||||
action = WriteDirectory()
|
||||
async def test_write_directory_serdeser(language: str, topic: str, context):
|
||||
action = WriteDirectory(context=context)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteDirectory"
|
||||
assert serialized_data["language"] == "Chinese"
|
||||
|
||||
new_action = WriteDirectory(**serialized_data)
|
||||
new_action = WriteDirectory(**serialized_data, context=context)
|
||||
ret = await new_action.run(topic=topic)
|
||||
assert isinstance(ret, dict)
|
||||
assert "title" in ret
|
||||
|
|
@ -30,12 +30,12 @@ async def test_write_directory_deserialize(language: str, topic: str):
|
|||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content_deserialize(language: str, topic: str, directory: Dict):
|
||||
action = WriteContent(language=language, directory=directory)
|
||||
async def test_write_content_serdeser(language: str, topic: str, directory: Dict, context):
|
||||
action = WriteContent(language=language, directory=directory, context=context)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteContent"
|
||||
|
||||
new_action = WriteContent(**serialized_data)
|
||||
new_action = WriteContent(**serialized_data, context=context)
|
||||
ret = await new_action.run(topic=topic)
|
||||
assert isinstance(ret, str)
|
||||
assert list(directory.keys())[0] in ret
|
||||
|
|
|
|||
|
|
@ -1,7 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/11 14:44
|
||||
@Author : alexanderwu
|
||||
@File : test_action.py
|
||||
"""
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue