patch release v0.5.2: fix user requirement dead loop in startup

fixbug: recursive user requirement dead loop
This commit is contained in:
garylin2099 2023-12-18 19:20:01 +08:00 committed by GitHub
commit fa727c48c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 9 additions and 24 deletions

View file

@ -25,9 +25,8 @@ from typing import Iterable, Set, Type
from pydantic import BaseModel, Field
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
@ -127,17 +126,7 @@ class RoleContext(BaseModel):
return self.memory.get()
class _RoleInjector(type):
def __call__(cls, *args, **kwargs):
instance = super().__call__(*args, **kwargs)
if not instance._rc.watch:
instance._watch([UserRequirement])
return instance
class Role(metaclass=_RoleInjector):
class Role:
"""Role/Agent"""
def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False):
@ -149,10 +138,9 @@ class Role(metaclass=_RoleInjector):
self._states = []
self._actions = []
self._role_id = str(self._setting)
self._rc = RoleContext()
self._rc = RoleContext(watch={any_to_str(UserRequirement)})
self._subscription = {any_to_str(self), name} if name else {any_to_str(self)}
def _reset(self):
self._states = []
self._actions = []
@ -203,8 +191,7 @@ class Role(metaclass=_RoleInjector):
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
buffer during _observe.
"""
tags = {any_to_str(t) for t in actions}
self._rc.watch.update(tags)
self._rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self._rc.check(self._role_id)
@ -401,6 +388,8 @@ class Role(metaclass=_RoleInjector):
msg = with_message
elif isinstance(with_message, list):
msg = Message("\n".join(with_message))
if not msg.cause_by:
msg.cause_by = UserRequirement
self.put_message(msg)
if not await self._observe():

View file

@ -121,10 +121,6 @@ class Message(BaseModel):
:param send_to: Specifies the target recipient or consumer for message delivery in the environment.
:param role: Message meta info tells who sent this message.
"""
if not cause_by:
from metagpt.actions import UserRequirement
cause_by = UserRequirement
super().__init__(
id=uuid.uuid4().hex,
content=content,

View file

@ -14,11 +14,11 @@ import uuid
import pytest
from pydantic import BaseModel
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import get_class_name
from metagpt.utils.common import any_to_str, get_class_name
class MockAction(Action):
@ -60,7 +60,7 @@ async def test_react():
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
)
role.subscribe({seed.subscription})
assert role._rc.watch == set({})
assert role._rc.watch == {any_to_str(UserRequirement)}
assert role.name == seed.name
assert role.profile == seed.profile
assert role._setting.goal == seed.goal