fixbug: tests

This commit is contained in:
莘权 马 2023-08-01 10:48:26 +08:00
parent 85c7148b62
commit d415ca5dbc
7 changed files with 15 additions and 17 deletions

View file

@ -15,6 +15,7 @@ from metagpt.llm import LLM
from metagpt.utils.common import OutputParser
from metagpt.logs import logger
class Action(ABC):
def __init__(self, name: str = '', context=None, llm: LLM = None):
self.name: str = name

View file

@ -135,8 +135,7 @@ class WriteDesign(Action):
self._save_prd(docs_path, resources_path, context[-1].content)
self._save_system_design(docs_path, resources_path, content)
async def run(self, *args, **kwargs):
context = args[0]
async def run(self, context):
prompt = PROMPT_TEMPLATE.format(context=context, format_example=FORMAT_EXAMPLE)
# system_design = await self._aask(prompt)
system_design = await self._aask_v1(prompt, "system_design", OUTPUT_MAPPING)

View file

@ -115,8 +115,7 @@ class WriteTasks(Action):
requirements_path = WORKSPACE_ROOT / ws_name / 'requirements.txt'
requirements_path.write_text(rsp.instruct_content.dict().get("Required Python third-party packages").strip('"\n'))
async def run(self, *args, **kwargs):
context = args[0]
async def run(self, context):
prompt = PROMPT_TEMPLATE.format(context=context, format_example=FORMAT_EXAMPLE)
rsp = await self._aask_v1(prompt, "task", OUTPUT_MAPPING)
self._save(context, rsp)

View file

@ -5,7 +5,6 @@
@Author : mashenquan
@File : write_teaching_plan.py
"""
from langchain.llms.base import LLM
from metagpt.logs import logger
from metagpt.actions import Action
from metagpt.schema import Message
@ -21,7 +20,7 @@ class TeachingPlanRequirement(Action):
class WriteTeachingPlanPart(Action):
"""Write Teaching Plan Part"""
def __init__(self, name: str = "", context=None, llm: LLM = None, topic: str = "", language: str = "Chinese"):
def __init__(self, name: str = "", context=None, llm=None, topic: str = "", language: str = "Chinese"):
"""
Args:
@ -35,8 +34,8 @@ class WriteTeachingPlanPart(Action):
self.language = language
self.rsp = None
async def run(self, *args, **kwargs):
if len(args) < 1 or len(args[0]) < 1 or not isinstance(args[0][0], Message):
async def run(self, messages, *args, **kwargs):
if len(messages) < 1 or not isinstance(messages[0], Message):
raise ValueError("Invalid args, a tuple of List[Message] is expected")
statement_patterns = self.TOPIC_STATEMENTS.get(self.topic, [])
@ -49,7 +48,7 @@ class WriteTeachingPlanPart(Action):
prompt = formatter.format(formation=self.FORMATION,
role=self.prefix,
statements="\n".join(statements),
lesson=args[0][0].content,
lesson=messages[0].content,
topic=self.topic,
language=self.language)