diff --git a/metagpt/actions/write_code_function.py b/metagpt/actions/write_code_function.py index 4ec565eb1..406d215a2 100644 --- a/metagpt/actions/write_code_function.py +++ b/metagpt/actions/write_code_function.py @@ -58,4 +58,4 @@ class WriteCodeFunction(BaseWriteAnalysisCode): ) -> str: prompt = self.process_msg(context, system_msg) code_content = await self.llm.aask_code(prompt, **kwargs) - return code_content + return code_content['code'] diff --git a/metagpt/actions/write_plan.py b/metagpt/actions/write_plan.py index 48cb1aad5..8db988c01 100644 --- a/metagpt/actions/write_plan.py +++ b/metagpt/actions/write_plan.py @@ -9,6 +9,7 @@ import json from metagpt.actions import Action from metagpt.schema import Message, Task +from metagpt.utils.common import CodeParser class WritePlan(Action): PROMPT_TEMPLATE = """ @@ -35,7 +36,7 @@ class WritePlan(Action): .replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks)) ) rsp = await self._aask(prompt) - return rsp + return CodeParser.parse_code(None, rsp) if rsp.startswith("```") else rsp @staticmethod def rsp_to_tasks(rsp: str) -> List[Task]: diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index c795bda11..3bb0c1660 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -64,14 +64,14 @@ class MLEngineer(Role): success = False while not success and counter < max_retry: context = self.get_memories() - - code = "print('abc')" - # code = await WriteCodeFunction().run(context=context) + print(f"{'*'*20}\n {context}") + # code = "print('abc')" + code = await WriteCodeFunction().run(context=context) # code = await WriteCodeWithOps.run(context, task, result) self._rc.memory.add(Message(content=code, role="assistant", cause_by=WriteCodeFunction)) result, success = await ExecutePyCode().run(code) - self._rc.memory.add(Message(content=result, role="assistant", cause_by=ExecutePyCode)) + self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) # if not success: # await self._ask_review() @@ -83,7 +83,7 @@ class MLEngineer(Role): async def _ask_review(self): context = self.get_memories() review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan) - self._rc.memory.add(Message(content=review, role="assistant", cause_by=AskReview)) + self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview)) return confirmed async def _update_plan(self, max_tasks: int = 3): diff --git a/tests/metagpt/actions/test_write_code_function.py b/tests/metagpt/actions/test_write_code_function.py index 4ff1a63c4..1940c9667 100644 --- a/tests/metagpt/actions/test_write_code_function.py +++ b/tests/metagpt/actions/test_write_code_function.py @@ -2,25 +2,24 @@ import pytest from metagpt.actions.write_code_function import WriteCodeFunction from metagpt.actions.execute_code import ExecutePyCode +from metagpt.schema import Message -@pytest.mark.asyncio -async def test_write_code(): - write_code = WriteCodeFunction() - code = await write_code.run("Write a hello world code.") - assert "language" in code.content - assert "code" in code.content - print(code) +# @pytest.mark.asyncio +# async def test_write_code(): +# write_code = WriteCodeFunction() +# code = await write_code.run("Write a hello world code.") +# assert len(code) > 0 +# print(code) -@pytest.mark.asyncio -async def test_write_code_by_list_prompt(): - write_code = WriteCodeFunction() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - code = await write_code.run(msg) - assert "language" in code.content - assert "code" in code.content - print(code) +# @pytest.mark.asyncio +# async def test_write_code_by_list_prompt(): +# write_code = WriteCodeFunction() +# msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] +# code = await write_code.run(msg) +# assert len(code) > 0 +# print(code) @pytest.mark.asyncio @@ -31,11 +30,10 @@ async def test_write_code_by_list_plan(): plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "求均值"] for task in plan: print(f"\n任务: {task}\n\n") - messages.append(task) + messages.append(Message(task, role='assistant')) code = await write_code.run(messages) - messages.append(code) - assert "language" in code.content - assert "code" in code.content + messages.append(Message(code, role='assistant')) + assert len(code) > 0 output = await execute_code.run(code) print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") - messages.append(output) + messages.append(output[0])