From 1265d3d924b0e1553591c6628d0c2de2a18d5722 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Fri, 8 Dec 2023 12:37:06 +0800 Subject: [PATCH] feat: make_tools by function. --- metagpt/actions/make_tools.py | 49 ++++++++++++++++++++++++ metagpt/provider/base_gpt_api.py | 2 +- tests/metagpt/actions/test_make_tools.py | 18 +++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 metagpt/actions/make_tools.py create mode 100644 tests/metagpt/actions/test_make_tools.py diff --git a/metagpt/actions/make_tools.py b/metagpt/actions/make_tools.py new file mode 100644 index 000000000..7fd05751e --- /dev/null +++ b/metagpt/actions/make_tools.py @@ -0,0 +1,49 @@ +from typing import List, Dict +from pathlib import Path +import re + +from tenacity import retry, stop_after_attempt, wait_fixed + +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.actions.write_analysis_code import WriteCodeByGenerate + + +class MakeTools(WriteCodeByGenerate): + DEFAULT_SYSTEM_MSG = """Please Create a General Function Code startswith `def` from any codes you got.\n + **Notice:The import statement must be written after `def`, it is very important for you. + Reflect on whether it meets the requirements of function. Must Write example code, and we will execute the example code.** + """ + + def __init__(self, name: str = '', context=None, llm=None, workspace: str = None): + super().__init__(name, context, llm) + self.workspace = workspace or "." + self.file_suffix = '.py' + + def parse_function_name(self, function_code: str) -> str: + # 定义正则表达式模式 + pattern = r'\bdef\s+([a-zA-Z_]\w*)\s*\(' + # 在代码中搜索匹配的模式 + match = re.search(pattern, function_code) + # 如果找到匹配项,则返回匹配的函数名;否则返回None + if match: + return match.group(1) + else: + return None + + def save(self, tool_code: str) -> None: + func_name = self.parse_function_name(tool_code) + if func_name is None: + raise ValueError(f"No function name found in {tool_code}") + saved_path = Path(self.workspace).joinpath(func_name+self.file_suffix) + logger.info(f"Saved tool_code {func_name} in {str(saved_path)}.") + saved_path.write_text(tool_code, encoding='utf-8') + + @retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) + async def run(self, code_message: List[Message | Dict], **kwargs) -> str: + msgs = self.process_msg(code_message) + logger.info(f"Ask: {msgs[-1]}") + tool_code = await self.llm.aask_code(msgs, **kwargs) + logger.info(f"Respond: Got {tool_code} from llm.") + self.save(tool_code['code']) + return tool_code["code"] diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index b6b034329..5516ceb7c 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -150,7 +150,7 @@ class BaseGPTAPI(BaseChatbot): :return dict: return the first function arguments of choice, for example, {'language': 'python', 'code': "print('Hello, World!')"} """ - return json.loads(self.get_choice_function(rsp)["arguments"]) + return json.loads(self.get_choice_function(rsp)["arguments"], strict=False) def messages_to_prompt(self, messages: list[dict]): """[{"role": "user", "content": msg}] to user: etc.""" diff --git a/tests/metagpt/actions/test_make_tools.py b/tests/metagpt/actions/test_make_tools.py new file mode 100644 index 000000000..2c5168bf1 --- /dev/null +++ b/tests/metagpt/actions/test_make_tools.py @@ -0,0 +1,18 @@ +import pytest + +from metagpt.actions.execute_code import ExecutePyCode +from metagpt.actions.make_tools import MakeTools + + +@pytest.mark.asyncio +async def test_make_tools(): + code = "import yfinance as yf\n\n# Collect Alibaba stock data\nalibaba = yf.Ticker('BABA')\ndata = alibaba.history(period='1d', start='2022-01-01', end='2022-12-31')\nprint(data.head())" + msgs = [{'role': 'assistant', 'content': code}] + mt = MakeTools() + tool_code = await mt.run(msgs) + print(tool_code) + ep = ExecutePyCode() + tool_code = "!pip install yfinance\n" + tool_code + result, res_type = await ep.run(tool_code) + assert res_type is True + print(result)