This commit is contained in:
刘棒棒 2023-12-20 12:11:42 +08:00
parent 19b0120c15
commit 913538639d

View file

@ -21,6 +21,7 @@ from metagpt.prompts.ml_engineer import (
PRINT_DATA_COLUMNS
)
from metagpt.roles import Role
from metagpt.roles.role import RoleContext
from metagpt.roles.kaggle_manager import DownloadData, SubmitResult
from metagpt.schema import Message, Plan
from metagpt.utils.common import remove_comments, create_func_config
@ -192,16 +193,6 @@ class MLEngineer(Role):
)
debug_context = [self.get_useful_memories(task_exclude_field={'result', 'code_steps'})[0]]
cause_by = WriteCodeByGenerate
if self.make_udfs and len(code.split('\n')) > 4:
# make and save user-defined function tools.
logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \
`{self.plan.current_task.instruction}` \n code {code}")
make_tools = MakeTools()
code_prompt = f"The following code is about {self.plan.current_task.instruction},\
convert it to be a General Function, {code}"
tool_code = await make_tools.run(code_prompt)
make_tools.save(tool_code)
else:
logger.info("Write code with tools")
schema_path = PROJECT_ROOT / "metagpt/tools/functions/schemas"
@ -219,6 +210,9 @@ class MLEngineer(Role):
result, success = await self.execute_code.run(code)
print(result)
# make tools for successful code and long code.
if success and self.make_udfs and len(code.split('\n')) > 4:
await self.make_tools(code=code)
self.working_memory.add(
Message(content=result, role="user", cause_by=ExecutePyCode)
)
@ -304,6 +298,39 @@ class MLEngineer(Role):
def get_working_memories(self) -> List[Message]:
return self.working_memory.get()
def reset(self):
"""Restart role with the same goal."""
self.plan = Plan(goal=self.plan.goal)
self.execute_code = ExecutePyCode()
async def make_tools(self, code: str):
"""Make user-defined functions(udfs, aka tools) for pure generation code.
Args:
code (str): pure generation code by class WriteCodeByGenerate.
"""
logger.warning(f"Making tools for task_id {self.plan.current_task_id}: \
`{self.plan.current_task.instruction}` \n code: \n {code}")
make_tools = MakeTools()
code_prompt = f"The following code is about {self.plan.current_task.instruction},\
convert it to be a General Function, {code}"
tool_code = await make_tools.run(code_prompt)
# check tool_code by execute_code
logger.info(f"Checking task_id {self.plan.current_task_id} tool code by executor...")
_, success = await self.execute_code.run(tool_code)
make_tool_retries, make_tool_current_retry = 3, 1
while not success:
tool_code = await make_tools.run(code_prompt)
_, success = await self.execute_code.run(tool_code)
if make_tool_current_retry > make_tool_retries:
logger.error(f"We have tried the maximum number of attempts {make_tool_retries}\
and still have not created tools for task_id {self.plan.current_task_id} successfully,\
we will skip it.")
break
# save successful tool code in udf
if success:
make_tools.save(tool_code)
if __name__ == "__main__":
requirement = "Run data analysis on sklearn Iris dataset, include a plot"
@ -314,6 +341,12 @@ if __name__ == "__main__":
async def main(requirement: str = requirement, auto_run: bool = True):
role = MLEngineer(goal=requirement, auto_run=auto_run)
# make udfs
role.make_udfs = True
role.use_udfs = False
await role.run(requirement)
# use udfs
role.reset()
role.make_udfs = False
role.use_udfs = True
await role.run(requirement)