fixbug: recursive user requirement dead loop

This commit is contained in:
莘权 马 2023-12-18 16:13:21 +08:00
parent e43aaec932
commit 9c405dfa77
3 changed files with 23 additions and 18 deletions

View file

@ -0,0 +1,16 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/18
@Author : mashenquan
@File : role_run.py
@Desc : Message type caused by `Role.run()` invocation.
"""
from metagpt.actions import Action
class RoleRun(Action):
"""Message type caused by `Role.run` invocation"""
async def run(self, *args, **kwargs):
raise NotImplementedError

View file

@ -27,7 +27,7 @@ from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.role_run import RoleRun
from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
@ -127,17 +127,7 @@ class RoleContext(BaseModel):
return self.memory.get()
class _RoleInjector(type):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if not instance._rc.watch:
instance._watch([UserRequirement])
return instance
class Role(metaclass=_RoleInjector):
class Role:
"""Role/Agent"""
def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False):
@ -152,7 +142,6 @@ class Role(metaclass=_RoleInjector):
self._rc = RoleContext()
self._subscription = {any_to_str(self), name} if name else {any_to_str(self)}
def _reset(self):
self._states = []
self._actions = []
@ -304,7 +293,9 @@ class Role(metaclass=_RoleInjector):
old_messages = [] if ignore_memory else self._rc.memory.get()
self._rc.memory.add_batch(news)
# Filter out messages of interest.
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
watch = self._rc.watch or set()
watch.add(any_to_str(RoleRun))
self._rc.news = [n for n in news if n.cause_by in watch and n not in old_messages]
# Design Rules:
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
@ -401,6 +392,8 @@ class Role(metaclass=_RoleInjector):
msg = with_message
elif isinstance(with_message, list):
msg = Message("\n".join(with_message))
if not msg.cause_by:
msg.cause_by = RoleRun
self.put_message(msg)
if not await self._observe():

View file

@ -121,10 +121,6 @@ class Message(BaseModel):
:param send_to: Specifies the target recipient or consumer for message delivery in the environment.
:param role: Message meta info tells who sent this message.
"""
if not cause_by:
from metagpt.actions import UserRequirement
cause_by = UserRequirement
super().__init__(
id=uuid.uuid4().hex,
content=content,