diff --git a/metagpt/utils/pycst.py b/metagpt/utils/pycst.py index 1edfed81c..a26ba70ff 100644 --- a/metagpt/utils/pycst.py +++ b/metagpt/utils/pycst.py @@ -49,6 +49,14 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine: return statement +def has_decorator(node: DocstringNode, name: str) -> bool: + return hasattr(node, "decorators") and any( + (hasattr(i.decorator, "value") and i.decorator.value == name) + or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name) + for i in node.decorators + ) + + class DocstringCollector(cst.CSTVisitor): """A visitor class for collecting docstrings from a CST. @@ -82,7 +90,7 @@ class DocstringCollector(cst.CSTVisitor): def _leave(self, node: DocstringNode) -> None: key = tuple(self.stack) self.stack.pop() - if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators): + if has_decorator(node, "overload"): return statement = get_docstring_statement(node) @@ -127,9 +135,7 @@ class DocstringTransformer(cst.CSTTransformer): key = tuple(self.stack) self.stack.pop() - if hasattr(updated_node, "decorators") and any( - (i.decorator.value == "overload") for i in updated_node.decorators - ): + if has_decorator(updated_node, "overload"): return updated_node statement = self.docstrings.get(key) diff --git a/tests/metagpt/tools/test_prompt_writer.py b/tests/metagpt/tools/test_prompt_writer.py index 9f0c25ba1..680d4fe54 100644 --- a/tests/metagpt/tools/test_prompt_writer.py +++ b/tests/metagpt/tools/test_prompt_writer.py @@ -17,14 +17,15 @@ from metagpt.tools.prompt_writer import ( ) +@pytest.mark.asyncio @pytest.mark.usefixtures("llm_api") -def test_gpt_prompt_generator(llm_api): +async def test_gpt_prompt_generator(llm_api): generator = GPTPromptGenerator() example = ( "商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g" ) - results = llm_api.ask_batch(generator.gen(example)) + results = await llm_api.aask_batch(generator.gen(example)) logger.info(results) assert len(results) > 0 @@ -58,3 +59,7 @@ def test_beagec_template(): assert any( "Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results ) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])