add agent subscription

This commit is contained in:
shenchucheng 2023-12-01 16:10:38 +08:00
parent 6f345002c4
commit dfc6e13ac3
3 changed files with 217 additions and 2 deletions

101
metagpt/subscription.py Normal file
View file

@ -0,0 +1,101 @@
import asyncio
from typing import AsyncGenerator, Awaitable, Callable
from pydantic import BaseModel, Field
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
class SubscriptionRunner(BaseModel):
"""A simple wrapper to manage subscription tasks for different roles using asyncio.
Example:
>>> import asyncio
>>> from metagpt.subscription import SubscriptionRunner
>>> from metagpt.roles import Searcher
>>> from metagpt.schema import Message
>>> async def trigger():
... while True:
... yield Message("the latest news about OpenAI")
... await asyncio.sleep(3600 * 24)
>>> async def callback(msg: Message):
... print(msg.content)
>>> async def main():
... pb = SubscriptionRunner()
... await pb.subscribe(Searcher(), trigger(), callback)
... await pb.run()
>>> asyncio.run(main())
"""
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
class Config:
arbitrary_types_allowed = True
async def subscribe(
self,
role: Role,
trigger: AsyncGenerator[Message, None],
callback: Callable[
[
Message,
],
Awaitable[None],
],
):
"""Subscribes a role to a trigger and sets up a callback to be called with the role's response.
Args:
role: The role to subscribe.
trigger: An asynchronous generator that yields Messages to be processed by the role.
callback: An asynchronous function to be called with the response from the role.
"""
loop = asyncio.get_running_loop()
async def _start_role():
async for msg in trigger:
resp = await role.run(msg)
await callback(resp)
self.tasks[role] = loop.create_task(_start_role(), name=f"Subscription-{role}")
async def unsubscribe(self, role: Role):
"""Unsubscribes a role from its trigger and cancels the associated task.
Args:
role: The role to unsubscribe.
"""
task = self.tasks.pop(role)
task.cancel()
async def run(self, raise_exception: bool = True):
"""Runs all subscribed tasks and handles their completion or exception.
Args:
raise_exception: _description_. Defaults to True.
Raises:
task.exception: _description_
"""
while True:
for role, task in self.tasks.items():
if task.done():
if task.exception():
if raise_exception:
raise task.exception()
logger.opt(exception=task.exception()).error(f"Task {task.get_name()} run error")
else:
logger.warning(
f"Task {task.get_name()} has completed. "
"If this is unexpected behavior, please check the trigger function."
)
self.tasks.pop(role)
break
else:
await asyncio.sleep(1)

View file

@ -6,14 +6,15 @@
@File : conftest.py
"""
import asyncio
import logging
import re
from unittest.mock import Mock
import pytest
from metagpt.logs import logger
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
import asyncio
import re
class Context:
@ -68,3 +69,14 @@ def proxy():
server = asyncio.get_event_loop().run_until_complete(asyncio.start_server(handle_client, "127.0.0.1", 0))
return "http://{}:{}".format(*server.sockets[0].getsockname())
# see https://github.com/Delgan/loguru/issues/59#issuecomment-466591978
@pytest.fixture
def loguru_caplog(caplog):
class PropogateHandler(logging.Handler):
def emit(self, record):
logging.getLogger(record.name).handle(record)
logger.add(PropogateHandler(), format="{message}")
yield caplog

View file

@ -0,0 +1,102 @@
import asyncio
import pytest
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.subscription import SubscriptionRunner
@pytest.mark.asyncio
async def test_subscription_run():
callback_done = 0
async def trigger():
while True:
yield Message("the latest news about OpenAI")
await asyncio.sleep(3600 * 24)
class MockRole(Role):
async def run(self, message=None):
return Message("")
async def callback(message):
nonlocal callback_done
callback_done += 1
runner = SubscriptionRunner()
roles = []
for _ in range(2):
role = MockRole()
roles.append(role)
await runner.subscribe(role, trigger(), callback)
task = asyncio.get_running_loop().create_task(runner.run())
for _ in range(10):
if callback_done == 2:
break
await asyncio.sleep(0)
else:
raise TimeoutError("callback not call")
role = roles[0]
assert role in runner.tasks
await runner.unsubscribe(roles[0])
for _ in range(10):
if role not in runner.tasks:
break
await asyncio.sleep(0)
else:
raise TimeoutError("callback not call")
task.cancel()
for i in runner.tasks.values():
i.cancel()
@pytest.mark.asyncio
async def test_subscription_run_error(loguru_caplog):
async def trigger1():
while True:
yield Message("the latest news about OpenAI")
await asyncio.sleep(3600 * 24)
async def trigger2():
yield Message("the latest news about OpenAI")
class MockRole1(Role):
async def run(self, message=None):
raise RuntimeError
class MockRole2(Role):
async def run(self, message=None):
return Message("")
async def callback(msg: Message):
print(msg)
runner = SubscriptionRunner()
await runner.subscribe(MockRole1(), trigger1(), callback)
with pytest.raises(RuntimeError):
await runner.run()
await runner.subscribe(MockRole2(), trigger2(), callback)
task = asyncio.get_running_loop().create_task(runner.run(False))
for _ in range(10):
if not runner.tasks:
break
await asyncio.sleep(0)
else:
raise TimeoutError("wait runner tasks empty timeout")
task.cancel()
for i in runner.tasks.values():
i.cancel()
assert len(loguru_caplog.records) >= 2
logs = "".join(loguru_caplog.messages)
assert "run error" in logs
assert "has completed" in logs