mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-27 14:25:20 +02:00
Merge branch 'dev_make_tools' into 'dev'
feat: make_tools by function. See merge request agents/data_agents_opt!17
This commit is contained in:
commit
a93173df65
11 changed files with 445 additions and 27 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -129,6 +129,7 @@ venv.bak/
|
|||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
metagpt/tools/functions/libs/udf/*.py
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
|
|
|||
|
|
@ -5,10 +5,15 @@
|
|||
@File : write_code_v2.py
|
||||
"""
|
||||
from typing import Dict, List, Union, Tuple
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
from pathlib import Path
|
||||
import re
|
||||
import json
|
||||
|
||||
import yaml
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.ml_engineer import (
|
||||
TOOL_RECOMMENDATION_PROMPT,
|
||||
|
|
@ -24,7 +29,7 @@ from metagpt.utils.common import create_func_config, remove_comments
|
|||
|
||||
|
||||
class BaseWriteAnalysisCode(Action):
|
||||
DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
|
||||
DEFAULT_SYSTEM_MSG = """You are Code Interpreter, a world-class programmer that can complete any goal by executing code. Strictly follow the plan and generate code step by step. Each step of the code will be executed on the user's machine, and the user will provide the code execution results to you.**Notice: The code for the next step depends on the code for the previous step. Must reuse variables in the lastest other code directly, dont creat it again, it is very import for you. Use !pip install in a standalone block to install missing packages.Usually the libraries you need are already installed.Dont check if packages already imported.**""" # prompt reference: https://github.com/KillianLucas/open-interpreter/blob/v0.1.4/interpreter/system_message.txt
|
||||
# REUSE_CODE_INSTRUCTION = """ATTENTION: DONT include codes from previous tasks in your current code block, include new codes only, DONT repeat codes!"""
|
||||
|
||||
def process_msg(self, prompt: Union[str, List[Dict], Message, List[Message]], system_msg: str = None):
|
||||
|
|
@ -107,13 +112,17 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
|
|||
if self.schema_path is not None:
|
||||
self._load_tools(schema_path)
|
||||
|
||||
def _load_tools(self, schema_path):
|
||||
def _load_tools(self, schema_path, schema_module=None):
|
||||
"""Load tools from yaml file"""
|
||||
yml_files = schema_path.glob("*.yml")
|
||||
for yml_file in yml_files:
|
||||
module = yml_file.stem
|
||||
with open(yml_file, "r", encoding="utf-8") as f:
|
||||
self.available_tools[module] = yaml.safe_load(f)
|
||||
if isinstance(schema_path, dict):
|
||||
schema_module = schema_module or 'udf'
|
||||
self.available_tools.update({schema_module: schema_path})
|
||||
else:
|
||||
yml_files = schema_path.glob("*.yml")
|
||||
for yml_file in yml_files:
|
||||
module = yml_file.stem
|
||||
with open(yml_file, "r", encoding="utf-8") as f:
|
||||
self.available_tools[module] = yaml.safe_load(f)
|
||||
|
||||
def _parse_recommend_tools(self, module: str, recommend_tools: list) -> dict:
|
||||
"""
|
||||
|
|
@ -217,3 +226,82 @@ class WriteCodeWithTools(BaseWriteAnalysisCode):
|
|||
rsp = await self.llm.aask_code(prompt, **tool_config)
|
||||
context = [Message(content=prompt, role="user")]
|
||||
return context, rsp["code"]
|
||||
|
||||
|
||||
class MakeTools(WriteCodeByGenerate):
|
||||
DEFAULT_SYSTEM_MSG = """Convert any codes provied for you to a very General Function Code startswith `def`.\n
|
||||
**Notice:
|
||||
1. Your code must contain a general function start with `def`.
|
||||
2. Refactor your code to get the most efficient implementation for large input data in the shortest amount of time.
|
||||
3. Must use Google style for function docstring, and your docstring must be consistent with the code,without missing anything.
|
||||
4. Write example code after `if __name__ == '__main__':`by using old varibales in old code,
|
||||
and make sure it could be execute in the user's machine.
|
||||
5. Only use the imported packages**
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = '', context: list[Message] = None, llm: LLM = None, workspace: str = None):
|
||||
"""
|
||||
:param str name: name, defaults to ''
|
||||
:param list[Message] context: context, defaults to None
|
||||
:param LLM llm: llm, defaults to None
|
||||
:param str workspace: tools code saved file path dir, defaults to None
|
||||
"""
|
||||
super().__init__(name, context, llm)
|
||||
self.workspace = workspace or str(Path(__file__).parents[1].joinpath("./tools/functions/libs/udf"))
|
||||
self.file_suffix: str = '.py'
|
||||
self.context = []
|
||||
|
||||
def parse_function_name(self, function_code: str) -> str:
|
||||
# 定义正则表达式模式
|
||||
pattern = r'\bdef\s+([a-zA-Z_]\w*)\s*\('
|
||||
# 在代码中搜索匹配的模式
|
||||
match = re.search(pattern, function_code)
|
||||
# 如果找到匹配项,则返回匹配的函数名;否则返回None
|
||||
if match:
|
||||
return match.group(1)
|
||||
else:
|
||||
return None
|
||||
|
||||
def save(self, tool_code: str) -> None:
|
||||
func_name = self.parse_function_name(tool_code)
|
||||
if func_name is None:
|
||||
raise ValueError(f"No function name found in {tool_code}")
|
||||
saved_path = Path(self.workspace).joinpath(func_name+self.file_suffix)
|
||||
logger.info(f"Saved tool_code {func_name} in {str(saved_path)}.")
|
||||
saved_path.write_text(tool_code, encoding='utf-8')
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1))
|
||||
async def run(self, code: str | List[dict], code_desc: str = None, **kwargs) -> str:
|
||||
# 拼接code prompt
|
||||
code_prompt = f"The following code is about {code_desc}, convert it to be a General Function, {code}"
|
||||
if not self.context:
|
||||
self.context = self.process_msg(code_prompt)
|
||||
else:
|
||||
self.context.append(self.process_msg(code_prompt)[-1])
|
||||
logger.info(f"\n\nAsk to Make tools:\n{'-'*60}\n {self.context[-1]}")
|
||||
|
||||
# 更新kwargs
|
||||
if 'code' in kwargs:
|
||||
kwargs.pop('code')
|
||||
if 'code_desc' in kwargs:
|
||||
kwargs.pop('code_desc')
|
||||
|
||||
max_tries, current_try = 3, 0
|
||||
while True:
|
||||
tool_code = await self.llm.aask_code(self.context, **kwargs)
|
||||
func_name = self.parse_function_name(tool_code['code'])
|
||||
current_try += 1
|
||||
# make tools failed, add error message to context.
|
||||
if not func_name:
|
||||
logger.info(f"\n\nTools Respond\n{'-'*60}\n: {tool_code}")
|
||||
logger.error(f"No function name found in code, we will retry make tools.\n{tool_code['code']}\n")
|
||||
self.context.append({'role': 'user', 'content': 'We need a general function in above code,but not found function.'})
|
||||
# end make tools
|
||||
if func_name is not None or current_try >= max_tries:
|
||||
if current_try >= max_tries:
|
||||
logger.error(f"We have tried the maximum number of attempts {max_tries}\
|
||||
and still have not created tools successfully, we will skip it.")
|
||||
break
|
||||
logger.info(f"\n\nTools Respond\n{'-'*60}\n: {tool_code}")
|
||||
self.save(tool_code['code'])
|
||||
return tool_code["code"]
|
||||
|
|
|
|||
|
|
@ -307,6 +307,7 @@ ML_SPECIFIC_PROMPT = {
|
|||
ML_MODULE_MAP = {
|
||||
"data_preprocess": "metagpt.tools.functions.libs.data_preprocess",
|
||||
"feature_engineering": "metagpt.tools.functions.libs.feature_engineering",
|
||||
"udf": "metagpt.tools.functions.libs.udf",
|
||||
}
|
||||
|
||||
STRUCTURAL_CONTEXT = """
|
||||
|
|
|
|||
|
|
@ -150,7 +150,7 @@ class BaseGPTAPI(BaseChatbot):
|
|||
:return dict: return the first function arguments of choice, for example,
|
||||
{'language': 'python', 'code': "print('Hello, World!')"}
|
||||
"""
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"])
|
||||
return json.loads(self.get_choice_function(rsp)["arguments"], strict=False)
|
||||
|
||||
def messages_to_prompt(self, messages: list[dict]):
|
||||
"""[{"role": "user", "content": msg}] to user: <msg> etc."""
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import fire
|
|||
from metagpt.actions.debug_code import DebugCode
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.ml_da_action import AskReview, SummarizeAnalysis, Reflect, ReviewConst, UpdateDataColumns
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools
|
||||
from metagpt.actions.write_analysis_code import WriteCodeByGenerate, WriteCodeWithTools, MakeTools
|
||||
from metagpt.actions.write_code_steps import WriteCodeSteps
|
||||
from metagpt.actions.write_plan import WritePlan
|
||||
from metagpt.actions.write_plan import update_plan_from_rsp, precheck_update_plan_from_rsp
|
||||
|
|
@ -20,6 +20,7 @@ from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
|
|||
from metagpt.schema import Message, Plan
|
||||
from metagpt.utils.save_code import save_code_file
|
||||
from metagpt.utils.recovery_util import save_history, load_history
|
||||
from metagpt.utils.common import remove_comments
|
||||
|
||||
|
||||
class MLEngineer(Role):
|
||||
|
|
@ -31,6 +32,8 @@ class MLEngineer(Role):
|
|||
self._watch([DownloadData, SubmitResult])
|
||||
|
||||
self.plan = Plan(goal=goal)
|
||||
self.make_udfs = False # user-defined functions
|
||||
self.use_udfs = False
|
||||
self.execute_code = ExecutePyCode()
|
||||
self.auto_run = auto_run
|
||||
self.use_tools = use_tools
|
||||
|
|
@ -81,7 +84,7 @@ class MLEngineer(Role):
|
|||
self.plan.finish_current_task()
|
||||
self.working_memory.clear()
|
||||
|
||||
if self.use_tools and task.task_type not in ['model_train', 'model_evaluate']:
|
||||
if (self.use_tools and task.task_type not in ['model_train', 'model_evaluate']) or self.use_udfs:
|
||||
success, new_code = await self._update_data_columns()
|
||||
if success:
|
||||
task.code = task.code + "\n\n" + new_code
|
||||
|
|
@ -137,8 +140,8 @@ class MLEngineer(Role):
|
|||
|
||||
while not success and counter < max_retry:
|
||||
context = self.get_useful_memories()
|
||||
|
||||
if counter > 0 and self.use_tools:
|
||||
if counter > 0 and (self.use_tools or self.use_udfs):
|
||||
logger.warning('We got a bug code, now start to debug...')
|
||||
code = await DebugCode().run(
|
||||
plan=self.plan.current_task.instruction,
|
||||
code=code,
|
||||
|
|
@ -147,8 +150,10 @@ class MLEngineer(Role):
|
|||
)
|
||||
logger.info(f"new code \n{code}")
|
||||
cause_by = DebugCode
|
||||
elif not self.use_tools or self.plan.current_task.task_type == "other":
|
||||
elif (not self.use_tools and not self.use_udfs) or (
|
||||
self.plan.current_task.task_type == 'other' and not self.use_udfs):
|
||||
logger.info("Write code with pure generation")
|
||||
# TODO: 添加基于current_task.instruction-code_path的k-v缓存
|
||||
code = await WriteCodeByGenerate().run(
|
||||
context=context, plan=self.plan, temperature=0.0
|
||||
)
|
||||
|
|
@ -156,7 +161,17 @@ class MLEngineer(Role):
|
|||
cause_by = WriteCodeByGenerate
|
||||
else:
|
||||
logger.info("Write code with tools")
|
||||
schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas"
|
||||
if self.use_udfs:
|
||||
# use user-defined function tools.
|
||||
from metagpt.tools.functions.libs.udf import UDFS_YAML
|
||||
logger.warning("Writing code with user-defined function tools by WriteCodeWithTools.")
|
||||
logger.info(f"Local user defined function as following:\
|
||||
\n{json.dumps(list(UDFS_YAML.keys()), indent=2, ensure_ascii=False)}")
|
||||
# set task_type to `udf`
|
||||
self.plan.current_task.task_type = 'udf'
|
||||
schema_path = UDFS_YAML
|
||||
else:
|
||||
schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas"
|
||||
tool_context, code = await WriteCodeWithTools(schema_path=schema_path).run(
|
||||
context=context,
|
||||
plan=self.plan,
|
||||
|
|
@ -164,13 +179,16 @@ class MLEngineer(Role):
|
|||
)
|
||||
debug_context = tool_context
|
||||
cause_by = WriteCodeWithTools
|
||||
|
||||
self.working_memory.add(
|
||||
Message(content=code, role="assistant", cause_by=cause_by)
|
||||
)
|
||||
|
||||
result, success = await self.execute_code.run(code)
|
||||
print(result)
|
||||
# make tools for successful code and long code.
|
||||
if success and self.make_udfs and len(remove_comments(code).split('\n')) > 4:
|
||||
logger.info('Execute code successfully. Now start to make tools ...')
|
||||
await self.make_tools(code=code)
|
||||
self.working_memory.add(
|
||||
Message(content=result, role="user", cause_by=ExecutePyCode)
|
||||
)
|
||||
|
|
@ -254,20 +272,52 @@ class MLEngineer(Role):
|
|||
def get_working_memories(self) -> List[Message]:
|
||||
return self.working_memory.get()
|
||||
|
||||
def reset(self):
|
||||
"""Restart role with the same goal."""
|
||||
self.plan = Plan(goal=self.plan.goal)
|
||||
self.execute_code = ExecutePyCode()
|
||||
self.working_memory = Memory()
|
||||
|
||||
async def make_tools(self, code: str):
|
||||
"""Make user-defined functions(udfs, aka tools) for pure generation code.
|
||||
|
||||
Args:
|
||||
code (str): pure generation code by class WriteCodeByGenerate.
|
||||
"""
|
||||
logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \
|
||||
`{self.plan.current_task.instruction}` \n code: \n {code}")
|
||||
make_tools = MakeTools()
|
||||
make_tool_retries, make_tool_current_retry = 3, 0
|
||||
while True:
|
||||
# start make tools
|
||||
tool_code = await make_tools.run(code, self.plan.current_task.instruction)
|
||||
make_tool_current_retry += 1
|
||||
|
||||
# check tool_code by execute_code
|
||||
logger.info(f"Checking task_id {self.plan.current_task_id} tool code by executor...")
|
||||
execute_result, execute_success = await self.execute_code.run(tool_code)
|
||||
if not execute_success:
|
||||
logger.error(f"Tool code faild to execute, \n{execute_result}\n.We will try to fix it ...")
|
||||
# end make tools
|
||||
if execute_success or make_tool_current_retry >= make_tool_retries:
|
||||
if make_tool_current_retry >= make_tool_retries:
|
||||
logger.error(f"We have tried the maximum number of attempts {make_tool_retries}\
|
||||
and still have not created tools for task_id {self.plan.current_task_id} successfully,\
|
||||
we will skip it.")
|
||||
break
|
||||
# save successful tool code in udf
|
||||
if execute_success:
|
||||
make_tools.save(tool_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# requirement = "Run data analysis on sklearn Iris dataset, include a plot"
|
||||
# requirement = "Run data analysis on sklearn Diabetes dataset, include a plot"
|
||||
# requirement = "Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy"
|
||||
# requirement = "Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy"
|
||||
# requirement = "Run EDA and visualization on this dataset, train a model to predict survival, report metrics on validation set (20%), dataset: workspace/titanic/train.csv"
|
||||
# requirement = "Perform data analysis on the provided data. Train a model to predict the target variable Survived. Include data preprocessing, feature engineering, and modeling in your pipeline. The metric is accuracy."
|
||||
requirement = "Perform data analysis on the provided data. Train a model to predict the target variable Survived. Include data preprocessing, feature engineering, and modeling in your pipeline. The metric is accuracy."
|
||||
|
||||
# data_path = f"{DATA_PATH}/titanic"
|
||||
# requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
# requirement = f"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy"
|
||||
# data_path = f"{DATA_PATH}/icr-identify-age-related-conditions"
|
||||
# requirement = f"This is a medical dataset with over fifty anonymized health characteristics linked to three age-related conditions. Your goal is to predict whether a subject has or has not been diagnosed with one of these conditions.The target column is Class. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report f1 score on the eval data. Train data path: {data_path}/split_train.csv, eval data path: {data_path}/split_eval.csv."
|
||||
data_path = f"{DATA_PATH}/titanic"
|
||||
requirement = f"This is a titanic passenger survival dataset, your goal is to predict passenger survival outcome. The target column is Survived. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report accuracy on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv'."
|
||||
requirement = f"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy"
|
||||
data_path = f"{DATA_PATH}/icr-identify-age-related-conditions"
|
||||
requirement = f"This is a medical dataset with over fifty anonymized health characteristics linked to three age-related conditions. Your goal is to predict whether a subject has or has not been diagnosed with one of these conditions.The target column is Class. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report f1 score on the eval data. Train data path: {data_path}/split_train.csv, eval data path: {data_path}/split_eval.csv."
|
||||
|
||||
# data_path = f"{DATA_PATH}/santander-customer-transaction-prediction"
|
||||
# requirement = f"This is a customers financial dataset. Your goal is to predict which customers will make a specific transaction in the future. The target column is target. Perform data analysis, data preprocessing, feature engineering, and modeling to predict the target. Report AUC Score on the eval data. Train data path: '{data_path}/split_train.csv', eval data path: '{data_path}/split_eval.csv' ."
|
||||
|
|
|
|||
123
metagpt/tools/functions/libs/udf/__init__.py
Normal file
123
metagpt/tools/functions/libs/udf/__init__.py
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
import ast
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
import inspect
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def extract_function_signatures(file_path):
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
source_code = file.read()
|
||||
|
||||
tree = ast.parse(source_code)
|
||||
function_signatures = []
|
||||
function_returns = []
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
# 只提取用户自定义函数,排除内置函数
|
||||
if not (node.name.startswith('__') and node.name.endswith('__')):
|
||||
# 获取函数名
|
||||
function_name = node.name
|
||||
# 获取参数列表
|
||||
args = [arg.arg for arg in node.args.args]
|
||||
# 获取函数签名
|
||||
function_signature = f"{function_name}({', '.join(args)})"
|
||||
# 导入函数
|
||||
module_name = Path(file_path).parts[-1][:-len(Path(file_path).suffix)]
|
||||
module = importlib.import_module(f"metagpt.tools.functions.libs.udf.{module_name}")
|
||||
# 将函数导入到当前命名空间
|
||||
globals().update({function_name: getattr(module, function_name)})
|
||||
# 获取函数注释和函数路径
|
||||
function_schema = {'udf_name': function_signature,
|
||||
'udf_path': f'from metagpt.tools.functions.libs.udf.{module_name} import {function_name}',
|
||||
'udf_doc': inspect.getdoc(getattr(module, function_name))}
|
||||
function_signatures.append(function_schema)
|
||||
# 获取函数返回变量名
|
||||
source_lines, _ = inspect.getsourcelines(getattr(module, function_name))
|
||||
for line in source_lines:
|
||||
if line.strip().startswith("return "):
|
||||
function_returns.append({
|
||||
'udf_name': function_name,
|
||||
'udf_returns': [var.strip() for var in line.strip()[len("return "):].split(',')]
|
||||
})
|
||||
break
|
||||
|
||||
# 没有返回值的函数
|
||||
if not function_returns or function_returns[-1]['udf_name'] != function_name:
|
||||
function_returns.append({
|
||||
'udf_name': function_name,
|
||||
'udf_returns': [None]
|
||||
})
|
||||
return function_signatures, function_returns
|
||||
|
||||
|
||||
def get_function_signatures_in_folder(folder_path):
|
||||
python_files = [f for f in os.listdir(folder_path) if f.endswith('.py') and f != '__init__.py']
|
||||
all_function_signatures = []
|
||||
all_function_returns = []
|
||||
|
||||
for file_name in python_files:
|
||||
file_path = os.path.join(folder_path, file_name)
|
||||
function_signatures, function_returns = extract_function_signatures(file_path)
|
||||
all_function_signatures.extend(function_signatures)
|
||||
all_function_returns.extend(function_returns)
|
||||
return all_function_signatures, all_function_returns
|
||||
|
||||
|
||||
# Create Tools Yaml Style Schema
|
||||
def docstring_to_yaml(docstring: str, return_vars: List[str] = None):
|
||||
logger.debug(f"\n\nFunction Docstring: \n{'-'*60}\n {docstring} \n\nFunction Returns: \n{'-'*60}\n{return_vars}\n")
|
||||
if docstring is None:
|
||||
return {}
|
||||
# 匹配简介部分
|
||||
description_match = re.search(r'^(.*?)(?:Args:|Returns:|Raises:|$)', docstring, re.DOTALL)
|
||||
description = description_match.group(1).strip() if description_match else ""
|
||||
|
||||
# 匹配Args部分
|
||||
args_match = re.search(r'Args:\s*(.*?)(?:Returns:|Raises:|$)', docstring, re.DOTALL)
|
||||
_args = args_match.group(1).strip() if args_match else ""
|
||||
variable_pattern = re.compile(r'(\w+)\s*\((.*?)\):\s*(.*)')
|
||||
params = variable_pattern.findall(_args)
|
||||
if not params:
|
||||
params = ((None, None, None),)
|
||||
# 匹配Returns部分
|
||||
returns_match = re.search(r'Returns:\s*(.*?)(?:Raises:|$)', docstring, re.DOTALL)
|
||||
returns = returns_match.group(1).strip() if returns_match else ""
|
||||
return_pattern = re.compile(r'^(.*)\s*:\s*(.*)$')
|
||||
# 添加返回值变量名
|
||||
return_vars = return_vars if isinstance(return_vars, list) else [return_vars]
|
||||
returns = [(r, *r_desc) for r_desc, r in zip(return_pattern.findall(returns), return_vars)]
|
||||
# 构建YAML字典
|
||||
yaml_data = {
|
||||
'description': description.strip('.').strip(),
|
||||
'parameters': {
|
||||
'properties': {param[0]: {'type': param[1], 'description': param[2]} for param in params if param[0] is not None},
|
||||
'required': [param[0] for param in params if param[0] is not None]
|
||||
},
|
||||
'returns': {ret[0]: {'type': ret[1], 'description': ret[2]} for ret in returns}
|
||||
}
|
||||
return yaml_data
|
||||
|
||||
|
||||
def extract_function_schema_yaml_in_folder(folder_path: str):
|
||||
function_signatures, function_returns = get_function_signatures_in_folder(folder_path)
|
||||
function_schema_yaml_data = {}
|
||||
for func_docstring, func_returns in zip(function_signatures, function_returns):
|
||||
if func_docstring['udf_doc']:
|
||||
fun_yaml_data = docstring_to_yaml(func_docstring['udf_doc'], func_returns['udf_returns'])
|
||||
fun_yaml_data.update({'type': 'function'})
|
||||
function_schema_yaml_data.update({func_returns['udf_name']: fun_yaml_data})
|
||||
return yaml.dump(function_schema_yaml_data, default_flow_style=False)
|
||||
|
||||
|
||||
folder_path = str(Path(__file__).parent.absolute())
|
||||
function_signatures, function_returns = get_function_signatures_in_folder(folder_path)
|
||||
|
||||
UDFS = [func for func in function_signatures]
|
||||
|
||||
UDFS_YAML_STR: str = extract_function_schema_yaml_in_folder(folder_path)
|
||||
UDFS_YAML: dict = yaml.load(UDFS_YAML_STR, Loader=yaml.FullLoader)
|
||||
|
|
@ -50,4 +50,4 @@ nbformat==5.9.2
|
|||
ipython==8.17.2
|
||||
ipykernel==6.27.0
|
||||
scikit_learn==1.3.2
|
||||
typing-extensions==4.8.0
|
||||
typing-extensions==4.9.0
|
||||
52
tests/metagpt/actions/test_make_tools.py
Normal file
52
tests/metagpt/actions/test_make_tools.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
from metagpt.actions.write_analysis_code import MakeTools
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools():
|
||||
code = "import yfinance as yf\n\n# Collect Alibaba stock data\nalibaba = yf.Ticker('BABA')\ndata = alibaba.history(period='1d', start='2022-01-01', end='2022-12-31')\nprint(data.head())"
|
||||
msgs = [{'role': 'assistant', 'content': code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = "!pip install yfinance\n" + tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools2():
|
||||
code = '''import pandas as pd\npath = "./tests/data/test.csv"\ndf = pd.read_csv(path)\ndata = df.copy()\n
|
||||
data['started_at'] = data['started_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['ended_at'] = data['ended_at'].apply(lambda r: pd.to_datetime(r))\ndata.head()'''
|
||||
msgs = [{'role': 'assistant', 'content': code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_tools3():
|
||||
code = '''import pandas as pd\npath = "./tests/data/test.csv"\ndf = pd.read_csv(path)\ndata = df.copy()\n
|
||||
data['started_at'] = data['started_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['ended_at'] = data['ended_at'].apply(lambda r: pd.to_datetime(r))\n
|
||||
data['duration_hour'] = (data['ended_at'] - data['started_at']).dt.seconds/3600\ndata.head()'''
|
||||
msgs = [{'role': 'assistant', 'content': code}]
|
||||
mt = MakeTools()
|
||||
tool_code = await mt.run(msgs)
|
||||
logger.debug(tool_code)
|
||||
ep = ExecutePyCode()
|
||||
tool_code = tool_code
|
||||
result, res_type = await ep.run(tool_code)
|
||||
assert res_type is True
|
||||
logger.debug(result)
|
||||
|
|
@ -78,3 +78,17 @@ def test_ask_code_list_str():
|
|||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_code_steps2():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["step by setp 生成代码: Step 1. 先生成随机数组a, Step 2. 求a中最大值, Step 3. 绘制数据a的直方图"]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
assert "Step 1" in rsp["code"]
|
||||
assert "Step 2" in rsp["code"]
|
||||
assert "Step 3" in rsp["code"]
|
||||
|
|
|
|||
40
tests/metagpt/roles/test_daml.py
Normal file
40
tests/metagpt/roles/test_daml.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
import pytest
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.ml_engineer import MLEngineer
|
||||
|
||||
|
||||
async def make_use_tools(requirement: str, auto_run: bool = True):
|
||||
"""make and use tools for requirement."""
|
||||
role = MLEngineer(goal=requirement, auto_run=auto_run)
|
||||
# make udfs
|
||||
role.use_tools = False
|
||||
role.use_code_steps = False
|
||||
role.make_udfs = True
|
||||
role.use_udfs = False
|
||||
await role.run(requirement)
|
||||
# use udfs
|
||||
role.reset()
|
||||
role.make_udfs = False
|
||||
role.use_udfs = True
|
||||
role.use_code_steps = False
|
||||
role.use_tools = False
|
||||
await role.run(requirement)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_make_use_tools():
|
||||
requirements = ["Run data analysis on sklearn Iris dataset, include a plot",
|
||||
"Run data analysis on sklearn Diabetes dataset, include a plot",
|
||||
"Run data analysis on sklearn Wine recognition dataset, include a plot, and train a model to predict wine class (20% as validation), and show validation accuracy",
|
||||
"Run data analysis on sklearn Wisconsin Breast Cancer dataset, include a plot, train a model to predict targets (20% as validation), and show validation accuracy",
|
||||
"Run EDA and visualization on this dataset, train a model to predict survival, report metrics on validation set (20%), dataset: tests/data/titanic.csv"]
|
||||
success = 0
|
||||
for requirement in tqdm(requirements, total=len(requirements)):
|
||||
try:
|
||||
await make_use_tools(requirement)
|
||||
success += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Found Error in {requirement}, {e}")
|
||||
logger.info(f"success: {round(success/len(requirements), 1)*100}%")
|
||||
49
tests/metagpt/tools/functions/test_udf.py
Normal file
49
tests/metagpt/tools/functions/test_udf.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
import yaml
|
||||
import json
|
||||
|
||||
from metagpt.tools.functions.libs.udf import UDFS, docstring_to_yaml, UDFS_YAML
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def test_udfs():
|
||||
assert len(UDFS) > 0
|
||||
assert 'udf_name' in UDFS[0]
|
||||
assert 'udf_doc' in UDFS[0]
|
||||
logger.info(UDFS)
|
||||
|
||||
|
||||
def test_docstring2yaml():
|
||||
docstring = """Calculate the duration in hours between two datetime columns.
|
||||
|
||||
Args:
|
||||
dataframe (pd.DataFrame): The dataframe containing the datetime columns.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: The dataframe with an additional column 'duration_hour' added.
|
||||
"""
|
||||
|
||||
yaml_result = docstring_to_yaml(docstring, return_vars='dataframe')
|
||||
assert 'parameters' in yaml_result
|
||||
assert 'properties' in yaml_result['parameters']
|
||||
assert 'dataframe' in yaml_result['parameters']['properties']
|
||||
|
||||
|
||||
def test_UDFS_YAML():
|
||||
assert len(UDFS_YAML) > 0
|
||||
logger.info(f"\n\n{json.dumps(UDFS_YAML, indent=2, ensure_ascii=False)}")
|
||||
function_schema = UDFS_YAML
|
||||
assert 'description' in function_schema[list(function_schema.keys())[0]]
|
||||
assert 'type' in function_schema[list(function_schema.keys())[0]]
|
||||
assert 'parameters' in function_schema[list(function_schema.keys())[0]]
|
||||
assert 'properties' in function_schema[list(function_schema.keys())[0]]['parameters']
|
||||
assert 'required' in function_schema[list(function_schema.keys())[0]]['parameters']
|
||||
assert 'returns' in function_schema[list(function_schema.keys())[0]]
|
||||
# 指定要保存的文件路径
|
||||
file_path = './tests/data/function_schema.yaml'
|
||||
|
||||
# 使用 PyYAML 将字典保存为 YAML 文件
|
||||
with open(file_path, 'w') as file:
|
||||
yaml.dump(function_schema, file, default_flow_style=False)
|
||||
|
||||
print(f'Data has been saved to {file_path}')
|
||||
Loading…
Add table
Add a link
Reference in a new issue