Merge branch 'tool_manage_new' into 'code_intepreter'

convert local class or function to tool, tool clarification at role initialization

See merge request agents/data_agents_opt!55
This commit is contained in:
林义章 2024-01-22 09:12:02 +00:00
commit 7f5f95d41b
18 changed files with 807 additions and 147 deletions

View file

@ -10,7 +10,7 @@ from metagpt.utils.recovery_util import load_history, save_history
async def run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
):
"""
The main function to run the MLEngineer with optional history loading.
@ -25,7 +25,9 @@ async def run_code_interpreter(
"""
if role_class == "ci":
role = CodeInterpreter(goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs)
role = CodeInterpreter(
goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, tools=tools
)
else:
role = MLEngineer(
goal=requirement,
@ -33,7 +35,7 @@ async def run_code_interpreter(
use_tools=use_tools,
use_code_steps=use_code_steps,
make_udfs=make_udfs,
use_udfs=use_udfs,
tools=tools,
)
if save_dir:
@ -73,6 +75,8 @@ if __name__ == "__main__":
use_tools = True
make_udfs = False
use_udfs = False
tools = []
# tools = ["FillMissingValue", "CatCross", "non_existing_test"]
async def main(
role_class: str = role_class,
@ -83,9 +87,10 @@ if __name__ == "__main__":
make_udfs: bool = make_udfs,
use_udfs: bool = use_udfs,
save_dir: str = save_dir,
tools=tools,
):
await run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
)
fire.Fire(main)

View file

@ -0,0 +1,158 @@
import pandas as pd
from metagpt.tools.tool_convert import convert_code_to_tool_schema, docstring_to_schema
def test_docstring_to_schema():
docstring = """
Some test desc.
Args:
features (list): Columns to be processed.
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only be
used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
expected = {
"description": " Some test desc. ",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
}
schema = docstring_to_schema(docstring)
assert schema == expected
class DummyClass:
"""
Completing missing values with simple strategies.
"""
def __init__(self, features: list, strategy: str = "mean", fill_value=None):
"""
Initialize self.
Args:
features (list): Columns to be processed.
strategy (str, optional): The imputation strategy, notice 'mean' and 'median' can only
be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.
fill_value (int, optional): Fill_value is used to replace all occurrences of missing_values.
Defaults to None.
"""
pass
def fit(self, df: pd.DataFrame):
"""
Fit the FillMissingValue model.
Args:
df (pd.DataFrame): The input DataFrame.
"""
pass
def transform(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Transform the input DataFrame with the fitted model.
Args:
df (pd.DataFrame): The input DataFrame.
Returns:
pd.DataFrame: The transformed DataFrame.
"""
pass
def dummy_fn(df: pd.DataFrame) -> dict:
"""
Analyzes a DataFrame and categorizes its columns based on data types.
Args:
df (pd.DataFrame): The DataFrame to be analyzed.
Returns:
dict: A dictionary with four keys ('Category', 'Numeric', 'Datetime', 'Others').
Each key corresponds to a list of column names belonging to that category.
"""
pass
def test_convert_code_to_tool_schema_class():
expected = {
"DummyClass": {
"type": "class",
"description": "Completing missing values with simple strategies.",
"methods": {
"__init__": {
"description": "Initialize self. ",
"parameters": {
"properties": {
"features": {"type": "list", "description": "Columns to be processed."},
"strategy": {
"type": "str",
"description": "The imputation strategy, notice 'mean' and 'median' can only be used for numeric features. Enum: ['mean', 'median', 'most_frequent', 'constant']. Defaults to 'mean'.",
"default": "'mean'",
"enum": ["'mean'", "'median'", "'most_frequent'", "'constant'"],
},
"fill_value": {
"type": "int",
"description": "Fill_value is used to replace all occurrences of missing_values. Defaults to None.",
"default": "None",
},
},
"required": ["features"],
},
},
"fit": {
"description": "Fit the FillMissingValue model. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
},
"transform": {
"description": "Transform the input DataFrame with the fitted model. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The input DataFrame."}},
"required": ["df"],
},
"returns": [{"type": "pd.DataFrame", "description": "The transformed DataFrame."}],
},
},
}
}
schema = convert_code_to_tool_schema(DummyClass)
assert schema == expected
def test_convert_code_to_tool_schema_function():
expected = {
"dummy_fn": {
"type": "function",
"description": "Analyzes a DataFrame and categorizes its columns based on data types. ",
"parameters": {
"properties": {"df": {"type": "pd.DataFrame", "description": "The DataFrame to be analyzed."}},
"required": ["df"],
},
}
}
schema = convert_code_to_tool_schema(dummy_fn)
assert schema == expected

View file

@ -98,4 +98,4 @@ def test_get_tools_by_type(tool_registry, schema_yaml):
# Test case for when the tool type does not exist
def test_get_tools_by_nonexistent_type(tool_registry):
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
assert tools_by_type is None
assert not tools_by_type