mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
commit
83ee76cca7
2 changed files with 17 additions and 6 deletions
|
|
@ -49,6 +49,14 @@ def get_docstring_statement(body: DocstringNode) -> cst.SimpleStatementLine:
|
|||
return statement
|
||||
|
||||
|
||||
def has_decorator(node: DocstringNode, name: str) -> bool:
|
||||
return hasattr(node, "decorators") and any(
|
||||
(hasattr(i.decorator, "value") and i.decorator.value == name)
|
||||
or (hasattr(i.decorator, "func") and hasattr(i.decorator.func, "value") and i.decorator.func.value == name)
|
||||
for i in node.decorators
|
||||
)
|
||||
|
||||
|
||||
class DocstringCollector(cst.CSTVisitor):
|
||||
"""A visitor class for collecting docstrings from a CST.
|
||||
|
||||
|
|
@ -82,7 +90,7 @@ class DocstringCollector(cst.CSTVisitor):
|
|||
def _leave(self, node: DocstringNode) -> None:
|
||||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
if hasattr(node, "decorators") and any(i.decorator.value == "overload" for i in node.decorators):
|
||||
if has_decorator(node, "overload"):
|
||||
return
|
||||
|
||||
statement = get_docstring_statement(node)
|
||||
|
|
@ -127,9 +135,7 @@ class DocstringTransformer(cst.CSTTransformer):
|
|||
key = tuple(self.stack)
|
||||
self.stack.pop()
|
||||
|
||||
if hasattr(updated_node, "decorators") and any(
|
||||
(i.decorator.value == "overload") for i in updated_node.decorators
|
||||
):
|
||||
if has_decorator(updated_node, "overload"):
|
||||
return updated_node
|
||||
|
||||
statement = self.docstrings.get(key)
|
||||
|
|
|
|||
|
|
@ -17,14 +17,15 @@ from metagpt.tools.prompt_writer import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
def test_gpt_prompt_generator(llm_api):
|
||||
async def test_gpt_prompt_generator(llm_api):
|
||||
generator = GPTPromptGenerator()
|
||||
example = (
|
||||
"商品名称:WonderLab 新肌果味代餐奶昔 小胖瓶 胶原蛋白升级版 饱腹代餐粉6瓶 75g/瓶(6瓶/盒) 店铺名称:金力宁食品专营店 " "品牌:WonderLab 保质期:1年 产地:中国 净含量:450g"
|
||||
)
|
||||
|
||||
results = llm_api.ask_batch(generator.gen(example))
|
||||
results = await llm_api.aask_batch(generator.gen(example))
|
||||
logger.info(results)
|
||||
assert len(results) > 0
|
||||
|
||||
|
|
@ -58,3 +59,7 @@ def test_beagec_template():
|
|||
assert any(
|
||||
"Edit and revise this document to improve its grammar, vocabulary, spelling, and style." in r for r in results
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue