allow select tool at role initialization & restructure writecodewithtools

This commit is contained in:
yzlin 2024-01-20 21:06:48 +08:00
parent 2ccfe31123
commit 540542834e
8 changed files with 127 additions and 90 deletions

View file

@ -10,7 +10,7 @@ from metagpt.utils.recovery_util import load_history, save_history
async def run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
):
"""
The main function to run the MLEngineer with optional history loading.
@ -25,7 +25,9 @@ async def run_code_interpreter(
"""
if role_class == "ci":
role = CodeInterpreter(goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs)
role = CodeInterpreter(
goal=requirement, auto_run=auto_run, use_tools=use_tools, make_udfs=make_udfs, tools=tools
)
else:
role = MLEngineer(
goal=requirement,
@ -33,7 +35,7 @@ async def run_code_interpreter(
use_tools=use_tools,
use_code_steps=use_code_steps,
make_udfs=make_udfs,
use_udfs=use_udfs,
tools=tools,
)
if save_dir:
@ -73,6 +75,8 @@ if __name__ == "__main__":
use_tools = True
make_udfs = False
use_udfs = False
tools = []
# tools = ["FillMissingValue", "CatCross", "non_existing_test"]
async def main(
role_class: str = role_class,
@ -83,9 +87,10 @@ if __name__ == "__main__":
make_udfs: bool = make_udfs,
use_udfs: bool = use_udfs,
save_dir: str = save_dir,
tools=tools,
):
await run_code_interpreter(
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir
role_class, requirement, auto_run, use_tools, use_code_steps, make_udfs, use_udfs, save_dir, tools
)
fire.Fire(main)

View file

@ -98,4 +98,4 @@ def test_get_tools_by_type(tool_registry, schema_yaml):
# Test case for when the tool type does not exist
def test_get_tools_by_nonexistent_type(tool_registry):
tools_by_type = tool_registry.get_tools_by_type("NonexistentType")
assert tools_by_type is None
assert not tools_by_type