feat: + subscribe

This commit is contained in:
莘权 马 2023-11-02 11:51:10 +08:00
parent 8572fa8ecd
commit 660f788683
3 changed files with 68 additions and 14 deletions

View file

@ -134,6 +134,10 @@ class Role(Named):
def _watch(self, actions: Iterable[Type[Action]]):
"""Listen to the corresponding behaviors"""
tags = [get_class_name(t) for t in actions]
self.subscribe(tags)
def subscribe(self, tags: Set[str]):
"""Listen to the corresponding behaviors"""
self._rc.watch.update(tags)
# check RoleContext after adding watch actions
self._rc.check(self._role_id)

View file

@ -0,0 +1,64 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023-11-1
@Author : mashenquan
@File : test_role.py
"""
import pytest
from pydantic import BaseModel
from metagpt.actions import Action, ActionOutput
from metagpt.environment import Environment
from metagpt.roles import Role
from metagpt.schema import Message
class MockAction(Action):
async def run(self, messages, *args, **kwargs):
assert messages
return ActionOutput(content=messages[-1].content, instruct_content=messages[-1])
class MockRole(Role):
def __init__(self, name="", profile="", goal="", constraints="", desc=""):
super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc)
self._init_actions([MockAction()])
@pytest.mark.asyncio
async def test_react():
class Input(BaseModel):
name: str
profile: str
goal: str
constraints: str
desc: str
subscription: str
inputs = [
{
"name": "A",
"profile": "Tester",
"goal": "Test",
"constraints": "constraints",
"desc": "desc",
"subscription": "start",
}
]
for i in inputs:
seed = Input(**i)
role = MockRole(
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
)
role.subscribe({seed.subscription})
env = Environment()
env.add_role(role)
env.publish_message(Message(content="test", cause_by=seed.subscription))
while not env.is_idle:
await env.run()
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,14 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/11 14:44
@Author : alexanderwu
@File : test_role.py
"""
from metagpt.roles import Role
def test_role_desc():
i = Role(profile='Sales', desc='Best Seller')
assert i.profile == 'Sales'
assert i._setting.desc == 'Best Seller'