fixbug: recursive user requirement dead loop

This commit is contained in:
莘权 马 2023-12-18 16:13:21 +08:00 committed by better629
parent e8a848a614
commit 31f1be98a0
2 changed files with 8 additions and 7 deletions

View file

@ -27,15 +27,15 @@ from typing import Iterable, Set, Type, Any
from pydantic import BaseModel, Field
from metagpt.actions.action import Action, ActionOutput, action_subclass_registry
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
from metagpt.llm import LLM
from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.human_provider import HumanProvider
from metagpt.schema import Message, MessageQueue
from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
@ -293,8 +293,7 @@ class Role(BaseModel):
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
buffer during _observe.
"""
tags = {any_to_str(t) for t in actions}
self._rc.watch.update(tags)
self._rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self._rc.check(self._role_id)
@ -509,6 +508,8 @@ class Role(BaseModel):
msg = with_message
elif isinstance(with_message, list):
msg = Message(content="\n".join(with_message))
if not msg.cause_by:
msg.cause_by = UserRequirement
self.put_message(msg)
if not await self._observe():

View file

@ -14,11 +14,11 @@ import uuid
import pytest
from pydantic import BaseModel
from metagpt.actions import Action, ActionOutput
from metagpt.actions import Action, ActionOutput, UserRequirement
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.utils.common import get_class_name
from metagpt.utils.common import any_to_str, get_class_name
class MockAction(Action):
@ -60,7 +60,7 @@ async def test_react():
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
)
role.subscribe({seed.subscription})
assert role._rc.watch == set({})
assert role._rc.watch == {any_to_str(UserRequirement)}
assert role.name == seed.name
assert role.profile == seed.profile
assert role._setting.goal == seed.goal