diff --git a/examples/debate.py b/examples/debate.py index 54da73cca..44602a11a 100644 --- a/examples/debate.py +++ b/examples/debate.py @@ -2,6 +2,8 @@ Filename: MetaGPT/examples/debate.py Created Date: Tuesday, September 19th 2023, 6:52:25 pm Author: garylin2099 +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `send_to` + value of the `Message` object; modify the argument type of `get_by_actions`. """ import asyncio import platform @@ -14,6 +16,8 @@ from metagpt.roles import Role from metagpt.schema import Message from metagpt.software_company import SoftwareCompany +from metagpt.utils.common import any_to_str + class ShoutOut(Action): """Action: Shout out loudly in a debate (quarrel)""" @@ -57,7 +61,8 @@ class Trump(Role): async def _observe(self) -> int: await super()._observe() # accept messages sent (from opponent) to self, disregard own messages from the last round - self._rc.news = [msg for msg in self._rc.news if msg.send_to == self.name] + + self._rc.news = [msg for msg in self._rc.news if msg.send_to == {self.name}] return len(self._rc.news) async def _act(self) -> Message: @@ -99,7 +104,9 @@ class Biden(Role): await super()._observe() # accept the very first human instruction (the debate topic) or messages sent (from opponent) to self, # disregard own messages from the last round - self._rc.news = [msg for msg in self._rc.news if msg.cause_by == BossRequirement or msg.send_to == self.name] + self._rc.news = [ + msg for msg in self._rc.news if msg.cause_by == any_to_str(BossRequirement) or msg.send_to == {self.name} + ] return len(self._rc.news) async def _act(self) -> Message: diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index a922d3694..58499df62 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -4,6 +4,8 @@ @Time : 2023/5/11 17:45 @Author : alexanderwu @File : write_code.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.1.3 of RFC 116, modify the data type of the `cause_by` + value of the `Message` object. """ from tenacity import retry, stop_after_attempt, wait_fixed @@ -12,7 +14,8 @@ from metagpt.actions.action import Action from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.schema import Message -from metagpt.utils.common import CodeParser + +from metagpt.utils.common import CodeParser, any_to_str PROMPT_TEMPLATE = """ NOTICE @@ -56,7 +59,7 @@ class WriteCode(Action): if self._is_invalid(filename): return - design = [i for i in context if i.cause_by == WriteDesign][0] + design = [i for i in context if i.cause_by == any_to_str(WriteDesign)][0] ws_name = CodeParser.parse_str(block="Python package name", text=design.content) ws_path = WORKSPACE_ROOT / ws_name diff --git a/metagpt/const.py b/metagpt/const.py index 7f3f87dfa..fa0ccc536 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -4,6 +4,8 @@ @Time : 2023/5/1 11:59 @Author : alexanderwu @File : const.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, added key definitions for + common properties in the Message. """ from pathlib import Path @@ -41,3 +43,9 @@ INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table" SKILL_DIRECTORY = PROJECT_ROOT / "metagpt/skills" MEM_TTL = 24 * 30 * 3600 + +MESSAGE_ROUTE_FROM = "sent_from" +MESSAGE_ROUTE_TO = "send_to" +MESSAGE_ROUTE_CAUSE_BY = "cause_by" +MESSAGE_META_ROLE = "role" +MESSAGE_ROUTE_TO_ALL = "" diff --git a/metagpt/environment.py b/metagpt/environment.py index 2e2aa152a..a50dbde9e 100644 --- a/metagpt/environment.py +++ b/metagpt/environment.py @@ -4,15 +4,22 @@ @Time : 2023/5/11 22:12 @Author : alexanderwu @File : environment.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116: + 1. Remove the functionality of `Environment` class as a public message buffer. + 2. Standardize the message forwarding behavior of the `Environment` class. + 3. Add the `is_idle` property. +@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing + functionality is to be consolidated into the `Environment` class. """ import asyncio -from typing import Iterable +from typing import Iterable, Set from pydantic import BaseModel, Field -from metagpt.memory import Memory +from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message +from metagpt.utils.common import is_subscribed class Environment(BaseModel): @@ -22,8 +29,9 @@ class Environment(BaseModel): """ roles: dict[str, Role] = Field(default_factory=dict) - memory: Memory = Field(default_factory=Memory) - history: str = Field(default="") + members: dict[Role, Set] = Field(default_factory=dict) + history: str = Field(default="") # For debug + class Config: arbitrary_types_allowed = True @@ -42,22 +50,33 @@ class Environment(BaseModel): for role in roles: self.add_role(role) - def publish_message(self, message: Message): - """向当前环境发布信息 - Post information to the current environment + + def publish_message(self, message: Message) -> bool: """ - # self.message_queue.put(message) - self.memory.add(message) - self.history += f"\n{message}" + Distribute the message to the recipients. + In accordance with the Message routing structure design in Chapter 2.2.1 of RFC 116, as already planned + in RFC 113 for the entire system, the routing information in the Message is only responsible for + specifying the message recipient, without concern for where the message recipient is located. How to + route the message to the message recipient is a problem addressed by the transport framework designed + in RFC 113. + """ + logger.info(f"publish_message: {message.dump()}") + found = False + # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + for role, subscription in self.members.items(): + if is_subscribed(message, subscription): + role.put_message(message) + found = True + if not found: + logger.warning(f"Message no recipients: {message.dump()}") + self.history += f"\n{message}" # For debug + + return True async def run(self, k=1): """处理一次所有信息的运行 Process all Role runs at once """ - # while not self.message_queue.empty(): - # message = self.message_queue.get() - # rsp = await self.manager.handle(message, self) - # self.message_queue.put(rsp) for _ in range(k): futures = [] for role in self.roles.values(): @@ -65,6 +84,7 @@ class Environment(BaseModel): futures.append(future) await asyncio.gather(*futures) + logger.info(f"is idle: {self.is_idle}") def get_roles(self) -> dict[str, Role]: """获得环境内的所有角色 @@ -77,3 +97,19 @@ class Environment(BaseModel): get all the environment roles """ return self.roles.get(name, None) + + @property + def is_idle(self): + """If true, all actions have been executed.""" + for r in self.roles.values(): + if not r.is_idle: + return False + return True + + def get_subscription(self, obj): + """Get the labels for messages to be consumed by the object.""" + return self.members.get(obj, {}) + + def set_subscription(self, obj, tags): + """Set the labels for message to be consumed by the object""" + self.members[obj] = tags diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index e0b8e68c1..6fc8050ef 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the implement of Long-term memory +""" +@Desc : the implement of Long-term memory +""" from metagpt.logs import logger from metagpt.memory import Memory diff --git a/metagpt/memory/memory.py b/metagpt/memory/memory.py index 282f5fe33..53b65fcf7 100644 --- a/metagpt/memory/memory.py +++ b/metagpt/memory/memory.py @@ -4,12 +4,13 @@ @Time : 2023/5/20 12:15 @Author : alexanderwu @File : memory.py +@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key. """ from collections import defaultdict -from typing import Iterable, Type +from typing import Iterable, Set -from metagpt.actions import Action from metagpt.schema import Message +from metagpt.utils.common import any_to_str, any_to_str_set class Memory: @@ -18,7 +19,7 @@ class Memory: def __init__(self): """Initialize an empty storage list and an empty index dictionary""" self.storage: list[Message] = [] - self.index: dict[Type[Action], list[Message]] = defaultdict(list) + self.index: dict[str, list[Message]] = defaultdict(list) def add(self, message: Message): """Add a new message to storage, while updating the index""" @@ -73,14 +74,16 @@ class Memory: news.append(i) return news - def get_by_action(self, action: Type[Action]) -> list[Message]: + def get_by_action(self, action) -> list[Message]: """Return all messages triggered by a specified Action""" - return self.index[action] + index = any_to_str(action) + return self.index[index] - def get_by_actions(self, actions: Iterable[Type[Action]]) -> list[Message]: + def get_by_actions(self, actions: Set) -> list[Message]: """Return all messages triggered by specified Actions""" rsp = [] - for action in actions: + indices = any_to_str_set(actions) + for action in indices: if action not in self.index: continue rsp += self.index[action] diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index 6d65575a8..ffd96849b 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -4,6 +4,12 @@ @Time : 2023/5/11 14:43 @Author : alexanderwu @File : engineer.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116: + 1. Modify the data type of the `cause_by` value in the `Message` to a string, and utilize the new message + distribution feature for message filtering. + 2. Consolidate message reception and processing logic within `_observe`. + 3. Fix bug: Add logic for handling asynchronous message processing when messages are not ready. + 4. Supplemented the external transmission of internal messages. """ import asyncio import shutil @@ -15,7 +21,7 @@ from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import CodeParser +from metagpt.utils.common import CodeParser, any_to_str from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -119,17 +125,13 @@ class Engineer(Role): file.write_text(code) return file - def recv(self, message: Message) -> None: - self._rc.memory.add(message) - if message in self._rc.important_memory: - self.todos = self.parse_tasks(message) - async def _act_mp(self) -> Message: # self.recreate_workspace() todo_coros = [] for todo in self.todos: todo_coro = WriteCode().run( - context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), filename=todo + context=self._rc.memory.get_by_actions([WriteTasks, WriteDesign]), + filename=todo, ) todo_coros.append(todo_coro) @@ -139,12 +141,12 @@ class Engineer(Role): logger.info(todo) logger.info(code_rsp) # self.write_file(todo, code) - msg = Message(content=code_rsp, role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content=code_rsp, role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(msg) del self.todos[0] logger.info(f"Done {self.get_workspace()} generating.") - msg = Message(content="all done.", role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content="all done.", role=self.profile, cause_by=self._rc.todo) return msg async def _act_sp(self) -> Message: @@ -155,7 +157,7 @@ class Engineer(Role): # logger.info(code_rsp) # code = self.parse_code(code_rsp) file_path = self.write_file(todo, code) - msg = Message(content=code, role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content=code, role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(msg) code_msg = todo + FILENAME_CODE_SEP + str(file_path) @@ -163,7 +165,10 @@ class Engineer(Role): logger.info(f"Done {self.get_workspace()} generating.") msg = Message( - content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" + content=MSG_SEP.join(code_msg_all), + role=self.profile, + cause_by=self._rc.todo, + send_to="Edward", # name of QaEngineer ) return msg @@ -201,12 +206,31 @@ class Engineer(Role): logger.info(f"Done {self.get_workspace()} generating.") msg = Message( - content=MSG_SEP.join(code_msg_all), role=self.profile, cause_by=type(self._rc.todo), send_to="QaEngineer" + content=MSG_SEP.join(code_msg_all), + role=self.profile, + cause_by=self._rc.todo, + send_to="Edward", # name of QaEngineer ) return msg async def _act(self) -> Message: """Determines the mode of action based on whether code review is used.""" + if not self._rc.todo: + return None if self.use_code_review: return await self._act_sp_precision() return await self._act_sp() + + async def _observe(self) -> int: + ret = await super(Engineer, self)._observe() + if ret == 0: + return ret + + # Parse task lists + for message in self._rc.news: + if not message.cause_by == any_to_str(WriteTasks): + continue + self.todos = self.parse_tasks(message) + return 1 + + return 0 diff --git a/metagpt/roles/qa_engineer.py b/metagpt/roles/qa_engineer.py index a763c2ce8..59a4135b8 100644 --- a/metagpt/roles/qa_engineer.py +++ b/metagpt/roles/qa_engineer.py @@ -4,6 +4,8 @@ @Time : 2023/5/11 14:43 @Author : alexanderwu @File : qa_engineer.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, modify the data + type of the `cause_by` value in the `Message` to a string, and utilize the new message filtering feature. """ import os from pathlib import Path @@ -20,7 +22,7 @@ from metagpt.const import WORKSPACE_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -from metagpt.utils.common import CodeParser, parse_recipient +from metagpt.utils.common import CodeParser, any_to_str_set, parse_recipient from metagpt.utils.special_tokens import FILENAME_CODE_SEP, MSG_SEP @@ -98,10 +100,10 @@ class QaEngineer(Role): content=str(file_info), role=self.profile, cause_by=WriteTest, - sent_from=self.profile, - send_to=self.profile, + sent_from=self, + send_to=self, ) - self._publish_message(msg) + self.publish_message(msg) logger.info(f"Done {self.get_workspace()}/tests generating.") @@ -132,7 +134,7 @@ class QaEngineer(Role): recipient = parse_recipient(result_msg) # the recipient might be Engineer or myself content = str(file_info) + FILENAME_CODE_SEP + result_msg msg = Message(content=content, role=self.profile, cause_by=RunCode, sent_from=self.profile, send_to=recipient) - self._publish_message(msg) + self.publish_message(msg) async def _debug_error(self, msg): file_info, context = msg.content.split(FILENAME_CODE_SEP) @@ -141,16 +143,13 @@ class QaEngineer(Role): self.write_file(file_name, code) recipient = msg.sent_from # send back to the one who ran the code for another run, might be one's self msg = Message( - content=file_info, role=self.profile, cause_by=DebugError, sent_from=self.profile, send_to=recipient + content=file_info, + role=self.profile, + cause_by=DebugError, + sent_from=self, + send_to=self, ) - self._publish_message(msg) - - async def _observe(self) -> int: - await super()._observe() - self._rc.news = [ - msg for msg in self._rc.news if msg.send_to == self.profile - ] # only relevant msgs count as observed news - return len(self._rc.news) + self.publish_message(msg) async def _act(self) -> Message: if self.test_round > self.test_round_allowed: @@ -159,20 +158,23 @@ class QaEngineer(Role): role=self.profile, cause_by=WriteTest, sent_from=self.profile, - send_to="", + send_to="" ) return result_msg + code_filters = any_to_str_set({WriteCode, WriteCodeReview}) + test_filters = any_to_str_set({WriteTest, DebugError}) + run_filters = any_to_str_set({RunCode}) for msg in self._rc.news: # Decide what to do based on observed msg type, currently defined by human, # might potentially be moved to _think, that is, let the agent decides for itself - if msg.cause_by in [WriteCode, WriteCodeReview]: + if msg.cause_by in code_filters: # engineer wrote a code, time to write a test for it await self._write_test(msg) - elif msg.cause_by in [WriteTest, DebugError]: + elif msg.cause_by in test_filters: # I wrote or debugged my test code, time to run it await self._run_code(msg) - elif msg.cause_by == RunCode: + elif msg.cause_by in run_filters: # I ran my test code, time to fix bugs, if any await self._debug_error(msg) self.test_round += 1 @@ -181,6 +183,6 @@ class QaEngineer(Role): role=self.profile, cause_by=WriteTest, sent_from=self.profile, - send_to="", + send_to="" ) return result_msg diff --git a/metagpt/roles/researcher.py b/metagpt/roles/researcher.py index acb46c718..29889b8ec 100644 --- a/metagpt/roles/researcher.py +++ b/metagpt/roles/researcher.py @@ -1,4 +1,9 @@ #!/usr/bin/env python +""" +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of + the `cause_by` value in the `Message` to a string to support the new message distribution feature. +""" + import asyncio @@ -58,18 +63,18 @@ class Researcher(Role): research_system_text = get_research_system_text(topic, self.language) if isinstance(todo, CollectLinks): links = await todo.run(topic, 4, 4) - ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=type(todo)) + ret = Message("", Report(topic=topic, links=links), role=self.profile, cause_by=todo) elif isinstance(todo, WebBrowseAndSummarize): links = instruct_content.links todos = (todo.run(*url, query=query, system_text=research_system_text) for (query, url) in links.items()) summaries = await asyncio.gather(*todos) summaries = list((url, summary) for i in summaries for (url, summary) in i.items() if summary) - ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=type(todo)) + ret = Message("", Report(topic=topic, summaries=summaries), role=self.profile, cause_by=todo) else: summaries = instruct_content.summaries summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries) content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text) - ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=type(self._rc.todo)) + ret = Message("", Report(topic=topic, content=content), role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(ret) return ret diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 282431bf7..424a28c6f 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -4,21 +4,36 @@ @Time : 2023/5/11 14:42 @Author : alexanderwu @File : role.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116: + 1. Merge the `recv` functionality into the `_observe` function. Future message reading operations will be + consolidated within the `_observe` function. + 2. Standardize the message filtering for string label matching. Role objects can access the message labels + they've subscribed to through the `subscribed_tags` property. + 3. Move the message receive buffer from the global variable `self._rc.env.memory` to the role's private variable + `self._rc.msg_buffer` for easier message identification and asynchronous appending of messages. + 4. Standardize the way messages are passed: `publish_message` sends messages out, while `put_message` places + messages into the Role object's private message receive buffer. There are no other message transmit methods. + 5. Standardize the parameters for the `run` function: the `test_message` parameter is used for testing purposes + only. In the normal workflow, you should use `publish_message` or `put_message` to transmit messages. +@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing + functionality is to be consolidated into the `Environment` class. """ from __future__ import annotations -from typing import Iterable, Type +from typing import Iterable, Set, Type from pydantic import BaseModel, Field from metagpt.actions import Action, ActionOutput -# from metagpt.environment import Environment from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import LongTermMemory, Memory -from metagpt.schema import Message + +from metagpt.schema import Message, MessageQueue +from metagpt.utils.common import any_to_str + PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """ @@ -68,11 +83,12 @@ class RoleContext(BaseModel): """Role Runtime Context""" env: "Environment" = Field(default=None) + msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates memory: Memory = Field(default_factory=Memory) long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory) state: int = Field(default=0) todo: Action = Field(default=None) - watch: set[Type[Action]] = Field(default_factory=set) + watch: set[str] = Field(default_factory=set) news: list[Type[Message]] = Field(default=[]) class Config: @@ -103,6 +119,7 @@ class Role: self._actions = [] self._role_id = str(self._setting) self._rc = RoleContext() + self._subscription = {any_to_str(self), name} if name else {any_to_str(self)} def _reset(self): self._states = [] @@ -120,11 +137,23 @@ class Role: self._states.append(f"{idx}. {action}") def _watch(self, actions: Iterable[Type[Action]]): - """Listen to the corresponding behaviors""" - self._rc.watch.update(actions) + """Watch Actions of interest. Role will select Messages caused by these Actions from its personal message + buffer during _observe. + """ + tags = {any_to_str(t) for t in actions} + self._rc.watch.update(tags) # check RoleContext after adding watch actions self._rc.check(self._role_id) + def subscribe(self, tags: Set[str]): + """Used to receive Messages with certain tags from the environment. Message will be put into personal message + buffer to be further processed in _observe. By default, a Role subscribes Messages with a tag of its own name + or profile. + """ + self._subscription = tags + if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113 + self._rc.env.set_subscription(self, self._subscription) + def _set_state(self, state): """Update the current state.""" self._rc.state = state @@ -132,14 +161,27 @@ class Role: self._rc.todo = self._actions[self._rc.state] def set_env(self, env: "Environment"): - """Set the environment in which the role works. The role can talk to the environment and can also receive messages by observing.""" + """Set the environment in which the role works. The role can talk to the environment and can also receive + messages by observing.""" self._rc.env = env + if env: + env.set_subscription(self, self._subscription) @property def profile(self): """Get the role description (position)""" return self._setting.profile + @property + def name(self): + """Get virtual user name""" + return self._setting.name + + @property + def subscription(self) -> Set: + """The labels for messages to be consumed by the Role object.""" + return self._subscription + def _get_prefix(self): """Get the role prefix""" if self._setting.desc: @@ -164,90 +206,86 @@ class Role: self._set_state(int(next_state)) async def _act(self) -> Message: - # prompt = self.get_prefix() - # prompt += ROLE_TEMPLATE.format(name=self.profile, state=self.states[self.state], result=response, - # history=self.history) - logger.info(f"{self._setting}: ready to {self._rc.todo}") response = await self._rc.todo.run(self._rc.important_memory) - # logger.info(response) if isinstance(response, ActionOutput): msg = Message( content=response.content, instruct_content=response.instruct_content, role=self.profile, - cause_by=type(self._rc.todo), + cause_by=self._rc.todo, ) else: - msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content=response, role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(msg) - # logger.debug(f"{response}") return msg async def _observe(self) -> int: - """Observe from the environment, obtain important information, and add it to memory""" - if not self._rc.env: - return 0 - env_msgs = self._rc.env.memory.get() - - observed = self._rc.env.memory.get_by_actions(self._rc.watch) - - self._rc.news = self._rc.memory.find_news( - observed - ) # find news (previously unseen messages) from observed messages - - for i in env_msgs: - self.recv(i) + """Prepare new messages for processing from the message buffer and other sources.""" + # Read unprocessed messages from the msg buffer. + news = self._rc.msg_buffer.pop_all() + # Store the read messages in your own memory to prevent duplicate processing. + old_messages = self._rc.memory.get() + self._rc.memory.add_batch(news) + # Filter out messages of interest. + self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages] + # Design Rules: + # If you need to further categorize Message objects, you can do so using the Message.set_meta function. + # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer. news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news] if news_text: logger.debug(f"{self._setting} observed: {news_text}") return len(self._rc.news) - def _publish_message(self, msg): + def publish_message(self, msg): """If the role belongs to env, then the role's messages will be broadcast to env""" + if not msg: + return if not self._rc.env: # If env does not exist, do not publish the message return self._rc.env.publish_message(msg) + def put_message(self, message): + """Place the message into the Role object's private message buffer.""" + if not message: + return + self._rc.msg_buffer.push(message) + async def _react(self) -> Message: """Think first, then act""" await self._think() logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}") return await self._act() - def recv(self, message: Message) -> None: - """add message to history.""" - # self._history += f"\n{message}" - # self._context = self._history - if message in self._rc.memory.get(): - return - self._rc.memory.add(message) - - async def handle(self, message: Message) -> Message: - """Receive information and reply with actions""" - # logger.debug(f"{self.name=}, {self.profile=}, {message.role=}") - self.recv(message) - - return await self._react() - - async def run(self, message=None): + async def run(self, with_message=None): """Observe, and think and act based on the results of the observation""" - if message: - if isinstance(message, str): - message = Message(message) - if isinstance(message, Message): - self.recv(message) - if isinstance(message, list): - self.recv(Message("\n".join(message))) - elif not await self._observe(): + if with_message: + msg = None + if isinstance(with_message, str): + msg = Message(with_message) + elif isinstance(with_message, Message): + msg = with_message + elif isinstance(with_message, list): + msg = Message("\n".join(with_message)) + self.put_message(msg) + + if not await self._observe(): # If there is no new information, suspend and wait logger.debug(f"{self._setting}: no news. waiting.") return rsp = await self._react() - # Publish the reply to the environment, waiting for the next subscriber to process - self._publish_message(rsp) + + # Reset the next action to be taken. + self._rc.todo = None + # Send the response message to the Environment object to have it relay the message to the subscribers. + self.publish_message(rsp) return rsp + + @property + def is_idle(self) -> bool: + """If true, all actions have been executed.""" + return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty() diff --git a/metagpt/roles/seacher.py b/metagpt/roles/seacher.py index a2c4896e2..587698d1d 100644 --- a/metagpt/roles/seacher.py +++ b/metagpt/roles/seacher.py @@ -4,6 +4,8 @@ @Time : 2023/5/23 17:25 @Author : alexanderwu @File : seacher.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of + the `cause_by` value in the `Message` to a string to support the new message distribution feature. """ from metagpt.actions import ActionOutput, SearchAndSummarize from metagpt.logs import logger @@ -61,10 +63,10 @@ class Searcher(Role): content=response.content, instruct_content=response.instruct_content, role=self.profile, - cause_by=type(self._rc.todo), + cause_by=self._rc.todo, ) else: - msg = Message(content=response, role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content=response, role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(msg) return msg diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index b27841d74..2443b8b58 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -4,6 +4,8 @@ @Time : 2023/9/13 12:23 @Author : femto Zheng @File : sk_agent.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message + distribution feature for message filtering. """ from semantic_kernel.planning import SequentialPlanner from semantic_kernel.planning.action_planner.action_planner import ActionPlanner @@ -70,7 +72,6 @@ class SkAgent(Role): result = (await self.plan.invoke_async()).result logger.info(result) - msg = Message(content=result, role=self.profile, cause_by=type(self._rc.todo)) + msg = Message(content=result, role=self.profile, cause_by=self._rc.todo) self._rc.memory.add(msg) - # logger.debug(f"{response}") return msg diff --git a/metagpt/schema.py b/metagpt/schema.py index 19c7a6654..b40d65d5f 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -4,15 +4,27 @@ @Time : 2023/5/8 22:12 @Author : alexanderwu @File : schema.py +@Modified By: mashenquan, 2023-10-31. According to Chapter 2.2.1 of RFC 116: + Replanned the distribution of responsibilities and functional positioning of `Message` class attributes. """ from __future__ import annotations -from dataclasses import dataclass, field -from typing import Type, TypedDict +import asyncio +import json +from asyncio import Queue, QueueEmpty, wait_for +from json import JSONDecodeError +from typing import List, Set, TypedDict -from pydantic import BaseModel +from pydantic import BaseModel, Field +from metagpt.const import ( + MESSAGE_ROUTE_CAUSE_BY, + MESSAGE_ROUTE_FROM, + MESSAGE_ROUTE_TO, + MESSAGE_ROUTE_TO_ALL, +) from metagpt.logs import logger +from metagpt.utils.common import any_to_str, any_to_str_set class RawMessage(TypedDict): @@ -20,17 +32,57 @@ class RawMessage(TypedDict): role: str -@dataclass -class Message: +class Message(BaseModel): """list[: ]""" content: str - instruct_content: BaseModel = field(default=None) - role: str = field(default="user") # system / user / assistant - cause_by: Type["Action"] = field(default="") - sent_from: str = field(default="") - send_to: str = field(default="") - restricted_to: str = field(default="") + instruct_content: BaseModel = Field(default=None) + role: str = "user" # system / user / assistant + cause_by: str = "" + sent_from: str = "" + send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL}) + + def __init__( + self, + content, + instruct_content=None, + role="user", + cause_by="", + sent_from="", + send_to=MESSAGE_ROUTE_TO_ALL, + **kwargs, + ): + """ + Parameters not listed below will be stored as meta info, including custom parameters. + :param content: Message content. + :param instruct_content: Message content struct. + :param cause_by: Message producer + :param sent_from: Message route info tells who sent this message. + :param send_to: Specifies the target recipient or consumer for message delivery in the environment. + :param role: Message meta info tells who sent this message. + """ + super().__init__( + content=content, + instruct_content=instruct_content, + role=role, + cause_by=any_to_str(cause_by), + sent_from=any_to_str(sent_from), + send_to=any_to_str_set(send_to), + **kwargs, + ) + + def __setattr__(self, key, val): + """Override `@property.setter`, convert non-string parameters into string parameters.""" + if key == MESSAGE_ROUTE_CAUSE_BY: + new_val = any_to_str(val) + elif key == MESSAGE_ROUTE_FROM: + new_val = any_to_str(val) + elif key == MESSAGE_ROUTE_TO: + new_val = any_to_str_set(val) + else: + new_val = val + super().__setattr__(key, new_val) + def __str__(self): # prefix = '-'.join([self.role, str(self.cause_by)]) @@ -40,45 +92,115 @@ class Message: return self.__str__() def to_dict(self) -> dict: + """Return a dict containing `role` and `content` for the LLM call.l""" return {"role": self.role, "content": self.content} + def dump(self) -> str: + """Convert the object to json string""" + return self.json(exclude_none=True) + + @staticmethod + def load(val): + """Convert the json string to object.""" + try: + d = json.loads(val) + return Message(**d) + except JSONDecodeError as err: + logger.error(f"parse json failed: {val}, error:{err}") + return None + + -@dataclass class UserMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, "user") + super().__init__(content=content, role="user") + -@dataclass class SystemMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, "system") + super().__init__(content=content, role="system") -@dataclass class AIMessage(Message): """便于支持OpenAI的消息 Facilitate support for OpenAI messages """ def __init__(self, content: str): - super().__init__(content, "assistant") + super().__init__(content=content, role="assistant") -if __name__ == "__main__": - test_content = "test_message" - msgs = [ - UserMessage(test_content), - SystemMessage(test_content), - AIMessage(test_content), - Message(test_content, role="QA"), - ] - logger.info(msgs) +class MessageQueue: + """Message queue which supports asynchronous updates.""" + + def __init__(self): + self._queue = Queue() + + def pop(self) -> Message | None: + """Pop one message from the queue.""" + try: + item = self._queue.get_nowait() + if item: + self._queue.task_done() + return item + except QueueEmpty: + return None + + def pop_all(self) -> List[Message]: + """Pop all messages from the queue.""" + ret = [] + while True: + msg = self.pop() + if not msg: + break + ret.append(msg) + return ret + + def push(self, msg: Message): + """Push a message into the queue.""" + self._queue.put_nowait(msg) + + def empty(self): + """Return true if the queue is empty.""" + return self._queue.empty() + + async def dump(self) -> str: + """Convert the `MessageQueue` object to a json string.""" + if self.empty(): + return "[]" + + lst = [] + try: + while True: + item = await wait_for(self._queue.get(), timeout=1.0) + if item is None: + break + lst.append(item.dict(exclude_none=True)) + self._queue.task_done() + except asyncio.TimeoutError: + logger.debug("Queue is empty, exiting...") + return json.dumps(lst) + + @staticmethod + def load(self, v) -> "MessageQueue": + """Convert the json string to the `MessageQueue` object.""" + q = MessageQueue() + try: + lst = json.loads(v) + for i in lst: + msg = Message(**i) + q.push(msg) + except JSONDecodeError as e: + logger.warning(f"JSON load failed: {v}, error:{e}") + + return q + diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index f09666beb..798acf214 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -4,6 +4,8 @@ @Time : 2023/4/29 16:07 @Author : alexanderwu @File : common.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.2 of RFC 116: + Add generic class-to-string and object-to-string conversion functionality. """ import ast import contextlib @@ -13,6 +15,7 @@ import platform import re from typing import List, Tuple, Union +from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger @@ -85,10 +88,7 @@ class OutputParser: @staticmethod def parse_python_code(text: str) -> str: - for pattern in ( - r"(.*?```python.*?\s+)?(?P.*)(```.*?)", - r"(.*?```python.*?\s+)?(?P.*)", - ): + for pattern in (r"(.*?```python.*?\s+)?(?P.*)(```.*?)", r"(.*?```python.*?\s+)?(?P.*)"): match = re.search(pattern, text, re.DOTALL) if not match: continue @@ -305,3 +305,46 @@ def parse_recipient(text): pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) return recipient.group(1) if recipient else "" + + +def get_class_name(cls) -> str: + """Return class name""" + return f"{cls.__module__}.{cls.__name__}" + + +def get_object_name(obj) -> str: + """Return class name of the object""" + cls = type(obj) + return f"{cls.__module__}.{cls.__name__}" + + +def any_to_str(val) -> str: + """Return the class name or the class name of the object, or 'val' if it's a string type.""" + if isinstance(val, str): + return val + if not callable(val): + return get_object_name(val) + + return get_class_name(val) + + +def any_to_str_set(val) -> set: + """Convert any type to string set.""" + res = set() + if isinstance(val, dict) or isinstance(val, list) or isinstance(val, set) or isinstance(val, tuple): + for i in val: + res.add(any_to_str(i)) + else: + res.add(any_to_str(val)) + return res + + +def is_subscribed(message, tags): + """Return whether it's consumer""" + if MESSAGE_ROUTE_TO_ALL in message.send_to: + return True + + for t in tags: + if t in message.send_to: + return True + return False diff --git a/requirements.txt b/requirements.txt index 24a2d94c3..c3b909e77 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ langchain==0.0.231 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3 -openai +openai==0.28.1 openpyxl beautifulsoup4==4.12.2 pandas==2.0.3 diff --git a/tests/metagpt/actions/test_write_prd.py b/tests/metagpt/actions/test_write_prd.py index 38e4e5221..07d701cb9 100644 --- a/tests/metagpt/actions/test_write_prd.py +++ b/tests/metagpt/actions/test_write_prd.py @@ -4,6 +4,7 @@ @Time : 2023/5/11 17:45 @Author : alexanderwu @File : test_write_prd.py +@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, replace `handle` with `run`. """ import pytest @@ -17,7 +18,7 @@ from metagpt.schema import Message async def test_write_prd(): product_manager = ProductManager() requirements = "开发一个基于大语言模型与私有知识库的搜索引擎,希望可以基于大语言模型进行搜索总结" - prd = await product_manager.handle(Message(content=requirements, cause_by=BossRequirement)) + prd = await product_manager.run(Message(content=requirements, cause_by=BossRequirement)) logger.info(requirements) logger.info(prd) diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index 9682ba760..c5b5c6eb1 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : unittest of `metagpt/memory/longterm_memory.py` +""" +@Desc : unittest of `metagpt/memory/longterm_memory.py` +""" from metagpt.actions import BossRequirement from metagpt.config import CONFIG diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index 8b338a79e..251c70b02 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -1,6 +1,9 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the unittests of metagpt/memory/memory_storage.py +""" +@Desc : the unittests of metagpt/memory/memory_storage.py +""" + from typing import List diff --git a/tests/metagpt/planner/test_action_planner.py b/tests/metagpt/planner/test_action_planner.py index 5ab9a493f..b8d4c1ad9 100644 --- a/tests/metagpt/planner/test_action_planner.py +++ b/tests/metagpt/planner/test_action_planner.py @@ -4,6 +4,8 @@ @Time : 2023/9/16 20:03 @Author : femto Zheng @File : test_basic_planner.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message + distribution feature for message handling. """ import pytest from semantic_kernel.core_skills import FileIOSkill, MathSkill, TextSkill, TimeSkill @@ -23,7 +25,7 @@ async def test_action_planner(): role.import_skill(TimeSkill(), "time") role.import_skill(TextSkill(), "text") task = "What is the sum of 110 and 990?" - role.recv(Message(content=task, cause_by=BossRequirement)) - + role.put_message(Message(content=task, cause_by=BossRequirement)) + await role._observe() await role._think() # it will choose mathskill.Add assert "1100" == (await role._act()).content diff --git a/tests/metagpt/planner/test_basic_planner.py b/tests/metagpt/planner/test_basic_planner.py index 03a82ec5e..24250a0b0 100644 --- a/tests/metagpt/planner/test_basic_planner.py +++ b/tests/metagpt/planner/test_basic_planner.py @@ -4,6 +4,8 @@ @Time : 2023/9/16 20:03 @Author : femto Zheng @File : test_basic_planner.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message + distribution feature for message handling. """ import pytest from semantic_kernel.core_skills import TextSkill @@ -26,7 +28,8 @@ async def test_basic_planner(): role.import_semantic_skill_from_directory(SKILL_DIRECTORY, "WriterSkill") role.import_skill(TextSkill(), "TextSkill") # using BasicPlanner - role.recv(Message(content=task, cause_by=BossRequirement)) + role.put_message(Message(content=task, cause_by=BossRequirement)) + await role._observe() await role._think() # assuming sk_agent will think he needs WriterSkill.Brainstorm and WriterSkill.Translate assert "WriterSkill.Brainstorm" in role.plan.generated_plan.result diff --git a/tests/metagpt/roles/test_architect.py b/tests/metagpt/roles/test_architect.py index d44e0d923..111438b0b 100644 --- a/tests/metagpt/roles/test_architect.py +++ b/tests/metagpt/roles/test_architect.py @@ -4,6 +4,8 @@ @Time : 2023/5/20 14:37 @Author : alexanderwu @File : test_architect.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message + distribution feature for message handling. """ import pytest @@ -15,7 +17,7 @@ from tests.metagpt.roles.mock import MockMessages @pytest.mark.asyncio async def test_architect(): role = Architect() - role.recv(MockMessages.req) - rsp = await role.handle(MockMessages.prd) + role.put_message(MockMessages.req) + rsp = await role.run(MockMessages.prd) logger.info(rsp) assert len(rsp.content) > 0 diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index f44188c17..3dc599770 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -4,6 +4,8 @@ @Time : 2023/5/12 10:14 @Author : alexanderwu @File : test_engineer.py +@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message + distribution feature for message handling. """ import pytest @@ -22,10 +24,10 @@ from tests.metagpt.roles.mock import ( async def test_engineer(): engineer = Engineer() - engineer.recv(MockMessages.req) - engineer.recv(MockMessages.prd) - engineer.recv(MockMessages.system_design) - rsp = await engineer.handle(MockMessages.tasks) + engineer.put_message(MockMessages.req) + engineer.put_message(MockMessages.prd) + engineer.put_message(MockMessages.system_design) + rsp = await engineer.run(MockMessages.tasks) logger.info(rsp) assert "all done." == rsp.content diff --git a/tests/metagpt/test_message.py b/tests/metagpt/test_message.py index ae6708943..f52dfb64c 100644 --- a/tests/metagpt/test_message.py +++ b/tests/metagpt/test_message.py @@ -4,6 +4,7 @@ @Time : 2023/5/16 10:57 @Author : alexanderwu @File : test_message.py +@Modified By: mashenquan, 2023-11-1. Modify coding style. """ import pytest @@ -34,3 +35,8 @@ def test_raw_message(): assert msg["content"] == "raw" with pytest.raises(KeyError): assert msg["1"] == 1, "KeyError: '1'" + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) + diff --git a/tests/metagpt/test_role.py b/tests/metagpt/test_role.py index 22cfa58a4..f93651303 100644 --- a/tests/metagpt/test_role.py +++ b/tests/metagpt/test_role.py @@ -4,11 +4,100 @@ @Time : 2023/5/11 14:44 @Author : alexanderwu @File : test_role.py +@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for + the utilization of the new message distribution feature in message handling. +@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing + functionality is to be consolidated into the `Environment` class. """ +import uuid + +import pytest +from pydantic import BaseModel + +from metagpt.actions import Action, ActionOutput +from metagpt.environment import Environment from metagpt.roles import Role +from metagpt.schema import Message +from metagpt.utils.common import get_class_name -def test_role_desc(): - i = Role(profile="Sales", desc="Best Seller") - assert i.profile == "Sales" - assert i._setting.desc == "Best Seller" +class MockAction(Action): + async def run(self, messages, *args, **kwargs): + assert messages + return ActionOutput(content=messages[-1].content, instruct_content=messages[-1]) + + +class MockRole(Role): + def __init__(self, name="", profile="", goal="", constraints="", desc=""): + super().__init__(name=name, profile=profile, goal=goal, constraints=constraints, desc=desc) + self._init_actions([MockAction()]) + + +@pytest.mark.asyncio +async def test_react(): + class Input(BaseModel): + name: str + profile: str + goal: str + constraints: str + desc: str + subscription: str + + inputs = [ + { + "name": "A", + "profile": "Tester", + "goal": "Test", + "constraints": "constraints", + "desc": "desc", + "subscription": "start", + } + ] + + for i in inputs: + seed = Input(**i) + role = MockRole( + name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc + ) + role.subscribe({seed.subscription}) + assert role._rc.watch == set({}) + assert role.name == seed.name + assert role.profile == seed.profile + assert role._setting.goal == seed.goal + assert role._setting.constraints == seed.constraints + assert role._setting.desc == seed.desc + assert role.is_idle + env = Environment() + env.add_role(role) + assert env.get_subscription(role) == {seed.subscription} + env.publish_message(Message(content="test", msg_to=seed.subscription)) + assert not role.is_idle + while not env.is_idle: + await env.run() + assert role.is_idle + env.publish_message(Message(content="test", cause_by=seed.subscription)) + assert not role.is_idle + while not env.is_idle: + await env.run() + assert role.is_idle + tag = uuid.uuid4().hex + role.subscribe({tag}) + assert env.get_subscription(role) == {tag} + + +@pytest.mark.asyncio +async def test_msg_to(): + m = Message(content="a", send_to=["a", MockRole, Message]) + assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + + m = Message(content="a", cause_by=MockAction, send_to={"a", MockRole, Message}) + assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + + m = Message(content="a", send_to=("a", MockRole, Message)) + assert m.send_to == set({"a", get_class_name(MockRole), get_class_name(Message)}) + + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) + diff --git a/tests/metagpt/test_schema.py b/tests/metagpt/test_schema.py index c154d77e1..51ebd5baa 100644 --- a/tests/metagpt/test_schema.py +++ b/tests/metagpt/test_schema.py @@ -4,10 +4,19 @@ @Time : 2023/5/20 10:40 @Author : alexanderwu @File : test_schema.py +@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for + the utilization of the new feature of `Message` class. """ +import json + +import pytest + +from metagpt.actions import Action from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage +from metagpt.utils.common import get_class_name +@pytest.mark.asyncio def test_messages(): test_content = "test_message" msgs = [ @@ -19,3 +28,47 @@ def test_messages(): text = str(msgs) roles = ["user", "system", "assistant", "QA"] assert all([i in text for i in roles]) + + +@pytest.mark.asyncio +def test_message(): + m = Message("a", role="v1") + v = m.dump() + d = json.loads(v) + assert d + assert d.get("content") == "a" + assert d.get("role") == "v1" + m.role = "v2" + v = m.dump() + assert v + m = Message.load(v) + assert m.content == "a" + assert m.role == "v2" + + m = Message("a", role="b", cause_by="c", x="d", send_to="c") + assert m.content == "a" + assert m.role == "b" + assert m.send_to == {"c"} + assert m.cause_by == "c" + + m.cause_by = "Message" + assert m.cause_by == "Message" + m.cause_by = Action + assert m.cause_by == get_class_name(Action) + m.cause_by = Action() + assert m.cause_by == get_class_name(Action) + m.content = "b" + assert m.content == "b" + + +@pytest.mark.asyncio +def test_routes(): + m = Message("a", role="b", cause_by="c", x="d", send_to="c") + m.send_to = "b" + assert m.send_to == {"b"} + m.send_to = {"e", Action} + assert m.send_to == {"e", get_class_name(Action)} + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_common.py b/tests/metagpt/utils/test_common.py index d3837ca8f..6474b1233 100644 --- a/tests/metagpt/utils/test_common.py +++ b/tests/metagpt/utils/test_common.py @@ -4,13 +4,20 @@ @Time : 2023/4/29 16:19 @Author : alexanderwu @File : test_common.py +@Modified by: mashenquan, 2023/11/21. Add unit tests. """ import os +from typing import Any, Set import pytest +from pydantic import BaseModel +from metagpt.actions import RunCode from metagpt.const import get_project_root +from metagpt.roles.tutorial_assistant import TutorialAssistant +from metagpt.schema import Message +from metagpt.utils.common import any_to_str, any_to_str_set class TestGetProjectRoot: @@ -21,10 +28,55 @@ class TestGetProjectRoot: def test_get_project_root(self): project_root = get_project_root() - assert project_root.name == "metagpt" + assert project_root.name == "MetaGPT" def test_get_root_exception(self): with pytest.raises(Exception) as exc_info: self.change_etc_dir() get_project_root() assert str(exc_info.value) == "Project root not found." + + def test_any_to_str(self): + class Input(BaseModel): + x: Any + want: str + + inputs = [ + Input(x=TutorialAssistant, want="metagpt.roles.tutorial_assistant.TutorialAssistant"), + Input(x=TutorialAssistant(), want="metagpt.roles.tutorial_assistant.TutorialAssistant"), + Input(x=RunCode, want="metagpt.actions.run_code.RunCode"), + Input(x=RunCode(), want="metagpt.actions.run_code.RunCode"), + Input(x=Message, want="metagpt.schema.Message"), + Input(x=Message(""), want="metagpt.schema.Message"), + Input(x="A", want="A"), + ] + for i in inputs: + v = any_to_str(i.x) + assert v == i.want + + def test_any_to_str_set(self): + class Input(BaseModel): + x: Any + want: Set + + inputs = [ + Input( + x=[TutorialAssistant, RunCode(), "a"], + want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"}, + ), + Input( + x={TutorialAssistant, RunCode(), "a"}, + want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"}, + ), + Input( + x=(TutorialAssistant, RunCode(), "a"), + want={"metagpt.roles.tutorial_assistant.TutorialAssistant", "metagpt.actions.run_code.RunCode", "a"}, + ), + ] + for i in inputs: + v = any_to_str_set(i.x) + assert v == i.want + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index 69f317f79..ffa34866c 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -1,6 +1,8 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the unittest of serialize +""" +@Desc : the unittest of serialize +""" from typing import List, Tuple