From 824cc247b66a385a91ff32367ec0d67936543630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Fri, 24 Nov 2023 11:05:51 +0800 Subject: [PATCH 1/3] chore --- metagpt/actions/write_code_function.py | 2 +- metagpt/actions/write_plan.py | 3 +- metagpt/roles/ml_engineer.py | 10 ++--- .../actions/test_write_code_function.py | 38 +++++++++---------- 4 files changed, 26 insertions(+), 27 deletions(-) diff --git a/metagpt/actions/write_code_function.py b/metagpt/actions/write_code_function.py index 4ec565eb1..406d215a2 100644 --- a/metagpt/actions/write_code_function.py +++ b/metagpt/actions/write_code_function.py @@ -58,4 +58,4 @@ class WriteCodeFunction(BaseWriteAnalysisCode): ) -> str: prompt = self.process_msg(context, system_msg) code_content = await self.llm.aask_code(prompt, **kwargs) - return code_content + return code_content['code'] diff --git a/metagpt/actions/write_plan.py b/metagpt/actions/write_plan.py index 48cb1aad5..8db988c01 100644 --- a/metagpt/actions/write_plan.py +++ b/metagpt/actions/write_plan.py @@ -9,6 +9,7 @@ import json from metagpt.actions import Action from metagpt.schema import Message, Task +from metagpt.utils.common import CodeParser class WritePlan(Action): PROMPT_TEMPLATE = """ @@ -35,7 +36,7 @@ class WritePlan(Action): .replace("__current_plan__", current_plan).replace("__max_tasks__", str(max_tasks)) ) rsp = await self._aask(prompt) - return rsp + return CodeParser.parse_code(None, rsp) if rsp.startswith("```") else rsp @staticmethod def rsp_to_tasks(rsp: str) -> List[Task]: diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index c795bda11..3bb0c1660 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -64,14 +64,14 @@ class MLEngineer(Role): success = False while not success and counter < max_retry: context = self.get_memories() - - code = "print('abc')" - # code = await WriteCodeFunction().run(context=context) + print(f"{'*'*20}\n {context}") + # code = "print('abc')" + code = await WriteCodeFunction().run(context=context) # code = await WriteCodeWithOps.run(context, task, result) self._rc.memory.add(Message(content=code, role="assistant", cause_by=WriteCodeFunction)) result, success = await ExecutePyCode().run(code) - self._rc.memory.add(Message(content=result, role="assistant", cause_by=ExecutePyCode)) + self._rc.memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) # if not success: # await self._ask_review() @@ -83,7 +83,7 @@ class MLEngineer(Role): async def _ask_review(self): context = self.get_memories() review, confirmed = await AskReview().run(context=context[-5:], plan=self.plan) - self._rc.memory.add(Message(content=review, role="assistant", cause_by=AskReview)) + self._rc.memory.add(Message(content=review, role="user", cause_by=AskReview)) return confirmed async def _update_plan(self, max_tasks: int = 3): diff --git a/tests/metagpt/actions/test_write_code_function.py b/tests/metagpt/actions/test_write_code_function.py index 4ff1a63c4..1940c9667 100644 --- a/tests/metagpt/actions/test_write_code_function.py +++ b/tests/metagpt/actions/test_write_code_function.py @@ -2,25 +2,24 @@ import pytest from metagpt.actions.write_code_function import WriteCodeFunction from metagpt.actions.execute_code import ExecutePyCode +from metagpt.schema import Message -@pytest.mark.asyncio -async def test_write_code(): - write_code = WriteCodeFunction() - code = await write_code.run("Write a hello world code.") - assert "language" in code.content - assert "code" in code.content - print(code) +# @pytest.mark.asyncio +# async def test_write_code(): +# write_code = WriteCodeFunction() +# code = await write_code.run("Write a hello world code.") +# assert len(code) > 0 +# print(code) -@pytest.mark.asyncio -async def test_write_code_by_list_prompt(): - write_code = WriteCodeFunction() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - code = await write_code.run(msg) - assert "language" in code.content - assert "code" in code.content - print(code) +# @pytest.mark.asyncio +# async def test_write_code_by_list_prompt(): +# write_code = WriteCodeFunction() +# msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] +# code = await write_code.run(msg) +# assert len(code) > 0 +# print(code) @pytest.mark.asyncio @@ -31,11 +30,10 @@ async def test_write_code_by_list_plan(): plan = ["随机生成一个pandas DataFrame时间序列", "绘制这个时间序列的直方图", "求均值"] for task in plan: print(f"\n任务: {task}\n\n") - messages.append(task) + messages.append(Message(task, role='assistant')) code = await write_code.run(messages) - messages.append(code) - assert "language" in code.content - assert "code" in code.content + messages.append(Message(code, role='assistant')) + assert len(code) > 0 output = await execute_code.run(code) print(f"\n[Output]: 任务{task}的执行结果是: \n{output}\n") - messages.append(output) + messages.append(output[0]) From bba3db5ffe062f6b5cb849b701bdd9c11665bf85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Fri, 24 Nov 2023 12:27:17 +0800 Subject: [PATCH 2/3] fix: --- metagpt/actions/write_code_function.py | 47 +++++++++++++------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/metagpt/actions/write_code_function.py b/metagpt/actions/write_code_function.py index 406d215a2..1f273a707 100644 --- a/metagpt/actions/write_code_function.py +++ b/metagpt/actions/write_code_function.py @@ -27,31 +27,30 @@ class WriteCodeFunction(BaseWriteAnalysisCode): super().__init__(name, context, llm) def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None): - if isinstance(prompt, str): - return system_msg + prompt if system_msg else prompt + default_system_msg = """You are Open Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.""" + # 全部转成list + if not isinstance(prompt, list): + prompt = [prompt] + assert isinstance(prompt, list) + # 转成list[dict] + messages = [] + for p in prompt: + if isinstance(p, str): + messages.append({'role': 'user', 'content': p}) + elif isinstance(p, dict): + messages.append(p) + elif isinstance(p, Message): + if isinstance(p.content, str): + messages.append(p.to_dict()) + elif isinstance(p.content, dict) and 'code' in p.content: + messages.append(p.content['code']) - if isinstance(prompt, Message): - if isinstance(prompt.content, dict): - prompt.content = system_msg + str([(k, v) for k, v in prompt.content.items()])\ - if system_msg else prompt.content - else: - prompt.content = system_msg + prompt.content if system_msg else prompt.content - return prompt - - if isinstance(prompt, list): - _prompt = [] - for msg in prompt: - if isinstance(msg, Message) and isinstance(msg.content, dict): - msg.content = str([(k, v) for k, v in msg.content.items()]) - if isinstance(msg, Message): - msg = msg.to_dict() - _prompt.append(msg) - prompt = _prompt - - if isinstance(prompt, list) and system_msg: - if system_msg not in prompt[0]['content']: - prompt[0]['content'] = system_msg + prompt[0]['content'] - return prompt + # 添加默认的提示词 + if default_system_msg not in messages[0]['content'] and messages[0]['role'] != 'system': + messages.insert(0, {'role': 'system', 'content': default_system_msg}) + elif default_system_msg not in messages[0]['content'] and messages[0]['role'] == 'system': + messages[0] = {'role': 'system', 'content': messages[0]['content']+default_system_msg} + return messages async def run( self, context: [List[Message]], plan: Plan = None, task_guidance: str = "", system_msg: str = None, **kwargs From 8b171b51337d5c416da9c52e78c0e0cbae138d82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Fri, 24 Nov 2023 12:37:20 +0800 Subject: [PATCH 3/3] fix: write_code_function bug. --- metagpt/actions/write_code_function.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/actions/write_code_function.py b/metagpt/actions/write_code_function.py index 1f273a707..c2fb6189a 100644 --- a/metagpt/actions/write_code_function.py +++ b/metagpt/actions/write_code_function.py @@ -27,7 +27,7 @@ class WriteCodeFunction(BaseWriteAnalysisCode): super().__init__(name, context, llm) def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None): - default_system_msg = """You are Open Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.""" + default_system_msg = """You are Open Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step.**""" # 全部转成list if not isinstance(prompt, list): prompt = [prompt]