add tool code for ast parsing

This commit is contained in:
yzlin 2024-03-25 17:35:12 +08:00
parent e7423763bc
commit da98495657
2 changed files with 13 additions and 6 deletions

View file

@ -38,7 +38,7 @@ def convert_code_to_tool_schema_ast(code: str) -> list[dict]:
child.parent = parent
add_parent_references(child, parent=node)
visitor = CodeVisitor()
visitor = CodeVisitor(code)
parsed_code = ast.parse(code)
add_parent_references(parsed_code)
visitor.visit(parsed_code)
@ -86,8 +86,9 @@ def get_class_method_docstring(cls, method_name):
class CodeVisitor(ast.NodeVisitor):
"""Visit and convert the AST nodes within a code file to tool schemas"""
def __init__(self):
def __init__(self, source_code: str):
self.tool_schemas = {} # {tool_name: tool_schema}
self.source_code = source_code
def visit_ClassDef(self, node):
class_schemas = {"type": "class", "description": remove_spaces(ast.get_docstring(node)), "methods": {}}
@ -97,17 +98,21 @@ class CodeVisitor(ast.NodeVisitor):
):
func_schemas = self._get_function_schemas(body_node)
class_schemas["methods"].update({body_node.name: func_schemas})
class_schemas["code"] = ast.get_source_segment(self.source_code, node)
self.tool_schemas[node.name] = class_schemas
def visit_FunctionDef(self, node):
if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"):
return
self.tool_schemas[node.name] = self._get_function_schemas(node)
self._visit_function(node)
def visit_AsyncFunctionDef(self, node):
self._visit_function(node)
def _visit_function(self, node):
if isinstance(node.parent, ast.ClassDef) or node.name.startswith("_"):
return
self.tool_schemas[node.name] = self._get_function_schemas(node)
function_schemas = self._get_function_schemas(node)
function_schemas["code"] = ast.get_source_segment(self.source_code, node)
self.tool_schemas[node.name] = function_schemas
def _get_function_schemas(self, node):
docstring = remove_spaces(ast.get_docstring(node))

View file

@ -154,10 +154,12 @@ def register_tools_from_file(file_path) -> dict[str, Tool]:
code = Path(file_path).read_text(encoding="utf-8")
tool_schemas = convert_code_to_tool_schema_ast(code)
for name, schemas in tool_schemas.items():
tool_code = schemas.pop("code", "")
TOOL_REGISTRY.register_tool(
tool_name=name,
tool_path=file_path,
schemas=schemas,
tool_code=tool_code,
)
registered_tools.update({name: TOOL_REGISTRY.get_tool(name)})
return registered_tools