From 3faa094248d819a178156471c9990089b9a8d5a7 Mon Sep 17 00:00:00 2001 From: yzlin Date: Thu, 18 Jan 2024 23:45:37 +0800 Subject: [PATCH] fix aask_code issues in ml_engineer --- metagpt/actions/debug_code.py | 3 +-- metagpt/actions/ml_da_action.py | 2 +- metagpt/actions/write_analysis_code.py | 8 ++++---- metagpt/roles/code_interpreter.py | 11 ++++------- metagpt/roles/ml_engineer.py | 4 ++-- 5 files changed, 12 insertions(+), 16 deletions(-) diff --git a/metagpt/actions/debug_code.py b/metagpt/actions/debug_code.py index e5e0ac5d4..121c126c4 100644 --- a/metagpt/actions/debug_code.py +++ b/metagpt/actions/debug_code.py @@ -119,5 +119,4 @@ class DebugCode(BaseWriteAnalysisCode): runtime_result=runtime_result, ) # 根据reflection结果重写代码 - improv_code = reflection["improved_impl"] - return improv_code + return {"code": reflection["improved_impl"]} diff --git a/metagpt/actions/ml_da_action.py b/metagpt/actions/ml_da_action.py index 584c4db7a..d4e77773f 100644 --- a/metagpt/actions/ml_da_action.py +++ b/metagpt/actions/ml_da_action.py @@ -63,4 +63,4 @@ class UpdateDataColumns(Action): prompt = UPDATE_DATA_COLUMNS.format(history_code=code_context) tool_config = create_func_config(PRINT_DATA_COLUMNS) rsp = await self.llm.aask_code(prompt, **tool_config) - return rsp["code"] + return rsp diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 65be198ef..cf806a986 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -59,7 +59,7 @@ class BaseWriteAnalysisCode(Action): } return messages - async def run(self, context: List[Message], plan: Plan = None) -> str: + async def run(self, context: List[Message], plan: Plan = None) -> dict: """Run of a code writing action, used in data analysis or modeling Args: @@ -67,7 +67,7 @@ class BaseWriteAnalysisCode(Action): plan (Plan, optional): Overall plan. Defaults to None. Returns: - str: The code string. + dict: code result in the format of {"code": "print('hello world')", "language": "python"} """ @@ -174,7 +174,7 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) - return rsp["code"] + return rsp class WriteCodeWithToolsML(WriteCodeWithTools): @@ -230,7 +230,7 @@ class WriteCodeWithToolsML(WriteCodeWithTools): tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) context = [Message(content=prompt, role="user")] - return context, rsp["code"] + return context, rsp class MakeTools(WriteCodeByGenerate): diff --git a/metagpt/roles/code_interpreter.py b/metagpt/roles/code_interpreter.py index 46cc00d5e..f972e72e2 100644 --- a/metagpt/roles/code_interpreter.py +++ b/metagpt/roles/code_interpreter.py @@ -54,7 +54,7 @@ class CodeInterpreter(Role): async def _act_on_task(self, current_task: Task) -> TaskResult: code, result, is_success = await self._write_and_exec_code() - task_result = TaskResult(code=code['code'], result=result, is_success=is_success) + task_result = TaskResult(code=code, result=result, is_success=is_success) return task_result async def _write_and_exec_code(self, max_retry: int = 3): @@ -69,7 +69,7 @@ class CodeInterpreter(Role): ### write code ### code, cause_by = await self._write_code() - self.working_memory.add(Message(content=code['code'], role="assistant", cause_by=cause_by)) + self.working_memory.add(Message(content=code["code"], role="assistant", cause_by=cause_by)) ### execute code ### result, success = await self.execute_code.run(**code) @@ -78,7 +78,7 @@ class CodeInterpreter(Role): self.working_memory.add(Message(content=result, role="user", cause_by=ExecutePyCode)) ### process execution result ### - if "!pip" in code: + if "!pip" in code["code"]: success = False counter += 1 @@ -89,7 +89,7 @@ class CodeInterpreter(Role): if ReviewConst.CHANGE_WORD[0] in review: counter = 0 # redo the task again with help of human suggestions - return code, result, success + return code["code"], result, success async def _write_code(self): todo = WriteCodeByGenerate() if not self.use_tools else WriteCodeWithTools() @@ -98,9 +98,6 @@ class CodeInterpreter(Role): context = self.planner.get_useful_memories() # print(*context, sep="\n***\n") code = await todo.run(context=context, plan=self.planner.plan, temperature=0.0) - # 暂时在这里转换 WriteCodeWithTools 的输出 - if isinstance(code, str): - code = {'code': code, 'language': 'python'} return code, todo diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index aeea39c0c..6b671f9c2 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -46,7 +46,7 @@ class MLEngineer(CodeInterpreter): logger.info(f"new code \n{code}") cause_by = DebugCode - self.latest_code = code + self.latest_code = code["code"] return code, cause_by @@ -61,6 +61,6 @@ class MLEngineer(CodeInterpreter): logger.info("Check columns in updated data") code = await UpdateDataColumns().run(self.planner.plan) success = False - result, success = await self.execute_code.run(code) + result, success = await self.execute_code.run(**code) print(result) return result if success else ""