From 5376a869298b26eafc74a952a352e81b712bdacb Mon Sep 17 00:00:00 2001 From: yzlin Date: Thu, 11 Apr 2024 16:55:52 +0800 Subject: [PATCH] fix tool convert bug and add more tests --- metagpt/tools/tool_convert.py | 5 ++--- tests/metagpt/tools/test_tool_convert.py | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/metagpt/tools/tool_convert.py b/metagpt/tools/tool_convert.py index 829269b1b..a84cbeea0 100644 --- a/metagpt/tools/tool_convert.py +++ b/metagpt/tools/tool_convert.py @@ -20,8 +20,7 @@ def convert_code_to_tool_schema(obj, include: list[str] = None) -> dict: continue # method_doc = inspect.getdoc(method) method_doc = get_class_method_docstring(obj, name) - if method_doc: - schema["methods"][name] = function_docstring_to_schema(method, method_doc) + schema["methods"][name] = function_docstring_to_schema(method, method_doc) elif inspect.isfunction(obj): schema = function_docstring_to_schema(obj, docstring) @@ -39,7 +38,7 @@ def convert_code_to_tool_schema_ast(code: str) -> list[dict]: return visitor.get_tool_schemas() -def function_docstring_to_schema(fn_obj, docstring) -> dict: +def function_docstring_to_schema(fn_obj, docstring="") -> dict: """ Converts a function's docstring into a schema dictionary. diff --git a/tests/metagpt/tools/test_tool_convert.py b/tests/metagpt/tools/test_tool_convert.py index 4798d32b0..5aa53ce4f 100644 --- a/tests/metagpt/tools/test_tool_convert.py +++ b/tests/metagpt/tools/test_tool_convert.py @@ -48,6 +48,14 @@ class DummyClass: pass +class DummySubClass(DummyClass): + """sub class docstring""" + + def sub_method(self, df: pd.DataFrame): + """sub method""" + pass + + def dummy_fn( df: pd.DataFrame, s: str, @@ -117,6 +125,18 @@ def test_convert_code_to_tool_schema_class(): assert schema == expected +def test_convert_code_to_tool_schema_subclass(): + schema = convert_code_to_tool_schema(DummySubClass) + assert "sub_method" in schema["methods"] # sub class method should be included + assert "fit" in schema["methods"] # parent class method should be included + + +def test_convert_code_to_tool_schema_include(): + schema = convert_code_to_tool_schema(DummyClass, include=["fit"]) + assert "fit" in schema["methods"] + assert "transform" not in schema["methods"] + + def test_convert_code_to_tool_schema_function(): expected = { "type": "function",