register tools from path

This commit is contained in:
yzlin 2024-03-18 22:00:51 +08:00
parent 0271cd7f77
commit e53a0acc8e
3 changed files with 54 additions and 6 deletions

View file

@ -7,7 +7,7 @@ PARSER = GoogleDocstringParser
def convert_code_to_tool_schema(obj, include: list[str] = None):
docstring = inspect.getdoc(obj)
assert docstring, "no docstring found for the objects, skip registering"
# assert docstring, "no docstring found for the objects, skip registering"
if inspect.isclass(obj):
schema = {"type": "class", "description": remove_spaces(docstring), "methods": {}}

View file

@ -7,10 +7,10 @@
"""
from __future__ import annotations
import importlib.util
import inspect
import os
from collections import defaultdict
from typing import Union
import yaml
from pydantic import BaseModel
@ -127,15 +127,63 @@ def make_schema(tool_source_object, include, path):
return schema
def validate_tool_names(tools: Union[list[str], str]) -> str:
def validate_tool_names(tools: list[str]) -> dict[str, Tool]:
assert isinstance(tools, list), "tools must be a list of str"
valid_tools = {}
for key in tools:
# one can define either tool names or tool type names, take union to get the whole set
if TOOL_REGISTRY.has_tool(key):
# one can define either tool names OR tool tags OR tool path, take union to get the whole set
# if tool paths are provided, they will be registered on the fly
if os.path.isdir(key) or os.path.isfile(key):
valid_tools.update(register_tools_from_path(key))
elif TOOL_REGISTRY.has_tool(key):
valid_tools.update({key: TOOL_REGISTRY.get_tool(key)})
elif TOOL_REGISTRY.has_tool_tag(key):
valid_tools.update(TOOL_REGISTRY.get_tools_by_tag(key))
else:
logger.warning(f"invalid tool name or tool type name: {key}, skipped")
return valid_tools
def load_module_from_file(filepath):
module_name = os.path.splitext(os.path.basename(filepath))[0]
spec = importlib.util.spec_from_file_location(module_name, filepath)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
def register_tools_from_file(file_path) -> dict[str, Tool]:
registered_tools = {}
module = load_module_from_file(file_path)
for name, obj in inspect.getmembers(module):
if inspect.isclass(obj) or inspect.isfunction(obj):
if obj.__module__ == module.__name__:
# excluding imported classes and functions, register only those defined in the file
if "metagpt" in file_path:
# split to handle ../metagpt/metagpt/tools/... where only metapgt/tools/... is needed
file_path = "metagpt" + file_path.split("metagpt")[-1]
TOOL_REGISTRY.register_tool(
tool_name=name,
tool_path=file_path,
tool_code="", # inspect.getsource(obj) will resulted in TypeError, skip it for now
tool_source_object=obj,
)
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
return registered_tools
def register_tools_from_path(path) -> dict[str, Tool]:
tools_registered = {}
if os.path.isfile(path) and path.endswith(".py"):
# Path is a Python file
tools_registered.update(register_tools_from_file(path))
elif os.path.isdir(path):
# Path is a directory
for root, _, files in os.walk(path):
for file in files:
if file.endswith(".py"):
file_path = os.path.join(root, file)
tools_registered.update(register_tools_from_file(file_path))
return tools_registered

View file

@ -3,7 +3,7 @@ from typing import Tuple
def remove_spaces(text):
return re.sub(r"\s+", " ", text).strip()
return re.sub(r"\s+", " ", text).strip() if text else ""
class DocstringParser: