From b7624d7298536135e84c1af1f08ad3e51bf09093 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Tue, 12 Dec 2023 14:41:43 +0800 Subject: [PATCH] feat: add WriteCodeWithUDFs. --- metagpt/actions/write_analysis_code.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/metagpt/actions/write_analysis_code.py b/metagpt/actions/write_analysis_code.py index 1127dc78b..725c4aa2a 100644 --- a/metagpt/actions/write_analysis_code.py +++ b/metagpt/actions/write_analysis_code.py @@ -7,6 +7,7 @@ from typing import Dict, List, Union, Tuple from metagpt.actions import Action +from metagpt.llm import LLM from metagpt.logs import logger from metagpt.prompts.ml_engineer import ( TOOL_RECOMMENDATION_PROMPT, @@ -19,7 +20,7 @@ from metagpt.prompts.ml_engineer import ( ) from metagpt.schema import Message, Plan from metagpt.tools.functions import registry -from metagpt.utils.common import create_func_config +from metagpt.utils.common import create_func_config, CodeParser class BaseWriteAnalysisCode(Action): @@ -203,3 +204,24 @@ class WriteCodeWithTools(BaseWriteAnalysisCode): tool_config = create_func_config(CODE_GENERATOR_WITH_TOOLS) rsp = await self.llm.aask_code(prompt, **tool_config) return rsp["code"] + + +class WriteCodeWithUDFs(WriteCodeByGenerate): + """Write code with user defined function.""" + from metagpt.tools.functions.libs.udf import UDFS + + DEFAULT_SYSTEM_MSG = f"""Please remember these functions, you will use these functions to write code:\n + {UDFS} + """ + + async def aask_code_and_text(self, context: List[Dict], **kwargs) -> Tuple[str]: + rsp = await self.llm.acompletion(context, **kwargs) + rsp_content = self.llm.get_choice_text(rsp) + code = CodeParser.parse_code(None, rsp_content) + return code, rsp_content + + async def run(self, context: List[Message], plan: Plan = None, task_guide: str = "", **kwargs) -> str: + prompt = self.process_msg(context) + logger.info(prompt[-1]) + code, _ = await self.aask_code_and_text(prompt, **kwargs) + return code