feat: merge

This commit is contained in:
莘权 马 2023-12-28 18:05:33 +08:00
commit f76078dedf
95 changed files with 1629 additions and 948 deletions

View file

@ -54,8 +54,8 @@ # Step 2: Clone the repository to your local machine for latest version, and ins
# Step 3: setup your OPENAI_API_KEY, or make sure it existed in the env
mkdir ~/.metagpt
cp config/config.yaml ~/.metagpt/key.yaml
vim ~/.metagpt/key.yaml
cp config/config.yaml ~/.metagpt/config.yaml
vim ~/.metagpt/config.yaml
# Step 4: run metagpt cli
metagpt "Create a 2048 game in python"

View file

@ -17,7 +17,7 @@ MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text()
class CreateAgent(Action):
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
### BACKGROUND
You are using an agent framework called metagpt to write agents capable of different actions,
the usage of metagpt can be illustrated by the following example:
@ -64,9 +64,9 @@ class AgentCreator(Role):
self._init_actions([CreateAgent])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo
msg = self._rc.memory.get()[-1]
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo
msg = self.rc.memory.get()[-1]
instruction = msg.content
code_text = await CreateAgent().run(example=self.agent_template, instruction=instruction)

View file

@ -16,7 +16,7 @@ from metagpt.schema import Message
class SimpleWriteCode(Action):
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
Write a python function that can {instruction} and provide two runnnable test cases.
Return ```python your_code_here ``` with NO other texts,
your code:
@ -60,8 +60,8 @@ class SimpleCoder(Role):
self._init_actions([SimpleWriteCode])
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo # todo will be SimpleWriteCode()
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo # todo will be SimpleWriteCode()
msg = self.get_memories(k=1)[0] # find the most recent messages
code_text = await todo.run(msg.content)
@ -80,16 +80,16 @@ class RunnableCoder(Role):
self._set_react_mode(react_mode=RoleReactMode.BY_ORDER.value)
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
# By choosing the Action by order under the hood
# todo will be first SimpleWriteCode() then SimpleRunCode()
todo = self._rc.todo
todo = self.rc.todo
msg = self.get_memories(k=1)[0] # find the most k recent messages
result = await todo.run(msg.content)
msg = Message(content=result, role=self.profile, cause_by=type(todo))
self._rc.memory.add(msg)
self.rc.memory.add(msg)
return msg

View file

@ -22,7 +22,7 @@ def parse_code(rsp):
class SimpleWriteCode(Action):
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
Write a python function that can {instruction}.
Return ```python your_code_here ``` with NO other texts,
your code:
@ -50,7 +50,7 @@ class SimpleCoder(Role):
class SimpleWriteTest(Action):
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
Context: {context}
Write {k} unit tests using pytest for the given function, assuming you have imported it.
Return ```python your_code_here ``` with NO other texts,
@ -80,8 +80,8 @@ class SimpleTester(Role):
self._watch([SimpleWriteCode, SimpleWriteReview]) # feel free to try this too
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo
# context = self.get_memories(k=1)[0].content # use the most recent memory as context
context = self.get_memories() # use all memories as context
@ -93,7 +93,7 @@ class SimpleTester(Role):
class SimpleWriteReview(Action):
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
Context: {context}
Review the test cases and provide one critical comments:
"""

View file

@ -7,6 +7,7 @@ Author: garylin2099
"""
import asyncio
import platform
from typing import Any
import fire
@ -20,7 +21,7 @@ from metagpt.team import Team
class SpeakAloud(Action):
"""Action: Speak out aloud in a debate (quarrel)"""
PROMPT_TEMPLATE = """
PROMPT_TEMPLATE: str = """
## BACKGROUND
Suppose you are {name}, you are in a debate with {opponent_name}.
## DEBATE HISTORY
@ -30,9 +31,7 @@ class SpeakAloud(Action):
Now it's your turn, you should closely respond to your opponent's latest argument, state your position, defend your arguments, and attack your opponent's arguments,
craft a strong and emotional response in 80 words, in {name}'s rhetoric and viewpoints, your will argue:
"""
def __init__(self, name="SpeakAloud", context=None, llm=None):
super().__init__(name, context, llm)
name: str = "SpeakAloud"
async def run(self, context: str, name: str, opponent_name: str):
prompt = self.PROMPT_TEMPLATE.format(context=context, name=name, opponent_name=opponent_name)
@ -44,27 +43,24 @@ class SpeakAloud(Action):
class Debator(Role):
def __init__(
self,
name: str,
profile: str,
opponent_name: str,
**kwargs,
):
super().__init__(name, profile, **kwargs)
name: str = ""
profile: str = ""
opponent_name: str = ""
def __init__(self, **data: Any):
super().__init__(**data)
self._init_actions([SpeakAloud])
self._watch([UserRequirement, SpeakAloud])
self.opponent_name = opponent_name
async def _observe(self) -> int:
await super()._observe()
# accept messages sent (from opponent) to self, disregard own messages from the last round
self._rc.news = [msg for msg in self._rc.news if msg.send_to == {self.name}]
return len(self._rc.news)
self.rc.news = [msg for msg in self.rc.news if msg.send_to == {self.name}]
return len(self.rc.news)
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo # An instance of SpeakAloud
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo # An instance of SpeakAloud
memories = self.get_memories()
context = "\n".join(f"{msg.sent_from}: {msg.content}" for msg in memories)
@ -79,7 +75,7 @@ class Debator(Role):
sent_from=self.name,
send_to=self.opponent_name,
)
self._rc.memory.add(msg)
self.rc.memory.add(msg)
return msg

View file

@ -13,7 +13,7 @@ from metagpt.actions.add_requirement import UserRequirement
from metagpt.actions.debug_error import DebugError
from metagpt.actions.design_api import WriteDesign
from metagpt.actions.design_api_review import DesignReview
from metagpt.actions.project_management import AssignTasks, WriteTasks
from metagpt.actions.project_management import WriteTasks
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
from metagpt.actions.run_code import RunCode
from metagpt.actions.search_and_summarize import SearchAndSummarize
@ -38,7 +38,6 @@ class ActionType(Enum):
RUN_CODE = RunCode
DEBUG_ERROR = DebugError
WRITE_TASKS = WriteTasks
ASSIGN_TASKS = AssignTasks
SEARCH_AND_SUMMARIZE = SearchAndSummarize
COLLECT_LINKS = CollectLinks
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize

View file

@ -10,7 +10,7 @@ from __future__ import annotations
from typing import Any, Optional, Union
from pydantic import BaseModel, Field
from pydantic import ConfigDict, Field
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
@ -19,50 +19,31 @@ from metagpt.schema import (
CodeSummarizeContext,
CodingContext,
RunCodeContext,
SerializationMixin,
TestingContext,
)
action_subclass_registry = {}
class Action(SerializationMixin, is_polymorphic_base=True):
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
class Action(BaseModel):
name: str = ""
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
prefix = "" # aask*时会加上prefix作为system_message
desc = "" # for skill manager
prefix: str = "" # aask*时会加上prefix作为system_message
desc: str = "" # for skill manager
node: ActionNode = Field(default=None, exclude=True)
# builtin variables
builtin_class_name: str = ""
class Config:
arbitrary_types_allowed = True
def __init_with_instruction(self, instruction: str):
"""Initialize action with instruction"""
self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw")
return self
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
def __init__(self, **data: Any):
super().__init__(**data)
# deserialize child classes dynamically for inherited `action`
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
self.__fields__["builtin_class_name"].default = self.__class__.__name__
if "instruction" in kwargs:
self.__init_with_instruction(kwargs["instruction"])
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
action_subclass_registry[cls.__name__] = cls
def dict(self, *args, **kwargs) -> "DictStrAny":
obj_dict = super().dict(*args, **kwargs)
if "llm" in obj_dict:
obj_dict.pop("llm")
return obj_dict
if "instruction" in data:
self.__init_with_instruction(data["instruction"])
def set_prefix(self, prefix):
"""Set prefix for later usage"""

View file

@ -11,7 +11,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
import json
from typing import Any, Dict, List, Optional, Tuple, Type
from pydantic import BaseModel, create_model, root_validator, validator
from pydantic import BaseModel, create_model, field_validator, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.config import CONFIG
@ -137,13 +137,15 @@ class ActionNode:
"""基于pydantic v1的模型动态生成用来检验结果类型正确性"""
new_class = create_model(class_name, **mapping)
@validator("*", allow_reuse=True)
@field_validator("*", mode="before")
@classmethod
def check_name(v, field):
if field.name not in mapping.keys():
raise ValueError(f"Unrecognized block: {field.name}")
return v
@root_validator(pre=True, allow_reuse=True)
@model_validator(mode="before")
@classmethod
def check_missing_fields(values):
required_fields = set(mapping.keys())
missing_fields = required_fields - set(values.keys())
@ -273,7 +275,9 @@ class ActionNode:
output_class = self.create_model_class(output_class_name, output_data_mapping)
if schema == "json":
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key=f"[/{TAG}]")
parsed_data = llm_output_postprecess(
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
)
else: # using markdown parser
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
@ -282,7 +286,7 @@ class ActionNode:
return content, instruct_content
def get(self, key):
return self.instruct_content.dict()[key]
return self.instruct_content.model_dump()[key]
def set_recursive(self, name, value):
setattr(self, name, value)
@ -348,17 +352,3 @@ class ActionNode:
cls = self.create_children_class()
self.instruct_content = cls(**tmp)
return self
def action_node_example():
node = ActionNode(key="key-0", expected_type=str, instruction="instruction-a", example="example-b")
logger.info(node.compile(context="123", schema="raw", mode="auto"))
logger.info(node.compile(context="123", schema="json", mode="auto"))
logger.info(node.compile(context="123", schema="markdown", mode="auto"))
logger.info(node.to_dict())
logger.info(node)
if __name__ == "__main__":
action_node_example()

View file

@ -10,6 +10,3 @@ from metagpt.actions import Action
class UserRequirement(Action):
"""User Requirement without any implementation details"""
async def run(self, *args, **kwargs):
raise NotImplementedError

View file

@ -79,7 +79,7 @@ class WriteDesign(Action):
logger.info("Nothing has changed.")
# Wait until all files under `docs/system_designs/` are processed before sending the publish message,
# leaving room for global optimization in subsequent steps.
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
@ -88,7 +88,7 @@ class WriteDesign(Action):
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
@ -99,7 +99,7 @@ class WriteDesign(Action):
doc = Document(
root_path=SYSTEM_DESIGN_FILE_REPO,
filename=filename,
content=system_design.instruct_content.json(ensure_ascii=False),
content=system_design.instruct_content.model_dump_json(),
)
else:
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)

View file

@ -8,7 +8,6 @@
from typing import List
from metagpt.actions.action_node import ActionNode
from metagpt.logs import logger
from metagpt.utils.mermaid import MMC1, MMC2
IMPLEMENTATION_APPROACH = ActionNode(
@ -63,12 +62,3 @@ NODES = [
]
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
def main():
prompt = DESIGN_API_NODE.compile(context="")
logger.info(prompt)
if __name__ == "__main__":
main()

View file

@ -39,6 +39,8 @@ class PrepareDocuments(Action):
path = Path(CONFIG.project_path)
if path.exists() and not CONFIG.inc:
shutil.rmtree(path)
CONFIG.project_path = path
CONFIG.project_name = path.name
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
async def run(self, with_messages, **kwargs):

View file

@ -73,7 +73,7 @@ class WriteTasks(Action):
logger.info("Nothing has changed.")
# Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for
# global optimization in subsequent steps.
return ActionOutput(content=change_files.json(), instruct_content=change_files)
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
system_design_doc = await system_design_file_repo.get(filename)
@ -83,7 +83,7 @@ class WriteTasks(Action):
else:
rsp = await self._run_new_tasks(context=system_design_doc.content)
task_doc = Document(
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.json(ensure_ascii=False)
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
)
await tasks_file_repo.save(
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
@ -102,7 +102,7 @@ class WriteTasks(Action):
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
node = await PM_NODE.fill(context, self.llm, schema)
task_doc.content = node.instruct_content.json(ensure_ascii=False)
task_doc.content = node.instruct_content.model_dump_json()
return task_doc
@staticmethod
@ -123,9 +123,3 @@ class WriteTasks(Action):
@staticmethod
async def _save_pdf(task_doc):
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
class AssignTasks(Action):
async def run(self, *args, **kwargs):
# Here you should implement the actual action
pass

View file

@ -50,7 +50,7 @@ class RebuildClassView(Action):
# try:
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
# class_view = node.instruct_content.dict()["Class View"]
# class_view = node.instruct_content.model_dump()["Class View"]
# except Exception as e:
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)

View file

@ -84,8 +84,9 @@ class CollectLinks(Action):
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
rank_func: Union[Callable[[list[str]], None], None] = None
rank_func: Optional[Callable[[list[str]], None]] = None
async def run(
self,
@ -181,18 +182,18 @@ class WebBrowseAndSummarize(Action):
llm: BaseLLM = Field(default_factory=LLM)
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: WebBrowserEngine = Field(
default_factory=lambda: WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
run_func=WebBrowseAndSummarize.browse_func,
)
)
web_browser_engine: Optional[WebBrowserEngine] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
if CONFIG.model_for_researcher_summary:
self.llm.model = CONFIG.model_for_researcher_summary
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
run_func=self.browse_func,
)
async def run(
self,
url: str,

View file

@ -82,11 +82,13 @@ class RunCode(Action):
llm: BaseLLM = Field(default_factory=LLM)
@classmethod
@handle_exception
async def run_text(cls, code) -> Tuple[str, str]:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
try:
# We will document_store the result in this dictionary
namespace = {}
exec(code, namespace)
except Exception as e:
return "", str(e)
return namespace.get("result", ""), ""
@classmethod

View file

@ -8,7 +8,7 @@
from typing import Any, Optional
import pydantic
from pydantic import Field, root_validator
from pydantic import Field, model_validator
from metagpt.actions import Action
from metagpt.config import CONFIG, Config
@ -114,10 +114,10 @@ class SearchAndSummarize(Action):
engine: Optional[SearchEngineType] = CONFIG.search_engine
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""
result = ""
@root_validator
@model_validator(mode="before")
@classmethod
def validate_engine_and_run_func(cls, values):
engine = values.get("engine")
search_func = values.get("search_func")

View file

@ -80,7 +80,7 @@ class WritePRD(Action):
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
return Message(
content=bug_fix.json(),
content=bug_fix.model_dump_json(),
instruct_content=bug_fix,
role="",
cause_by=FixBug,
@ -112,7 +112,7 @@ class WritePRD(Action):
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
# 'publish' message to transition the workflow to the next stage. This design allows room for global
# optimization in subsequent steps.
return ActionOutput(content=change_files.json(), instruct_content=change_files)
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
# sas = SearchAndSummarize()
@ -139,7 +139,7 @@ class WritePRD(Action):
CONFIG.project_name = Path(CONFIG.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
prd_doc.content = node.instruct_content.json(ensure_ascii=False)
prd_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return prd_doc
@ -151,7 +151,7 @@ class WritePRD(Action):
new_prd_doc = Document(
root_path=PRDS_FILE_REPO,
filename=FileRepository.new_filename() + ".json",
content=prd.instruct_content.json(ensure_ascii=False),
content=prd.instruct_content.model_dump_json(),
)
elif await self._is_relative(requirement_doc, prd_doc):
new_prd_doc = await self._merge(requirement_doc, prd_doc)
@ -181,18 +181,13 @@ class WritePRD(Action):
@staticmethod
async def _rename_workspace(prd):
if CONFIG.project_path: # Updating on the old version has already been specified if it's valid. According to
# Section 2.2.3.10 of RFC 135
if not CONFIG.project_name:
CONFIG.project_name = Path(CONFIG.project_path).name
return
if not CONFIG.project_name:
if isinstance(prd, (ActionOutput, ActionNode)):
ws_name = prd.instruct_content.dict()["Project Name"]
ws_name = prd.instruct_content.model_dump()["Project Name"]
else:
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
CONFIG.project_name = ws_name
if ws_name:
CONFIG.project_name = ws_name
CONFIG.git_repo.rename_root(CONFIG.project_name)
async def _is_bugfix(self, context) -> bool:

View file

@ -19,6 +19,7 @@ class WritePRDReview(Action):
name: str = ""
context: Optional[str] = None
llm: BaseLLM = Field(default_factory=LLM)
prd: Optional[str] = None
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
prd_review_prompt_template: str = """

View file

@ -72,6 +72,7 @@ class Config(metaclass=Singleton):
self.inc = False
self.reqa_file = ""
self.max_auto_summarize_code = 0
self.git_reinit = False
self._init_with_config_files_and_env(yaml_file)
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
@ -146,7 +147,7 @@ class Config(metaclass=Singleton):
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
_ = self.get_default_llm_provider_enum()
# self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_base_url = self._get("OPENAI_BASE_URL")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
self.openai_api_type = self._get("OPENAI_API_TYPE")
self.openai_api_version = self._get("OPENAI_API_VERSION")

View file

@ -17,7 +17,7 @@ from langchain.document_loaders import (
UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import CharacterTextSplitter
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
from metagpt.config import CONFIG
@ -117,13 +117,12 @@ class IndexableDocument(Document):
Advanced document handling: For vector databases or search engines.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
data: Union[pd.DataFrame, list]
content_col: Optional[str] = Field(default="")
meta_col: Optional[str] = Field(default="")
class Config:
arbitrary_types_allowed = True
@classmethod
def from_path(cls, data_path: Path, content_col="content", meta_col="metadata"):
if not data_path.exists():

View file

@ -1,111 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/5/28 00:00
@Author : alexanderwu
@File : milvus_store.py
"""
from typing import TypedDict
import numpy as np
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
from metagpt.document_store.base_store import BaseStore
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
"""Assume the structure of columns is str: regular type"""
fields = []
for col, ctype in columns.items():
if ctype == str:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], max_length=100)
elif ctype == np.ndarray:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], dim=2)
else:
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], is_primary=(col == primary_col_name))
fields.append(mcol)
schema = CollectionSchema(fields, description=desc)
return schema
class MilvusConnection(TypedDict):
alias: str
host: str
port: str
class MilvusStore(BaseStore):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/create_collection.md
"""
def __init__(self, connection):
connections.connect(**connection)
self.collection = None
def _create_collection(self, name, schema):
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
return collection
def create_collection(self, name, columns):
schema = columns_to_milvus_schema(columns, "idx")
self.collection = self._create_collection(name, schema)
return self.collection
def drop(self, name):
Collection(name).drop()
def load_collection(self):
self.collection.load()
def build_index(self, field="emb"):
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
def search(self, query: list[list[float]], *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/search.md
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
Note the above description, is this logic serious? This should take a long time, right?
"""
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
results = self.collection.search(
data=query,
anns_field=kwargs.get("field", "emb"),
param=search_params,
limit=10,
expr=None,
consistency_level="Strong",
)
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
return results
def write(self, name, schema, *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/create_collection.md
:param args:
:param kwargs:
:return:
"""
raise NotImplementedError
def add(self, data, *args, **kwargs):
"""
FIXME: ADD TESTS
https://milvus.io/docs/v2.0.x/insert_data.md
import random
data = [
[i for i in range(2000)],
[i for i in range(10000, 12000)],
[[random.random() for _ in range(2)] for _ in range(2000)],
]
:param args:
:param kwargs:
:return:
"""
self.collection.insert(data)

View file

@ -15,11 +15,11 @@ import asyncio
from pathlib import Path
from typing import Iterable, Set
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.config import CONFIG
from metagpt.logs import logger
from metagpt.roles.role import Role, role_subclass_registry
from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
@ -29,30 +29,17 @@ class Environment(BaseModel):
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
desc: str = Field(default="") # 环境描述
roles: dict[str, Role] = Field(default_factory=dict)
members: dict[Role, Set] = Field(default_factory=dict)
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
history: str = "" # For debug
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs):
roles = []
for role_key, role in kwargs.get("roles", {}).items():
current_role = kwargs["roles"][role_key]
if isinstance(current_role, dict):
item_class_name = current_role.get("builtin_class_name", None)
for name, subclass in role_subclass_registry.items():
registery_class_name = subclass.__fields__["builtin_class_name"].default
if item_class_name == registery_class_name:
current_role = subclass(**current_role)
break
kwargs["roles"][role_key] = current_role
roles.append(current_role)
super().__init__(**kwargs)
self.add_roles(roles) # add_roles again to init the Role.set_env
@model_validator(mode="after")
def init_roles(self):
self.add_roles(self.roles.values())
return self
def serialize(self, stg_path: Path):
roles_path = stg_path.joinpath("roles.json")

View file

@ -4,7 +4,7 @@
@Time : 2023/6/5 01:44
@Author : alexanderwu
@File : skill_manager.py
@Modified By: mashenquan, 2023/8/20. Remove useless `_llm`
@Modified By: mashenquan, 2023/8/20. Remove useless `llm`
"""
from metagpt.actions import Action
from metagpt.const import PROMPT_PATH

View file

@ -73,7 +73,7 @@ class BrainMemory(BaseModel):
redis = Redis()
if not redis.is_valid or not redis_key:
return False
v = self.json(ensure_ascii=False)
v = self.model_dump_json()
if self.cacheable:
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
logger.debug(f"REDIS SET {redis_key} {v}")
@ -99,6 +99,7 @@ class BrainMemory(BaseModel):
if msg.id:
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
return
self.history.append(msg)
self.last_history_id = str(msg.id)
self.is_dirty = True
@ -156,7 +157,7 @@ class BrainMemory(BaseModel):
if left == 0:
break
m.content = m.content[0:left]
msgs.append(m.dict())
msgs.append(m.model_dump())
break
msgs.append(m)
total_length += delta

View file

@ -7,7 +7,7 @@
from typing import Optional
from pydantic import Field
from pydantic import ConfigDict, Field
from metagpt.logs import logger
from metagpt.memory import Memory
@ -22,13 +22,12 @@ class LongTermMemory(Memory):
- update memory when it changed
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
memory_storage: MemoryStorage = Field(default_factory=MemoryStorage)
rc: Optional["RoleContext"] = None
msg_from_recover: bool = False
class Config:
arbitrary_types_allowed = True
def recover_memory(self, role_id: str, rc: "RoleContext"):
messages = self.memory_storage.recover_memory(role_id)
self.rc = rc

View file

@ -8,9 +8,9 @@
"""
from collections import defaultdict
from pathlib import Path
from typing import Iterable, Set
from typing import DefaultDict, Iterable, Set
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, SerializeAsAny
from metagpt.const import IGNORED_MESSAGE_ID
from metagpt.schema import Message
@ -25,23 +25,14 @@ from metagpt.utils.common import (
class Memory(BaseModel):
"""The most basic memory: super-memory"""
storage: list[Message] = []
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
storage: list[SerializeAsAny[Message]] = []
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
ignore_id: bool = False
def __init__(self, **kwargs):
index = kwargs.get("index", {})
new_index = defaultdict(list)
for action_str, value in index.items():
new_index[action_str] = [Message(**item_dict) for item_dict in value]
kwargs["index"] = new_index
super(Memory, self).__init__(**kwargs)
self.index = new_index
def serialize(self, stg_path: Path):
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
memory_path = stg_path.joinpath("memory.json")
storage = self.dict()
storage = self.model_dump()
write_json_file(memory_path, storage)
@classmethod

View file

@ -69,7 +69,7 @@ class OpenAILLM(BaseLLM):
self.aclient = AsyncOpenAI(**kwargs)
def _make_client_kwargs(self) -> dict:
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
# to use proxy, openai v1 needs http_client
if proxy_params := self._get_proxy_params():
@ -81,8 +81,8 @@ class OpenAILLM(BaseLLM):
params = {}
if self.config.openai_proxy:
params = {"proxies": self.config.openai_proxy}
if self.config.OPENAI_BASE_URL:
params["base_url"] = self.config.OPENAI_BASE_URL
if self.config.openai_base_url:
params["base_url"] = self.config.openai_base_url
return params

View file

@ -65,22 +65,20 @@ class Assistant(Role):
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
rsp = await self._llm.aask(prompt, [])
rsp = await self.llm.aask(prompt, [])
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
return await self._plan(rsp, last_talk=last_talk)
async def act(self) -> Message:
result = await self._rc.todo.run()
result = await self.rc.todo.run()
if not result:
return None
if isinstance(result, str):
msg = Message(content=result, role="assistant", cause_by=self._rc.todo)
msg = Message(content=result, role="assistant", cause_by=self.rc.todo)
elif isinstance(result, Message):
msg = result
else:
msg = Message(
content=result.content, instruct_content=result.instruct_content, cause_by=type(self._rc.todo)
)
msg = Message(content=result.content, instruct_content=result.instruct_content, cause_by=type(self.rc.todo))
self.memory.add_answer(msg)
return msg
@ -99,8 +97,8 @@ class Assistant(Role):
async def talk_handler(self, text, **kwargs) -> bool:
history = self.memory.history_text
text = kwargs.get("last_talk") or text
self._rc.todo = TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self._llm, **kwargs
self.rc.todo = TalkAction(
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
)
return True
@ -110,13 +108,11 @@ class Assistant(Role):
if not skill:
logger.info(f"skill not found: {text}")
return await self.talk_handler(text=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self._llm, ask=last_talk, **kwargs)
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
await action.run(**kwargs)
if action.args is None:
return await self.talk_handler(text=last_talk, **kwargs)
self._rc.todo = SkillAction(
skill=skill, args=action.args, llm=self._llm, name=skill.name, desc=skill.description
)
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
return True
async def refine_memory(self) -> str:
@ -125,16 +121,16 @@ class Assistant(Role):
return None
if not self.memory.is_history_available:
return last_talk
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self._llm)
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm):
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self.llm)
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self.llm):
# Merge relevant content.
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self._llm)
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self.llm)
return f"{merged} {last_talk}"
return last_talk
def get_memory(self) -> str:
return self.memory.json()
return self.memory.model_dump_json()
def load_memory(self, jsn):
try:

View file

@ -109,7 +109,7 @@ class Engineer(Role):
coding_context = await todo.run()
# Code review
if review:
action = WriteCodeReview(context=coding_context, llm=self._llm)
action = WriteCodeReview(context=coding_context, llm=self.llm)
self._init_action_system_message(action)
coding_context = await action.run()
await src_file_repo.save(
@ -118,9 +118,12 @@ class Engineer(Role):
content=coding_context.code_doc.content,
)
msg = Message(
content=coding_context.json(), instruct_content=coding_context, role=self.profile, cause_by=WriteCode
content=coding_context.model_dump_json(),
instruct_content=coding_context,
role=self.profile,
cause_by=WriteCode,
)
self._rc.memory.add(msg)
self.rc.memory.add(msg)
changed_files.add(coding_context.code_doc.filename)
if not changed_files:
@ -129,12 +132,12 @@ class Engineer(Role):
async def _act(self) -> Message | None:
"""Determines the mode of action based on whether code review is used."""
if self._rc.todo is None:
if self.rc.todo is None:
return None
if isinstance(self._rc.todo, WriteCode):
if isinstance(self.rc.todo, WriteCode):
self.next_todo_action = any_to_name(SummarizeCode)
return await self._act_write_code()
if isinstance(self._rc.todo, SummarizeCode):
if isinstance(self.rc.todo, SummarizeCode):
self.next_todo_action = any_to_name(WriteCode)
return await self._act_summarize()
return None
@ -170,7 +173,7 @@ class Engineer(Role):
tasks.append(todo.context.dict())
await code_summaries_file_repo.save(
filename=Path(todo.context.design_filename).name,
content=todo.context.json(),
content=todo.context.model_dump_json(),
dependencies=dependencies,
)
else:
@ -193,7 +196,7 @@ class Engineer(Role):
)
async def _is_pass(self, summary) -> (str, str):
rsp = await self._llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
logger.info(rsp)
if "YES" in rsp:
return True, rsp
@ -204,17 +207,17 @@ class Engineer(Role):
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
if not self._rc.news:
if not self.rc.news:
return None
msg = self._rc.news[0]
msg = self.rc.news[0]
if msg.cause_by in write_code_filters:
logger.debug(f"TODO WriteCode:{msg.json()}")
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
return self._rc.todo
return self.rc.todo
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
logger.debug(f"TODO SummarizeCode:{msg.json()}")
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
await self._new_summarize_actions()
return self._rc.todo
return self.rc.todo
return None
@staticmethod
@ -241,7 +244,9 @@ class Engineer(Role):
context = await Engineer._new_coding_context(
filename, src_file_repo, task_file_repo, design_file_repo, dependency
)
coding_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content=context.json())
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
)
return coding_doc
async def _new_code_actions(self, bug_fix=False):
@ -266,15 +271,15 @@ class Engineer(Role):
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
)
coding_doc = Document(
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.json()
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
)
if task_filename in changed_files.docs:
logger.warning(
f"Log to expose potential conflicts: {coding_doc.json()} & "
f"{changed_files.docs[task_filename].json()}"
f"Log to expose potential conflicts: {coding_doc.model_dump_json()} & "
f"{changed_files.docs[task_filename].model_dump_json()}"
)
changed_files.docs[task_filename] = coding_doc
self.code_todos = [WriteCode(context=i, llm=self._llm) for i in changed_files.docs.values()]
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
# Code directly modified by the user.
dependency = await CONFIG.git_repo.get_dependency()
for filename in changed_src_files:
@ -288,10 +293,10 @@ class Engineer(Role):
dependency=dependency,
)
changed_files.docs[filename] = coding_doc
self.code_todos.append(WriteCode(context=coding_doc, llm=self._llm))
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
if self.code_todos:
self._rc.todo = self.code_todos[0]
self.rc.todo = self.code_todos[0]
async def _new_summarize_actions(self):
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
@ -304,9 +309,9 @@ class Engineer(Role):
summarizations[ctx].append(filename)
for ctx, filenames in summarizations.items():
ctx.codes_filenames = filenames
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm))
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
if self.summarize_todos:
self._rc.todo = self.summarize_todos[0]
self.rc.todo = self.summarize_todos[0]
@property
def todo(self) -> str:

View file

@ -69,8 +69,8 @@ class InvoiceOCRAssistant(Role):
Returns:
A message containing the result of the action.
"""
msg = self._rc.memory.get(k=1)[0]
todo = self._rc.todo
msg = self.rc.memory.get(k=1)[0]
todo = self.rc.todo
if isinstance(todo, InvoiceOCR):
self.origin_query = msg.content
invoice_path: InvoicePath = msg.instruct_content
@ -87,11 +87,11 @@ class InvoiceOCRAssistant(Role):
else:
self._init_actions([GenerateTable])
self._rc.todo = None
self.rc.todo = None
content = INVOICE_OCR_SUCCESS
resp = OCRResults(ocr_result=json.dumps(resp))
msg = Message(content=content, instruct_content=resp)
self._rc.memory.add(msg)
self.rc.memory.add(msg)
return await super().react()
elif isinstance(todo, GenerateTable):
ocr_results: OCRResults = msg.instruct_content
@ -108,5 +108,5 @@ class InvoiceOCRAssistant(Role):
resp = ReplyData(content=resp)
msg = Message(content=content, instruct_content=resp)
self._rc.memory.add(msg)
self.rc.memory.add(msg)
return msg

View file

@ -40,12 +40,13 @@ class ProductManager(Role):
async def _think(self) -> bool:
"""Decide what to do"""
if CONFIG.git_repo:
if CONFIG.git_repo and not CONFIG.git_reinit:
self._set_state(1)
else:
self._set_state(0)
CONFIG.git_reinit = False
self.todo_action = any_to_name(WritePRD)
return bool(self._rc.todo)
return bool(self.rc.todo)
async def _observe(self, ignore_memory=False) -> int:
return await super()._observe(ignore_memory=True)

View file

@ -69,7 +69,7 @@ class QaEngineer(Role):
)
logger.info(f"Writing {test_doc.filename}..")
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
context = await WriteTest(context=context, llm=self._llm).run()
context = await WriteTest(context=context, llm=self.llm).run()
await tests_file_repo.save(
filename=context.test_doc.filename,
content=context.test_doc.content,
@ -86,7 +86,7 @@ class QaEngineer(Role):
)
self.publish_message(
Message(
content=run_code_context.json(),
content=run_code_context.model_dump_json(),
role=self.profile,
cause_by=WriteTest,
sent_from=self,
@ -106,11 +106,11 @@ class QaEngineer(Role):
return
run_code_context.code = src_doc.content
run_code_context.test_code = test_doc.content
result = await RunCode(context=run_code_context, llm=self._llm).run()
result = await RunCode(context=run_code_context, llm=self.llm).run()
run_code_context.output_filename = run_code_context.test_filename + ".json"
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
filename=run_code_context.output_filename,
content=result.json(),
content=result.model_dump_json(),
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
)
run_code_context.code = None
@ -120,7 +120,7 @@ class QaEngineer(Role):
mappings = {"Engineer": "Alex", "QaEngineer": "Edward"}
self.publish_message(
Message(
content=run_code_context.json(),
content=run_code_context.model_dump_json(),
role=self.profile,
cause_by=RunCode,
sent_from=self,
@ -130,14 +130,14 @@ class QaEngineer(Role):
async def _debug_error(self, msg):
run_code_context = RunCodeContext.loads(msg.content)
code = await DebugError(context=run_code_context, llm=self._llm).run()
code = await DebugError(context=run_code_context, llm=self.llm).run()
await FileRepository.save_file(
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
)
run_code_context.output = None
self.publish_message(
Message(
content=run_code_context.json(),
content=run_code_context.model_dump_json(),
role=self.profile,
cause_by=DebugError,
sent_from=self,
@ -159,7 +159,7 @@ class QaEngineer(Role):
code_filters = any_to_str_set({SummarizeCode})
test_filters = any_to_str_set({WriteTest, DebugError})
run_filters = any_to_str_set({RunCode})
for msg in self._rc.news:
for msg in self.rc.news:
# Decide what to do based on observed msg type, currently defined by human,
# might potentially be moved to _think, that is, let the agent decides for itself
if msg.cause_by in code_filters:

View file

@ -6,6 +6,7 @@
"""
import asyncio
import re
from pydantic import BaseModel
@ -41,20 +42,20 @@ class Researcher(Role):
logger.warning(f"The language `{self.language}` has not been tested, it may not work.")
async def _think(self) -> bool:
if self._rc.todo is None:
if self.rc.todo is None:
self._set_state(0)
return True
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
if self.rc.state + 1 < len(self.states):
self._set_state(self.rc.state + 1)
else:
self._rc.todo = None
self.rc.todo = None
return False
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
todo = self._rc.todo
msg = self._rc.memory.get(k=1)[0]
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
todo = self.rc.todo
msg = self.rc.memory.get(k=1)[0]
if isinstance(msg.instruct_content, Report):
instruct_content = msg.instruct_content
topic = instruct_content.topic
@ -78,14 +79,14 @@ class Researcher(Role):
else:
summaries = instruct_content.summaries
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
content = await self.rc.todo.run(topic, summary_text, system_text=research_system_text)
ret = Message(
content="",
instruct_content=Report(topic=topic, content=content),
role=self.profile,
cause_by=self._rc.todo,
cause_by=self.rc.todo,
)
self._rc.memory.add(ret)
self.rc.memory.add(ret)
return ret
def research_system_text(self, topic, current_task: Action) -> str:
@ -107,9 +108,11 @@ class Researcher(Role):
return msg
def write_report(self, topic: str, content: str):
filename = re.sub(r'[\\/:"*?<>|]+', " ", topic)
filename = filename.replace("\n", "")
if not RESEARCH_PATH.exists():
RESEARCH_PATH.mkdir(parents=True)
filepath = RESEARCH_PATH / f"{topic}.md"
filepath = RESEARCH_PATH / f"{filename}.md"
filepath.write_text(content)

View file

@ -10,8 +10,8 @@
consolidated within the `_observe` function.
2. Standardize the message filtering for string label matching. Role objects can access the message labels
they've subscribed to through the `subscribed_tags` property.
3. Move the message receive buffer from the global variable `self._rc.env.memory` to the role's private variable
`self._rc.msg_buffer` for easier message identification and asynchronous appending of messages.
3. Move the message receive buffer from the global variable `self.rc.env.memory` to the role's private variable
`self.rc.msg_buffer` for easier message identification and asynchronous appending of messages.
4. Standardize the way messages are passed: `publish_message` sends messages out, while `put_message` places
messages into the Role object's private message receive buffer. There are no other message transmit methods.
5. Standardize the parameters for the `run` function: the `test_message` parameter is used for testing purposes
@ -24,12 +24,11 @@ from __future__ import annotations
from enum import Enum
from pathlib import Path
from typing import Any, Iterable, Set, Type
from typing import Any, Iterable, Optional, Set, Type
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
from metagpt.actions import Action, ActionOutput
from metagpt.actions.action import action_subclass_registry
from metagpt.actions.action_node import ActionNode
from metagpt.actions.add_requirement import UserRequirement
from metagpt.const import SERDESER_PATH
@ -37,7 +36,7 @@ from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory
from metagpt.provider.base_llm import BaseLLM
from metagpt.schema import Message, MessageQueue
from metagpt.schema import Message, MessageQueue, SerializationMixin
from metagpt.utils.common import (
any_to_name,
any_to_str,
@ -92,8 +91,10 @@ class RoleReactMode(str, Enum):
class RoleContext(BaseModel):
"""Role Runtime Context"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
env: "Environment" = Field(default=None, exclude=True)
env: "Environment" = Field(default=None, exclude=True) # # avoid circular import
# TODO judge if ser&deser
msg_buffer: MessageQueue = Field(
default_factory=MessageQueue, exclude=True
@ -109,9 +110,6 @@ class RoleContext(BaseModel):
) # see `Role._set_react_mode` for definitions of the following two attributes
max_react_loop: int = 1
class Config:
arbitrary_types_allowed = True
def check(self, role_id: str):
# if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
# self.long_term_memory.recover_memory(role_id, self)
@ -128,12 +126,11 @@ class RoleContext(BaseModel):
return self.memory.get()
role_subclass_registry = {}
class Role(BaseModel):
class Role(SerializationMixin, is_polymorphic_base=True):
"""Role/Agent"""
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
name: str = ""
profile: str = ""
goal: str = ""
@ -141,84 +138,40 @@ class Role(BaseModel):
desc: str = ""
is_human: bool = False
_llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message
_role_id: str = ""
_states: list[str] = []
_actions: list[Action] = []
_rc: RoleContext = Field(default_factory=RoleContext)
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
role_id: str = ""
states: list[str] = []
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
rc: RoleContext = Field(default_factory=RoleContext)
subscription: set[str] = set()
# builtin variables
recovered: bool = False # to tag if a recovered role
latest_observed_msg: Message = None # record the latest observed message when interrupted
builtin_class_name: str = ""
_private_attributes = {
"_llm": None,
"_role_id": _role_id,
"_states": [],
"_actions": [],
"_rc": RoleContext(),
"_subscription": set(),
}
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
class Config:
arbitrary_types_allowed = True
exclude = ["_llm"]
@model_validator(mode="after")
def check_subscription(self) -> set:
if not self.subscription:
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
return self
def __init__(self, **kwargs: Any):
for index in range(len(kwargs.get("_actions", []))):
current_action = kwargs["_actions"][index]
if isinstance(current_action, dict):
item_class_name = current_action.get("builtin_class_name", None)
for name, subclass in action_subclass_registry.items():
registery_class_name = subclass.__fields__["builtin_class_name"].default
if item_class_name == registery_class_name:
current_action = subclass(**current_action)
break
kwargs["_actions"][index] = current_action
def __init__(self, **data: Any):
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
from metagpt.environment import Environment
super().__init__(**kwargs)
Environment
# ------
Role.model_rebuild()
super().__init__(**data)
# 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655
self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider()
self._private_attributes["_role_id"] = str(self._setting)
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
for key in self._private_attributes.keys():
if key in kwargs:
object.__setattr__(self, key, kwargs[key])
if key == "_rc":
_rc = RoleContext(**kwargs["_rc"])
object.__setattr__(self, "_rc", _rc)
else:
if key == "_rc":
# # Warning, if use self._private_attributes["_rc"],
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
object.__setattr__(self, key, RoleContext())
else:
object.__setattr__(self, key, self._private_attributes[key])
self._llm.system_prompt = self._get_prefix()
# deserialize child classes dynamically for inherited `role`
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
self.__fields__["builtin_class_name"].default = self.__class__.__name__
if "actions" in kwargs:
self._init_actions(kwargs["actions"])
self._watch(kwargs.get("watch") or [UserRequirement])
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
role_subclass_registry[cls.__name__] = cls
self.llm.system_prompt = self._get_prefix()
self._watch(data.get("watch") or [UserRequirement])
def _reset(self):
object.__setattr__(self, "_states", [])
object.__setattr__(self, "_actions", [])
self.states = []
self.actions = []
@property
def _setting(self):
@ -231,12 +184,12 @@ class Role(BaseModel):
else stg_path
)
role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True})
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
role_info_path = stg_path.joinpath("role_info.json")
write_json_file(role_info_path, role_info)
self._rc.memory.serialize(stg_path) # serialize role's memory alone
self.rc.memory.serialize(stg_path) # serialize role's memory alone
@classmethod
def deserialize(cls, stg_path: Path) -> "Role":
@ -260,13 +213,13 @@ class Role(BaseModel):
action.set_prefix(self._get_prefix())
def refresh_system_message(self):
self._llm.system_prompt = self._get_prefix()
self.llm.system_prompt = self._get_prefix()
def set_recovered(self, recovered: bool = False):
self.recovered = recovered
def set_memory(self, memory: Memory):
self._rc.memory = memory
self.rc.memory = memory
def init_actions(self, actions):
self._init_actions(actions)
@ -276,7 +229,7 @@ class Role(BaseModel):
for idx, action in enumerate(actions):
if not isinstance(action, Action):
## 默认初始化
i = action(name="", llm=self._llm)
i = action(name="", llm=self.llm)
else:
if self.is_human and not isinstance(action.llm, HumanProvider):
logger.warning(
@ -285,10 +238,9 @@ class Role(BaseModel):
f"try passing in Action classes instead of initialized instances"
)
i = action
# i.set_env(self._rc.env)
self._init_action_system_message(i)
self._actions.append(i)
self._states.append(f"{idx}. {action}")
self.actions.append(i)
self.states.append(f"{idx}. {action}")
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1):
"""Set strategy of the Role reacting to observed Message. Variation lies in how
@ -307,20 +259,20 @@ class Role(BaseModel):
Defaults to 1, i.e. _think -> _act (-> return result and end)
"""
assert react_mode in RoleReactMode.values(), f"react_mode must be one of {RoleReactMode.values()}"
self._rc.react_mode = react_mode
self.rc.react_mode = react_mode
if react_mode == RoleReactMode.REACT:
self._rc.max_react_loop = max_react_loop
self.rc.max_react_loop = max_react_loop
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
buffer during _observe.
"""
self._rc.watch = {any_to_str(t) for t in actions}
self.rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self._rc.check(self._role_id)
self.rc.check(self.role_id)
def is_watch(self, caused_by: str):
return caused_by in self._rc.watch
return caused_by in self.rc.watch
def subscribe(self, tags: Set[str]):
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
@ -328,19 +280,19 @@ class Role(BaseModel):
or profile.
"""
self.subscription = tags
if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
self._rc.env.set_subscription(self, self.subscription)
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
self.rc.env.set_subscription(self, self.subscription)
def _set_state(self, state: int):
"""Update the current state."""
self._rc.state = state
logger.debug(f"actions={self._actions}, state={state}")
self._rc.todo = self._actions[self._rc.state] if state >= 0 else None
self.rc.state = state
logger.debug(f"actions={self.actions}, state={state}")
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
def set_env(self, env: "Environment"):
"""Set the environment in which the role works. The role can talk to the environment and can also receive
messages by observing."""
self._rc.env = env
self.rc.env = env
if env:
env.set_subscription(self, self.subscription)
self.refresh_system_message() # add env message to system message
@ -348,7 +300,7 @@ class Role(BaseModel):
@property
def action_count(self):
"""Return number of action"""
return len(self._actions)
return len(self.actions)
def _get_prefix(self):
"""Get the role prefix"""
@ -360,38 +312,38 @@ class Role(BaseModel):
if self.constraints:
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
if self._rc.env and self._rc.env.desc:
other_role_names = ", ".join(self._rc.env.role_names())
env_desc = f"You are in {self._rc.env.desc} with roles({other_role_names})."
if self.rc.env and self.rc.env.desc:
other_role_names = ", ".join(self.rc.env.role_names())
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
prefix += env_desc
return prefix
async def _think(self) -> bool:
"""Consider what to do and decide on the next course of action. Return false if nothing can be done."""
if len(self._actions) == 1:
if len(self.actions) == 1:
# If there is only one action, then only this one can be performed
self._set_state(0)
return True
if self.recovered and self._rc.state >= 0:
self._set_state(self._rc.state) # action to run from recovered state
if self.recovered and self.rc.state >= 0:
self._set_state(self.rc.state) # action to run from recovered state
self.set_recovered(False) # avoid max_react_loop out of work
return True
prompt = self._get_prefix()
prompt += STATE_TEMPLATE.format(
history=self._rc.history,
states="\n".join(self._states),
n_states=len(self._states) - 1,
previous_state=self._rc.state,
history=self.rc.history,
states="\n".join(self.states),
n_states=len(self.states) - 1,
previous_state=self.rc.state,
)
next_state = await self._llm.aask(prompt)
next_state = await self.llm.aask(prompt)
next_state = extract_state_value_from_output(next_state)
logger.debug(f"{prompt=}")
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)):
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self.states)):
logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1")
next_state = -1
else:
@ -402,21 +354,21 @@ class Role(BaseModel):
return True
async def _act(self) -> Message:
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
response = await self._rc.todo.run(self._rc.history)
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
response = await self.rc.todo.run(self.rc.history)
if isinstance(response, (ActionOutput, ActionNode)):
msg = Message(
content=response.content,
instruct_content=response.instruct_content,
role=self._setting,
cause_by=self._rc.todo,
cause_by=self.rc.todo,
sent_from=self,
)
elif isinstance(response, Message):
msg = response
else:
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo, sent_from=self)
self._rc.memory.add(msg)
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo, sent_from=self)
self.rc.memory.add(msg)
return msg
@ -426,7 +378,7 @@ class Role(BaseModel):
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
for idx, new in enumerate(observed_pure):
if (new["cause_by"] in self._rc.watch or self.name in new["send_to"]) and new not in existed_pure:
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
news.append(observed[idx])
return news
@ -437,59 +389,59 @@ class Role(BaseModel):
if self.recovered:
news = [self.latest_observed_msg] if self.latest_observed_msg else []
if not news:
news = self._rc.msg_buffer.pop_all()
news = self.rc.msg_buffer.pop_all()
# Store the read messages in your own memory to prevent duplicate processing.
old_messages = [] if ignore_memory else self._rc.memory.get()
self._rc.memory.add_batch(news)
old_messages = [] if ignore_memory else self.rc.memory.get()
self.rc.memory.add_batch(news)
# Filter out messages of interest.
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
self.latest_observed_msg = self._rc.news[-1] if self._rc.news else None # record the latest observed msg
self.rc.news = [n for n in news if n.cause_by in self.rc.watch and n not in old_messages]
self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg
# Design Rules:
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
if news_text:
logger.debug(f"{self._setting} observed: {news_text}")
return len(self._rc.news)
return len(self.rc.news)
# async def _observe(self, ignore_memory=False) -> int:
# """Prepare new messages for processing from the message buffer and other sources."""
# # Read unprocessed messages from the msg buffer.
# news = self._rc.msg_buffer.pop_all()
# news = self.rc.msg_buffer.pop_all()
# if self.recovered:
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
# else:
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
#
# # Store the read messages in your own memory to prevent duplicate processing.
# old_messages = [] if ignore_memory else self._rc.memory.get()
# self._rc.memory.add_batch(news)
# old_messages = [] if ignore_memory else self.rc.memory.get()
# self.rc.memory.add_batch(news)
# # Filter out messages of interest.
# self._rc.news = self._find_news(news, old_messages)
# self.rc.news = self._find_news(news, old_messages)
#
# # Design Rules:
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
# if news_text:
# logger.debug(f"{self._setting} observed: {news_text}")
# return len(self._rc.news)
# return len(self.rc.news)
def publish_message(self, msg):
"""If the role belongs to env, then the role's messages will be broadcast to env"""
if not msg:
return
if not self._rc.env:
if not self.rc.env:
# If env does not exist, do not publish the message
return
self._rc.env.publish_message(msg)
self.rc.env.publish_message(msg)
def put_message(self, message):
"""Place the message into the Role object's private message buffer."""
if not message:
return
self._rc.msg_buffer.push(message)
self.rc.msg_buffer.push(message)
async def _react(self) -> Message:
"""Think first, then act, until the Role _think it is time to stop and requires no more todo.
@ -498,22 +450,22 @@ class Role(BaseModel):
"""
actions_taken = 0
rsp = Message(content="No actions taken yet") # will be overwritten after Role _act
while actions_taken < self._rc.max_react_loop:
while actions_taken < self.rc.max_react_loop:
# think
await self._think()
if self._rc.todo is None:
if self.rc.todo is None:
break
# act
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
rsp = await self._act() # 这个rsp是否需要publish_message
actions_taken += 1
return rsp # return output from the last action
async def _act_by_order(self) -> Message:
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state
rsp = Message(content="No actions taken yet") # return default message if _actions=[]
for i in range(start_idx, len(self._states)):
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
rsp = Message(content="No actions taken yet") # return default message if actions=[]
for i in range(start_idx, len(self.states)):
self._set_state(i)
rsp = await self._act()
return rsp # return output from the last action
@ -525,18 +477,18 @@ class Role(BaseModel):
async def react(self) -> Message:
"""Entry to one of three strategies by which Role reacts to the observed Message"""
if self._rc.react_mode == RoleReactMode.REACT:
if self.rc.react_mode == RoleReactMode.REACT:
rsp = await self._react()
elif self._rc.react_mode == RoleReactMode.BY_ORDER:
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
rsp = await self._act_by_order()
elif self._rc.react_mode == RoleReactMode.PLAN_AND_ACT:
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
rsp = await self._plan_and_act()
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
return rsp
def get_memories(self, k=0) -> list[Message]:
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
return self._rc.memory.get(k=k)
return self.rc.memory.get(k=k)
@role_raise_decorator
async def run(self, with_message=None) -> Message | None:
@ -561,7 +513,7 @@ class Role(BaseModel):
rsp = await self.react()
# Reset the next action to be taken.
self._rc.todo = None
self.rc.todo = None
# Send the response message to the Environment object to have it relay the message to the subscribers.
self.publish_message(rsp)
return rsp
@ -569,12 +521,12 @@ class Role(BaseModel):
@property
def is_idle(self) -> bool:
"""If true, all actions have been executed."""
return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty()
return not self.rc.news and not self.rc.todo and self.rc.msg_buffer.empty()
async def think(self) -> Action:
"""The exported `think` function"""
await self._think()
return self._rc.todo
return self.rc.todo
async def act(self) -> ActionOutput:
"""The exported `act` function"""
@ -584,6 +536,6 @@ class Role(BaseModel):
@property
def todo(self) -> str:
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
if self._actions:
return any_to_name(self._actions[0])
if self.actions:
return any_to_name(self.actions[0])
return ""

View file

@ -57,19 +57,19 @@ class Searcher(Role):
async def _act_sp(self) -> Message:
"""Performs the search action in a single process."""
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
response = await self._rc.todo.run(self._rc.memory.get(k=0))
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
response = await self.rc.todo.run(self.rc.memory.get(k=0))
if isinstance(response, (ActionOutput, ActionNode)):
msg = Message(
content=response.content,
instruct_content=response.instruct_content,
role=self.profile,
cause_by=self._rc.todo,
cause_by=self.rc.todo,
)
else:
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo)
self._rc.memory.add(msg)
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo)
self.rc.memory.add(msg)
return msg
async def _act(self) -> Message:

View file

@ -7,13 +7,13 @@
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
distribution feature for message filtering.
"""
from typing import Any, Type
from typing import Any, Callable, Union
from pydantic import Field
from semantic_kernel import Kernel
from semantic_kernel.planning import SequentialPlanner
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
from semantic_kernel.planning.basic_planner import BasicPlanner
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
from metagpt.actions import UserRequirement
from metagpt.actions.execute_task import ExecuteTask
@ -41,17 +41,17 @@ class SkAgent(Role):
goal: str = "Execute task based on passed in task description"
constraints: str = ""
plan: Any = None
plan: Plan = None
planner_cls: Any = None
planner: Any = None
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
llm: BaseLLM = Field(default_factory=LLM)
kernel: Kernel = Field(default_factory=Kernel)
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
import_skill: Type[Kernel.import_skill] = None
import_semantic_skill_from_directory: Callable = None
import_skill: Callable = None
def __init__(self, **kwargs) -> None:
def __init__(self, **data: Any) -> None:
"""Initializes the Engineer role with given attributes."""
super().__init__(**kwargs)
super().__init__(**data)
self._init_actions([ExecuteTask()])
self._watch([UserRequirement])
self.kernel = make_sk_kernel()
@ -71,10 +71,10 @@ class SkAgent(Role):
self._set_state(0)
# how funny the interface is inconsistent
if isinstance(self.planner, BasicPlanner):
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content, self.kernel)
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
logger.info(self.plan.generated_plan)
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content)
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
async def _act(self) -> Message:
# how funny the interface is inconsistent
@ -85,6 +85,6 @@ class SkAgent(Role):
result = (await self.plan.invoke_async()).result
logger.info(result)
msg = Message(content=result, role=self.profile, cause_by=self._rc.todo)
self._rc.memory.add(msg)
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
self.rc.memory.add(msg)
return msg

View file

@ -42,34 +42,34 @@ class Teacher(Role):
async def _think(self) -> bool:
"""Everything will be done part by part."""
if not self._actions:
if not self._rc.news or self._rc.news[0].cause_by != any_to_str(UserRequirement):
if not self.actions:
if not self.rc.news or self.rc.news[0].cause_by != any_to_str(UserRequirement):
raise ValueError("Lesson content invalid.")
actions = []
print(TeachingPlanBlock.TOPICS)
for topic in TeachingPlanBlock.TOPICS:
act = WriteTeachingPlanPart(context=self._rc.news[0].content, topic=topic, llm=self._llm)
act = WriteTeachingPlanPart(context=self.rc.news[0].content, topic=topic, llm=self.llm)
actions.append(act)
self._init_actions(actions)
if self._rc.todo is None:
if self.rc.todo is None:
self._set_state(0)
return True
if self._rc.state + 1 < len(self._states):
self._set_state(self._rc.state + 1)
if self.rc.state + 1 < len(self.states):
self._set_state(self.rc.state + 1)
return True
self._rc.todo = None
self.rc.todo = None
return False
async def _react(self) -> Message:
ret = Message(content="")
while True:
await self._think()
if self._rc.todo is None:
if self.rc.todo is None:
break
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
msg = await self._act()
if ret.content != "":
ret.content += "\n\n\n"
@ -104,7 +104,7 @@ class Teacher(Role):
def course_title(self):
"""Return course title of teaching plan"""
default_title = "teaching_plan"
for act in self._actions:
for act in self.actions:
if act.topic != TeachingPlanBlock.COURSE_TITLE:
continue
if act.rsp is None:

View file

@ -34,9 +34,9 @@ class TutorialAssistant(Role):
constraints: str = "Strictly follow Markdown's syntax, with neat and standardized layout"
language: str = "Chinese"
topic = ""
main_title = ""
total_content = ""
topic: str = ""
main_title: str = ""
total_content: str = ""
def __init__(self, **kwargs):
super().__init__(**kwargs)
@ -71,9 +71,9 @@ class TutorialAssistant(Role):
Returns:
A message containing the result of the action.
"""
todo = self._rc.todo
todo = self.rc.todo
if type(todo) is WriteDirectory:
msg = self._rc.memory.get(k=1)[0]
msg = self.rc.memory.get(k=1)[0]
self.topic = msg.content
resp = await todo.run(topic=self.topic)
logger.info(resp)

View file

@ -23,9 +23,17 @@ from abc import ABC
from asyncio import Queue, QueueEmpty, wait_for
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
from pydantic import BaseModel, Field
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
field_serializer,
field_validator,
)
from pydantic_core import core_schema
from metagpt.config import CONFIG
from metagpt.const import (
@ -46,6 +54,64 @@ from metagpt.utils.serialize import (
)
class SerializationMixin(BaseModel):
"""SereDeserMixin for subclass' ser&deser"""
__is_polymorphic_base = False
__subclasses_map__ = {}
@classmethod
def __get_pydantic_core_schema__(
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
) -> core_schema.CoreSchema:
schema = handler(source)
og_schema_ref = schema["ref"]
schema["ref"] += ":mixin"
return core_schema.no_info_before_validator_function(
cls.__deserialize_with_real_type__,
schema=schema,
ref=og_schema_ref,
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
)
@classmethod
def __serialize_add_class_type__(
cls,
value,
handler: core_schema.SerializerFunctionWrapHandler,
) -> Any:
ret = handler(value)
if not len(cls.__subclasses__()):
# only subclass add `__module_class_name`
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
return ret
@classmethod
def __deserialize_with_real_type__(cls, value: Any):
if not isinstance(value, dict):
return value
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
# add right condition to init BaseClass like Action()
return value
module_class_name = value.get("__module_class_name", None)
if module_class_name is None:
raise ValueError("Missing field: __module_class_name")
class_type = cls.__subclasses_map__.get(module_class_name, None)
if class_type is None:
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
return class_type(**value)
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
cls.__is_polymorphic_base = is_polymorphic_base
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
super().__init_subclass__(**kwargs)
class SimpleMessage(BaseModel):
content: str
role: str
@ -102,33 +168,64 @@ class Documents(BaseModel):
class Message(BaseModel):
"""list[<role>: <content>]"""
id: str # According to Section 2.2.3.1.1 of RFC 135
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
content: str
instruct_content: BaseModel = None
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
role: str = "user" # system / user / assistant
cause_by: str = ""
sent_from: str = ""
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
cause_by: str = Field(default="", validate_default=True)
sent_from: str = Field(default="", validate_default=True)
send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
def __init__(self, content: str = "", **kwargs):
ic = kwargs.get("instruct_content", None)
@field_validator("id", mode="before")
@classmethod
def check_id(cls, id: str) -> str:
return id if id else uuid.uuid4().hex
@field_validator("instruct_content", mode="before")
@classmethod
def check_instruct_content(cls, ic: Any) -> BaseModel:
if ic and not isinstance(ic, BaseModel) and "class" in ic:
# compatible with custom-defined ActionOutput
mapping = actionoutput_str_to_mapping(ic["mapping"])
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
ic_new = ic_obj(**ic["value"])
kwargs["instruct_content"] = ic_new
ic = ic_obj(**ic["value"])
return ic
kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
kwargs["content"] = kwargs.get("content", content)
kwargs["cause_by"] = any_to_str(
kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement"))
)
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
super(Message, self).__init__(**kwargs)
@field_validator("cause_by", mode="before")
@classmethod
def check_cause_by(cls, cause_by: Any) -> str:
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
@field_validator("sent_from", mode="before")
@classmethod
def check_sent_from(cls, sent_from: Any) -> str:
return any_to_str(sent_from if sent_from else "")
@field_validator("send_to", mode="before")
@classmethod
def check_send_to(cls, send_to: Any) -> set:
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
ic_dict = None
if ic:
# compatible with custom-defined ActionOutput
schema = ic.model_json_schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
return ic_dict
def __init__(self, content: str = "", **data: Any):
data["content"] = data.get("content", content)
super().__init__(**data)
def __setattr__(self, key, val):
"""Override `@property.setter`, convert non-string parameters into string parameters."""
@ -142,26 +239,10 @@ class Message(BaseModel):
new_val = val
super().__setattr__(key, new_val)
def dict(self, *args, **kwargs) -> "DictStrAny":
"""overwrite the `dict` to dump dynamic pydantic model"""
obj_dict = super(Message, self).dict(*args, **kwargs)
ic = self.instruct_content
if ic:
# compatible with custom-defined ActionOutput
schema = ic.schema()
# `Documents` contain definitions
if "definitions" not in schema:
# TODO refine with nested BaseModel
mapping = actionoutout_schema_to_mapping(schema)
mapping = actionoutput_mapping_to_str(mapping)
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
return obj_dict
def __str__(self):
# prefix = '-'.join([self.role, str(self.cause_by)])
if self.instruct_content:
return f"{self.role}: {self.instruct_content.dict()}"
return f"{self.role}: {self.instruct_content.model_dump()}"
return f"{self.role}: {self.content}"
def __repr__(self):
@ -173,7 +254,7 @@ class Message(BaseModel):
def dump(self) -> str:
"""Convert the object to json string"""
return self.json(exclude_none=True)
return self.model_dump_json(exclude_none=True, warnings=False)
@staticmethod
@handle_exception(exception_type=JSONDecodeError, default_return=None)
@ -224,19 +305,9 @@ class AIMessage(Message):
class MessageQueue(BaseModel):
"""Message queue which supports asynchronous updates."""
_queue: Queue = Field(default_factory=Queue)
model_config = ConfigDict(arbitrary_types_allowed=True)
_private_attributes = {"_queue": Queue()}
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
for key in self._private_attributes.keys():
if key in kwargs:
object.__setattr__(self, key, kwargs[key])
else:
object.__setattr__(self, key, Queue())
_queue: Queue = PrivateAttr(default_factory=Queue)
def pop(self) -> Message | None:
"""Pop one message from the queue."""
@ -312,28 +383,28 @@ class BaseContext(BaseModel, ABC):
class CodingContext(BaseContext):
filename: str
design_doc: Optional[Document]
task_doc: Optional[Document]
code_doc: Optional[Document]
design_doc: Optional[Document] = None
task_doc: Optional[Document] = None
code_doc: Optional[Document] = None
class TestingContext(BaseContext):
filename: str
code_doc: Document
test_doc: Optional[Document]
test_doc: Optional[Document] = None
class RunCodeContext(BaseContext):
mode: str = "script"
code: Optional[str]
code: Optional[str] = None
code_filename: str = ""
test_code: Optional[str]
test_code: Optional[str] = None
test_filename: str = ""
command: List[str] = Field(default_factory=list)
working_directory: str = ""
additional_python_paths: List[str] = Field(default_factory=list)
output_filename: Optional[str]
output: Optional[str]
output_filename: Optional[str] = None
output: Optional[str] = None
class RunCodeResult(BaseContext):

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

108
metagpt/strategy/base.py Normal file
View file

@ -0,0 +1,108 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:16 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from typing import List
from anytree import Node, RenderTree
from pydantic import BaseModel
class BaseParser(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def propose(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def sample(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def value(self, input: str, **kwargs) -> str:
raise NotImplementedError
class BaseEvaluator(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def status_verify(self, *args, **kwargs):
raise NotImplementedError
class ThoughtNode(Node):
"""A node representing a thought in the thought tree."""
name: str = ""
value: int = 0
id: int = 0
valid_status: bool = True
def update_value(self, value) -> None:
"""Update the value of the thought node."""
self.value = value
def update_valid_status(self, status) -> None:
"""Update the validity status of the thought node."""
self.valid_status = status
class ThoughtTree(RenderTree):
"""A tree structure to represent thoughts."""
@property
def all_nodes(self) -> List[ThoughtNode]:
"""
Get a list of all nodes in the thought tree.
Returns:
List[ThoughtNode]: A list containing all nodes in the thought tree.
"""
all_nodes = [node for _, _, node in self]
return all_nodes
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
"""
Update the tree with new thoughts.
Args:
thought (List[dict]): A list of dictionaries representing thought information.
current_node (ThoughtNode): The current node under which new thoughts will be added.
Returns:
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
"""
nodes = []
for node_info in thought:
node = ThoughtNode(
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
)
nodes.append(node)
return nodes
def parse_node_path(self, node) -> List[str]:
"""
Parse and retrieve the hierarchical path of the given thought node.
This method traverses the parent nodes of the provided 'node' and constructs
the full path from the root node to the given node.
Args:
node: The thought node for which the hierarchical path needs to be parsed.
Returns:
List[str]: A list representing the full hierarchical path of the given thought node.
The list is ordered from the root node to the provided node.
"""
full_node_path = []
while node is not None:
full_node_path.append(node.name)
node = node.parent
full_node_path.reverse()
return full_node_path
def show(self) -> None:
"""Print the updated tree."""
print("\nUpdated Tree:")
for pre, _, node in self:
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/26/2023 3:32 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,73 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:06 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.creative_writing import cot_prompt, vote_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class TextGenParser(BaseParser):
propose_prompt: str = cot_prompt
value_prompt: str = vote_prompt
def __call__(self, input_text: str) -> str:
return input_text
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
# node_result = self(input)
id = kwargs.get("node_id", "0")
return self.value_prompt + f"Choice {id}:\n{input}\n"
class TextGenEvaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
value = 0
node_id = kwargs.get("node_id", "0")
pattern = r".*best choice is .*(\d+).*"
match = re.match(pattern, evaluation, re.DOTALL)
if match:
vote = int(match.groups()[0])
print(vote)
if vote == int(node_id):
value = 1
except:
value = 0
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didnt like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
parser = TextGenParser()
evaluator = TextGenEvaluator()
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot_base.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,64 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 1:36 AM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import re
from metagpt.strategy.prompt_templates.game24 import propose_prompt, value_prompt
from metagpt.strategy.tot import TreeofThought
from metagpt.strategy.tot_schema import (
BaseEvaluator,
BaseParser,
Strategy,
ThoughtSolverConfig,
)
class Game24Parser(BaseParser):
propose_prompt: str = propose_prompt
value_prompt: str = value_prompt
def __call__(self, input_text: str) -> str:
last_line = input_text.strip().split("\n")[-1]
return last_line.split("left: ")[-1].split(")")[0]
def propose(self, current_state: str, **kwargs) -> str:
return self.propose_prompt.format(input=current_state, **kwargs)
def value(self, input: str = "", **kwargs) -> str:
node_result = self(input)
return self.value_prompt.format(input=node_result)
class Game24Evaluator(BaseEvaluator):
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
status_map = {val: key for key, val in value_map.items()}
def __call__(self, evaluation: str, **kwargs) -> float:
try:
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
value = self.value_map[matches[0]]
except:
value = 0.001
return value
def status_verify(self, value):
status = False
if value in self.status_map:
status_value = self.status_map[value]
if status_value != "impossible":
status = True
return status
if __name__ == "__main__":
import asyncio
initial_prompt = """4 5 6 10"""
parser = Game24Parser()
evaluator = Game24Evaluator()
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
tot = TreeofThought(strategy=Strategy.BFS, config=config)
asyncio.run(tot.solve(init_prompt=initial_prompt))

View file

@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 5:21 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :

View file

@ -0,0 +1,25 @@
standard_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
"""
cot_prompt = """
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
Make a plan then write. Your output should be of the following format:
Plan:
Your plan here.
Passage:
Your passage here.
"""
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
"""
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
"""
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
"""

View file

@ -0,0 +1,139 @@
# 5-shot
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Input: 1 4 8 8
Answer: (8 / 4 + 1) * 8 = 24
Input: 5 5 5 9
Answer: 5 + 5 + 5 + 9 = 24
Input: {input}
"""
# 5-shot
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
Input: 4 4 6 8
Steps:
4 + 8 = 12 (left: 4 6 12)
6 - 4 = 2 (left: 2 12)
2 * 12 = 24 (left: 24)
Answer: (6 - 4) * (4 + 8) = 24
Input: 2 9 10 12
Steps:
12 * 2 = 24 (left: 9 10 24)
10 - 9 = 1 (left: 1 24)
24 * 1 = 24 (left: 24)
Answer: (12 * 2) * (10 - 9) = 24
Input: 4 9 10 13
Steps:
13 - 10 = 3 (left: 3 4 9)
9 - 3 = 6 (left: 4 6)
4 * 6 = 24 (left: 24)
Answer: 4 * (9 - (13 - 10)) = 24
Input: 1 4 8 8
Steps:
8 / 4 = 2 (left: 1 2 8)
1 + 2 = 3 (left: 3 8)
3 * 8 = 24 (left: 24)
Answer: (1 + 8 / 4) * 8 = 24
Input: 5 5 5 9
Steps:
5 + 5 = 10 (left: 5 9 10)
10 + 5 = 15 (left: 9 15)
15 + 9 = 24 (left: 24)
Answer: ((5 + 5) + 5) + 9 = 24
Input: {input}
"""
# 1-shot
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
Input: 2 8 8 14
Possible next steps:
2 + 8 = 10 (left: 8 10 14)
8 / 2 = 4 (left: 4 8 14)
14 + 2 = 16 (left: 8 8 16)
2 * 8 = 16 (left: 8 14 16)
8 - 2 = 6 (left: 6 8 14)
14 - 8 = 6 (left: 2 6 8)
14 / 2 = 7 (left: 7 8 8)
14 - 2 = 12 (left: 8 8 12)
Here is my task for 1 input and {n_generate_sample} possible thoughts:
Input: {input}
Possible next steps:
"""
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
10 14
10 + 14 = 24
sure
11 12
11 + 12 = 23
12 - 11 = 1
11 * 12 = 132
11 / 12 = 0.91
impossible
4 4 10
4 + 4 + 10 = 8 + 10 = 18
4 * 10 - 4 = 40 - 4 = 36
(10 - 4) * 4 = 6 * 4 = 24
sure
4 9 11
9 + 11 + 4 = 20 + 4 = 24
sure
5 7 8
5 + 7 + 8 = 12 + 8 = 20
(8 - 5) * 7 = 3 * 7 = 21
I cannot obtain 24 now, but numbers are within a reasonable range
likely
5 6 6
5 + 6 + 6 = 17
(6 - 5) * 6 = 1 * 6 = 6
I cannot obtain 24 now, but numbers are within a reasonable range
likely
10 10 11
10 + 10 + 11 = 31
(11 - 10) * 10 = 10
10 10 10 are all too big
impossible
1 3 3
1 * 3 * 3 = 9
(1 + 3) * 3 = 12
1 3 3 are all too small
impossible
{input}
"""
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) = 24
Judge:
sure
Input: 2 9 10 12
Answer: 2 * 12 * (10 - 9) = 24
Judge:
sure
Input: 4 9 10 13
Answer: (13 - 9) * (10 - 4) = 24
Judge:
sure
Input: 4 4 6 8
Answer: (4 + 8) * (6 - 4) + 1 = 25
Judge:
impossible
Input: 2 9 10 12
Answer: 2 * (12 - 10) = 24
Judge:
impossible
Input: 4 9 10 13
Answer: (13 - 4) * (10 - 9) = 24
Judge:
impossible
Input: {input}
Answer: {answer}
Judge:"""

272
metagpt/strategy/tot.py Normal file
View file

@ -0,0 +1,272 @@
# -*- coding: utf-8 -*-
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import asyncio
from typing import Any, List
from pydantic import BaseModel, Field
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser
OUTPUT_FORMAT = """
Output a list of jsons following the format:
```json
[
{
"node_id": str = "unique identifier for a solution, can be an ordinal",
"node_state_instruction": "specified sample of solution",
},
...
]
```
"""
class ThoughtSolverBase(BaseModel):
thought_tree: str = ""
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.llm.use_system_prompt = False
async def solve(self, init_prompt):
"""
Solve method for subclasses to implement.
"""
raise NotImplementedError("Subclasses must implement the solve method")
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
"""
Generate children thoughts based on the current state.
Args:
current_state (str): The current state for which thoughts are generated.
current_node (ThoughtNode): The current node in the thought tree.
Returns:
List[ThoughtNode]: List of nodes representing the generated thoughts.
"""
state_prompt = self.config.parser.propose(
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block=None, text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
return self.thought_tree.update_node(thoughts, current_node=current_node)
async def evaluate_node(self, node, parent_value) -> None:
"""
Evaluate a node and update its status and value.
Args:
node (ThoughtNode): The node to be evaluated.
parent_value (float): The parent node's value.
Returns:
None
"""
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
evaluation = await self.llm.aask(msg=eval_prompt)
value = self.config.evaluator(evaluation, **{"node_id": node.id})
status = self.config.evaluator.status_verify(value)
node.update_valid_status(status=status)
# 累计分数
node.update_value(parent_value + value)
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
"""
Select nodes based on the configured selection method.
Args:
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
Returns:
List[ThoughtNode]: List of selected nodes.
"""
# selection
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in select_nodes:
node.parent = None # 从树中删除节点
return select_nodes
def update_solution(self):
"""
Select the result with the highest score.
Returns:
- List[ThoughtNode]: List of nodes representing the best solution.
- List[str]: List of node names forming the best solution path.
"""
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
best_solution_path = self.thought_tree.parse_node_path(best_node)
return [best_node], best_solution_path
class BFSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
"""
Solve the problem using Breadth-First Search (BFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through BFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
current_nodes = [root]
for step in range(self.config.max_steps):
solutions = await self._bfs_build(current_nodes)
selected_nodes = self.select_nodes(solutions)
current_nodes = selected_nodes
self.thought_tree.show()
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
async def _bfs_build(self, current_nodes):
"""
Build the thought tree using Breadth-First Search (BFS) strategy.
Args:
current_nodes (List[ThoughtNode]): Current nodes to expand.
Returns:
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
"""
tasks = []
for node in current_nodes:
current_state = self.config.parser(node.name)
current_value = node.value
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
thought_nodes_list = await asyncio.gather(*tasks)
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
return solutions
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await asyncio.gather(
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
)
return thought_nodes
class DFSSolver(ThoughtSolverBase):
async def _dfs(self, root_node):
"""
Perform Depth-First Search (DFS) on the thought tree.
Args:
root_node (ThoughtNode): The root node of the thought tree.
Returns:
List[str]: The solution path obtained through DFS.
"""
impossible_state_cnt = 0
node = root_node
for step in range(self.max_steps):
current_state = self.config.parser(node.name)
current_value = node.value
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
if thought_nodes[0].valid_status is False:
impossible_state_cnt += 1
if impossible_state_cnt >= 2:
logger.info("impossible state reached, break")
break
node = thought_nodes[0]
_solution_path = self.thought_tree.parse_node_path(node)
self.thought_tree.show()
return _solution_path
async def solve(self, init_prompt="", root=ThoughtNode("")):
"""
Solve the problem using Depth-First Search (DFS) strategy.
Args:
init_prompt (str): The initial prompt for the solver.
Returns:
List[str]: The best solution path obtained through DFS.
"""
root = ThoughtNode(init_prompt)
self.thought_tree = ThoughtTree(root)
for n in range(self.config.n_solution_sample):
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
await self._dfs(root)
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
class MCTSSolver(ThoughtSolverBase):
async def solve(self, init_prompt=""):
raise NotImplementedError
class TreeofThought(BaseModel):
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
strategy: Strategy = Field(default=Strategy.BFS)
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._initialize_solver(self.strategy)
def _initialize_solver(self, strategy):
"""
Initialize the solver based on the chosen strategy.
Args:
strategy (Strategy): The strategy to use for solving.
Returns:
ThoughtSolverBase: An instance of the appropriate solver.
"""
if strategy == Strategy.BFS:
self.solver = BFSSolver(config=self.config)
elif strategy == Strategy.DFS:
self.solver = DFSSolver(config=self.config)
elif strategy == Strategy.MCTS:
self.solver = MCTSSolver(config=self.config)
else:
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
async def solve(self, init_prompt=""):
"""
Solve the problem using the specified strategy.
Args:
init_prompt (str): The initial prompt for the solver.
strategy (str): The strategy to use for solving.
Returns:
Any: The solution obtained using the selected strategy.
"""
await self.solver.solve(init_prompt)

View file

@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# @Date : 12/25/2023 9:14 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from enum import Enum
from pydantic import BaseModel, Field
from metagpt.strategy.base import BaseEvaluator, BaseParser
class MethodSelect(Enum):
SAMPLE = "sample"
GREEDY = "greedy"
class Strategy(Enum):
BFS = "BFS"
DFS = "DFS"
MCTS = "MCTS"
class ThoughtSolverConfig(BaseModel):
max_steps: int = 3
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
n_generate_sample: int = 5 # per node
n_select_sample: int = 3 # per path
n_solution_sample: int = 5 # only for dfs
parser: BaseParser = Field(default_factory=BaseParser)
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)

View file

@ -1,7 +1,7 @@
import asyncio
from typing import AsyncGenerator, Awaitable, Callable
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from metagpt.logs import logger
from metagpt.roles import Role
@ -33,10 +33,9 @@ class SubscriptionRunner(BaseModel):
>>> asyncio.run(main())
"""
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
model_config = ConfigDict(arbitrary_types_allowed=True)
class Config:
arbitrary_types_allowed = True
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
async def subscribe(
self,

View file

@ -10,8 +10,9 @@
import warnings
from pathlib import Path
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from metagpt.actions import UserRequirement
from metagpt.config import CONFIG
@ -34,32 +35,27 @@ class Team(BaseModel):
dedicated to env any multi-agent activity, such as collaboratively writing executable code.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
env: Environment = Field(default_factory=Environment)
investment: float = Field(default=10.0)
idea: str = Field(default="")
def __init__(self, **kwargs):
super().__init__(**kwargs)
if "roles" in kwargs:
self.hire(kwargs["roles"])
if "env_desc" in kwargs:
self.env.desc = kwargs["env_desc"]
class Config:
arbitrary_types_allowed = True
def __init__(self, **data: Any):
super(Team, self).__init__(**data)
if "roles" in data:
self.hire(data["roles"])
if "env_desc" in data:
self.env.desc = data["env_desc"]
def serialize(self, stg_path: Path = None):
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
team_info_path = stg_path.joinpath("team_info.json")
write_json_file(team_info_path, self.dict(exclude={"env": True}))
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
@classmethod
def recover(cls, stg_path: Path) -> "Team":
return cls.deserialize(stg_path)
@classmethod
def deserialize(cls, stg_path: Path) -> "Team":
"""stg_path = ./storage/team"""
@ -76,7 +72,6 @@ class Team(BaseModel):
# recover environment
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
team_info.update({"env": environment})
team = Team(**team_info)
return team
@ -121,7 +116,7 @@ class Team(BaseModel):
return self.run_project(idea=idea, send_to=send_to)
def _save(self):
logger.info(self.json(ensure_ascii=False))
logger.info(self.model_dump_json())
@serialize_decorator
async def run(self, n_round=3, idea="", send_to="", auto_archive=True):

View file

@ -9,7 +9,7 @@ from typing import Optional
from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
from metagpt.logs import logger
@ -25,15 +25,14 @@ except ImportError:
class GoogleAPIWrapper(BaseModel):
google_api_key: Optional[str] = None
google_cse_id: Optional[str] = None
model_config = ConfigDict(arbitrary_types_allowed=True)
google_api_key: Optional[str] = Field(default=None, validate_default=True)
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
class Config:
arbitrary_types_allowed = True
@validator("google_api_key", always=True)
@field_validator("google_api_key", mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
@ -45,7 +44,7 @@ class GoogleAPIWrapper(BaseModel):
)
return val
@validator("google_cse_id", always=True)
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id

View file

@ -8,13 +8,15 @@
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config import CONFIG
class SerpAPIWrapper(BaseModel):
search_engine: Any #: :meta private:
model_config = ConfigDict(arbitrary_types_allowed=True)
search_engine: Any = None #: :meta private:
params: dict = Field(
default={
"engine": "google",
@ -23,13 +25,11 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
serpapi_api_key: Optional[str] = None
# should add `validate_default=True` to check with default value
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
arbitrary_types_allowed = True
@validator("serpapi_api_key", always=True)
@field_validator("serpapi_api_key", mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key

View file

@ -9,21 +9,18 @@ import json
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from metagpt.config import CONFIG
class SerperWrapper(BaseModel):
search_engine: Any #: :meta private:
search_engine: Any = None #: :meta private:
payload: dict = Field(default={"page": 1, "num": 10})
serper_api_key: Optional[str] = None
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
arbitrary_types_allowed = True
@validator("serper_api_key", always=True)
@field_validator("serper_api_key", mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key

View file

@ -27,7 +27,7 @@ from typing import Any, Callable, List, Tuple, Union, get_args, get_origin
import aiofiles
import loguru
from pydantic.json import pydantic_encoder
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, _utils
from metagpt.const import MESSAGE_ROUTE_TO_ALL
@ -472,7 +472,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
folder_path.mkdir(parents=True, exist_ok=True)
with open(json_file, "w", encoding=encoding) as fout:
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
def import_class(class_name: str, module_name: str) -> type:
@ -512,7 +512,7 @@ def role_raise_decorator(func):
except KeyboardInterrupt as kbi:
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
if self.latest_observed_msg:
self._rc.memory.delete(self.latest_observed_msg)
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))
except Exception:
@ -522,7 +522,7 @@ def role_raise_decorator(func):
"we delete the newest role communication message in the role's memory."
)
# remove role newest observed msg to make it observed again
self._rc.memory.delete(self.latest_observed_msg)
self.rc.memory.delete(self.latest_observed_msg)
# raise again to make it captured outside
raise Exception(format_trackback_info(limit=None))

View file

@ -5,7 +5,7 @@ from typing import Generator, Optional
from urllib.parse import urljoin, urlparse
from bs4 import BeautifulSoup
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr
class WebPage(BaseModel):
@ -13,11 +13,8 @@ class WebPage(BaseModel):
html: str
url: str
class Config:
underscore_attrs_are_private = True
_soup: Optional[BeautifulSoup] = None
_title: Optional[str] = None
_soup: Optional[BeautifulSoup] = PrivateAttr(default=None)
_title: Optional[str] = PrivateAttr(default=None)
@property
def soup(self) -> BeautifulSoup:

View file

@ -62,10 +62,10 @@ def serialize_message(message: "Message"):
ic = message_cp.instruct_content
if ic:
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
schema = ic.schema()
schema = ic.model_json_schema()
mapping = actionoutout_schema_to_mapping(schema)
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
msg_ser = pickle.dumps(message_cp)
return msg_ser

View file

@ -10,7 +10,7 @@ fire==0.4.0
typer
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.1.16
lancedb==0.4.0
langchain==0.0.352
loguru==0.6.0
meilisearch==0.21.0
@ -19,7 +19,7 @@ openai==1.6.0
openpyxl
beautifulsoup4==4.12.2
pandas==2.0.3
pydantic==1.10.8
pydantic==2.5.3
#pygame==2.1.3
#pymilvus==2.2.8
pytest==7.2.2
@ -33,16 +33,15 @@ tqdm==4.64.0
#unstructured[local-inference]
# selenium>4
# webdriver_manager<3.9
anthropic==0.3.6
anthropic==0.8.1
typing-inspect==0.8.0
aiofiles
typing_extensions==4.7.0
typing_extensions==4.9.0
libcst==1.0.1
qdrant-client==1.4.0
qdrant-client==1.7.0
pytest-mock==3.11.1
# open-interpreter==0.1.7; python_version>"3.9" # Conflict with openai 1.x
ta==0.10.2
semantic-kernel==0.4.0.dev0
semantic-kernel==0.4.3.dev0
wrapt==1.15.0
#aiohttp_jinja2
# azure-cognitiveservices-speech~=1.31.0 # Used by metagpt/tools/azure_tts.py

View file

@ -89,6 +89,7 @@ def loguru_caplog(caplog):
@pytest.fixture(scope="session", autouse=True)
def setup_and_teardown_git_repo(request):
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / "unittest")
CONFIG.git_reinit = True
# Destroy git repo at the end of the test session.
def fin():

View file

@ -12,6 +12,7 @@ import pytest
from metagpt.actions import Action
from metagpt.actions.action_node import ActionNode
from metagpt.environment import Environment
from metagpt.llm import LLM
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.team import Team
@ -76,18 +77,24 @@ async def test_action_node_one_layer():
assert "key-a" in markdown_template
assert node_dict["key-a"] == "instruction-b"
assert "key-a" in repr(node)
@pytest.mark.asyncio
async def test_action_node_two_layer():
node_a = ActionNode(key="key-a", expected_type=str, instruction="i-a", example="e-a")
node_b = ActionNode(key="key-b", expected_type=str, instruction="i-b", example="e-b")
node_a = ActionNode(key="reasoning", expected_type=str, instruction="reasoning step by step", example="")
node_b = ActionNode(key="answer", expected_type=str, instruction="the final answer", example="")
root = ActionNode.from_children(key="", nodes=[node_a, node_b])
assert "key-a" in root.children
root = ActionNode.from_children(key="detail answer", nodes=[node_a, node_b])
assert "reasoning" in root.children
assert node_b in root.children.values()
json_template = root.compile(context="123", schema="json", mode="auto")
assert "i-a" in json_template
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
assert "579" in answer1.content
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
assert "579" in answer2.content
t_dict = {
@ -116,16 +123,33 @@ WRITE_TASKS_OUTPUT_MAPPING = {
"Anything UNCLEAR": (str, ...),
}
WRITE_TASKS_OUTPUT_MAPPING_MISSING = {
"Required Python third-party packages": (str, ...),
}
def test_create_model_class():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
assert test_class.__name__ == "test_class"
output = test_class(**t_dict)
print(output.schema())
assert output.schema()["title"] == "test_class"
assert output.schema()["type"] == "object"
assert output.schema()["properties"]["Full API spec"]
def test_create_model_class_missing():
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING_MISSING)
assert test_class.__name__ == "test_class"
_ = test_class(**t_dict) # 这里应该要挂掉
def test_create_model_class_with_mapping():
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
t1 = t(**t_dict)
value = t1.dict()["Task list"]
value = t1.model_dump()["Task list"]
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]

View file

@ -1,101 +0,0 @@
import os
import tempfile
import pytest
from metagpt.actions.clone_function import (
CloneFunction,
run_function_code,
run_function_script,
)
source_code = """
import pandas as pd
import ta
def user_indicator():
# 读取股票数据
stock_data = pd.read_csv('./tests/data/baba_stock.csv')
stock_data.head()
# 计算简单移动平均线
stock_data['SMA'] = ta.trend.sma_indicator(stock_data['Close'], window=6)
stock_data[['Date', 'Close', 'SMA']].head()
# 计算布林带
stock_data['bb_upper'], stock_data['bb_middle'], stock_data['bb_lower'] = ta.volatility.bollinger_hband_indicator(stock_data['Close'], window=20), ta.volatility.bollinger_mavg(stock_data['Close'], window=20), ta.volatility.bollinger_lband_indicator(stock_data['Close'], window=20)
stock_data[['Date', 'Close', 'bb_upper', 'bb_middle', 'bb_lower']].head()
"""
template_code = """
def stock_indicator(stock_path: str, indicators=['Simple Moving Average', 'BollingerBands', 'MACD]) -> pd.DataFrame:
import pandas as pd
# here is your code.
"""
def get_expected_res():
import pandas as pd
import ta
# 读取股票数据
stock_data = pd.read_csv("./tests/data/baba_stock.csv")
stock_data.head()
# 计算简单移动平均线
stock_data["SMA"] = ta.trend.sma_indicator(stock_data["Close"], window=6)
stock_data[["Date", "Close", "SMA"]].head()
# 计算布林带
stock_data["bb_upper"], stock_data["bb_middle"], stock_data["bb_lower"] = (
ta.volatility.bollinger_hband_indicator(stock_data["Close"], window=20),
ta.volatility.bollinger_mavg(stock_data["Close"], window=20),
ta.volatility.bollinger_lband_indicator(stock_data["Close"], window=20),
)
stock_data[["Date", "Close", "bb_upper", "bb_middle", "bb_lower"]].head()
return stock_data
@pytest.mark.asyncio
async def test_clone_function():
clone = CloneFunction()
code = await clone.run(template_code, source_code)
assert "def " in code
stock_path = "./tests/data/baba_stock.csv"
df, msg = run_function_code(code, "stock_indicator", stock_path)
assert not msg
expected_df = get_expected_res()
assert df.equals(expected_df)
def test_run_function_script():
# 创建一个临时文件并写入脚本内容
script_content = """def valid_function(arg1, arg2):\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as temp_file:
temp_file.write(script_content)
temp_file_path = temp_file.name
invalid_script_content = """def valid_function(arg1, arg2)\n return arg1 + arg2\n"""
with tempfile.NamedTemporaryFile(mode="w+", suffix=".py", delete=False) as error_temp_file:
error_temp_file.write(invalid_script_content)
error_temp_file_path = error_temp_file.name
try:
# 正常情况下运行脚本
result, _ = run_function_script(temp_file_path, "valid_function", 1, arg2=2)
assert result == 3
# 不存在的脚本路径
with pytest.raises(FileNotFoundError):
run_function_script("nonexistent/path/script.py", "valid_function", 1, arg2=2)
# 无效的脚本内容
result, traceback = run_function_script(error_temp_file_path, "invalid_function", 1, arg2=2)
assert not result
assert "SyntaxError" in traceback
# 函数调用失败的情况
result, traceback = run_function_script(temp_file_path, "function_that_raises_exception", 1, arg2=2)
assert not result
assert "KeyError" in traceback
finally:
# 删除临时文件
if os.path.exists(temp_file_path):
os.remove(temp_file_path)

View file

@ -142,7 +142,7 @@ async def test_debug_error():
"Ran 5 tests in 0.007s\n\nFAILED (failures=1)\n;\n",
)
await FileRepository.save_file(
filename=ctx.output_filename, content=output_data.json(), relative_path=TEST_OUTPUTS_FILE_REPO
filename=ctx.output_filename, content=output_data.model_dump_json(), relative_path=TEST_OUTPUTS_FILE_REPO
)
debug_error = DebugError(context=ctx)

View file

@ -20,7 +20,7 @@ from metagpt.actions.invoice_ocr import GenerateTable, InvoiceOCR, ReplyQuestion
"invoice_path",
[
"../../data/invoices/invoice-3.jpg",
"../../data/invoices/invoice-4.zip",
# "../../data/invoices/invoice-4.zip",
],
)
async def test_invoice_ocr(invoice_path: str):

View file

@ -8,7 +8,7 @@
import pytest
from metagpt.actions import CollectLinks
from metagpt.actions import CollectLinks, research
@pytest.mark.asyncio
@ -18,5 +18,107 @@ async def test_action():
assert result
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", '
'"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
rank_before = []
rank_after = []
url_per_query = 4
def rank_func(results):
results = results[:url_per_query]
rank_before.append(results)
results = results[::-1]
rank_after.append(results)
return results
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z
@pytest.mark.asyncio
async def test_web_browse_and_summarize(mocker):
async def mock_llm_ask(*args, **kwargs):
return "metagpt"
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
url = "https://github.com/geekan/MetaGPT"
url2 = "https://github.com/trending"
query = "What's new in metagpt"
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] == "metagpt"
resp = await research.WebBrowseAndSummarize().run(url, url2, query=query)
assert len(resp) == 2
async def mock_llm_ask(*args, **kwargs):
return "Not relevant."
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
resp = await research.WebBrowseAndSummarize().run(url, query=query)
assert len(resp) == 1
assert url in resp
assert resp[url] is None
@pytest.mark.asyncio
async def test_conduct_research(mocker):
data = None
async def mock_llm_ask(*args, **kwargs):
nonlocal data
data = f"# Research Report\n## Introduction\n{args} {kwargs}"
return data
mocker.patch("metagpt.provider.base_gpt_api.BaseGPTAPI.aask", mock_llm_ask)
content = (
"MetaGPT takes a one line requirement as input and "
"outputs user stories / competitive analysis / requirements / data structures / APIs / documents, etc."
)
resp = await research.ConductResearch().run("The application of MetaGPT", content)
assert resp == data
async def mock_collect_links_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
elif "Provide up to 4 queries related to your research topic" in prompt:
return (
'["MetaGPT use cases", "The roadmap of MetaGPT", ' '"The function of MetaGPT", "What llm MetaGPT support"]'
)
elif "sort the remaining search results" in prompt:
return "[1,2]"
return ""
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -14,13 +14,13 @@ from metagpt.schema import RunCodeContext
@pytest.mark.asyncio
async def test_run_text():
result, errs = await RunCode.run_text("result = 1 + 1")
assert result == 2
assert errs == ""
out, err = await RunCode.run_text("result = 1 + 1")
assert out == 2
assert err == ""
result, errs = await RunCode.run_text("result = 1 / 0")
assert result == ""
assert "ZeroDivisionError" in errs
out, err = await RunCode.run_text("result = 1 / 0")
assert out == ""
assert "division by zero" in err
@pytest.mark.asyncio

View file

@ -32,11 +32,11 @@ async def test_write_code():
context = CodingContext(
filename="task_filename.py", design_doc=Document(content="设计一个名为'add'的函数,该函数接受两个整数作为输入,并返回它们的和。")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
write_code = WriteCode(context=doc)
code = await write_code.run()
logger.info(code.json())
logger.info(code.model_dump_json())
# 我们不能精确地预测生成的代码,但我们可以检查某些关键字
assert "def add" in code.code_doc.content

View file

@ -29,7 +29,7 @@ async def test_write_test():
write_test = WriteTest(context=context)
context = await write_test.run()
logger.info(context.json())
logger.info(context.model_dump_json())
# We cannot exactly predict the generated test cases, but we can check if it is a string and if it is not empty
assert isinstance(context.test_doc.content, str)

View file

@ -1,36 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/6/11 21:08
@Author : alexanderwu
@File : test_milvus_store.py
"""
import random
import numpy as np
from metagpt.document_store.milvus_store import MilvusConnection, MilvusStore
from metagpt.logs import logger
book_columns = {"idx": int, "name": str, "desc": str, "emb": np.ndarray, "price": float}
book_data = [
[i for i in range(10)],
[f"book-{i}" for i in range(10)],
[f"book-desc-{i}" for i in range(10000, 10010)],
[[random.random() for _ in range(2)] for _ in range(10)],
[random.random() for _ in range(10)],
]
def test_milvus_store():
milvus_connection = MilvusConnection(alias="default", host="192.168.50.161", port="30530")
milvus_store = MilvusStore(milvus_connection)
milvus_store.drop("Book")
milvus_store.create_collection("Book", book_columns)
milvus_store.add(book_data)
milvus_store.build_index("emb")
milvus_store.load_collection()
results = milvus_store.search([[1.0, 1.0]], field="emb")
logger.info(results)
assert results

View file

@ -29,7 +29,7 @@ points = [
]
def test_milvus_store():
def test_qdrant_store():
qdrant_connection = QdrantConnection(memory=True)
vectors_config = VectorParams(size=2, distance=Distance.COSINE)
qdrant_store = QdrantStore(qdrant_connection)
@ -43,13 +43,13 @@ def test_milvus_store():
results = qdrant_store.search("Book", query=[1.0, 1.0])
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
results = qdrant_store.search("Book", query=[1.0, 1.0], return_vector=True)
assert results[0]["id"] == 2
assert results[0]["score"] == 0.999106722578389
assert results[0]["vector"] == [0.7363563179969788, 0.6765939593315125]
assert results[1]["score"] == 7
assert results[1]["id"] == 7
assert results[1]["score"] == 0.9961650411397226
assert results[1]["vector"] == [0.7662628889083862, 0.6425272226333618]
results = qdrant_store.search(

View file

@ -5,6 +5,7 @@
@Author : mashenquan
@File : test_brain_memory.py
"""
import pytest
from metagpt.config import LLMProviderEnum

View file

@ -86,31 +86,25 @@ class TestOpenAI:
def test_make_client_kwargs_without_proxy(self, config):
instance = OpenAILLM()
instance.config = config
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
instance = OpenAILLM()
instance.config = config_azure
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"}
assert "http_client" not in kwargs
assert "http_client" not in async_kwargs
def test_make_client_kwargs_with_proxy(self, config_proxy):
instance = OpenAILLM()
instance.config = config_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
instance = OpenAILLM()
instance.config = config_azure_proxy
kwargs, async_kwargs = instance._make_client_kwargs()
kwargs = instance._make_client_kwargs()
assert "http_client" in kwargs
assert "http_client" in async_kwargs

View file

@ -32,3 +32,19 @@ async def test_researcher(mocker):
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
def test_write_report(mocker):
with TemporaryDirectory() as dirname:
for i, topic in enumerate(
[
("1./metagpt"),
('2.:"metagpt'),
("3.*?<>|metagpt"),
("4. metagpt\n"),
]
):
researcher.RESEARCH_PATH = Path(dirname)
content = "# Research Report"
researcher.Researcher().write_report(topic, content)
assert (researcher.RESEARCH_PATH / f"{i+1}. metagpt.md").read_text().startswith("# Research Report")

View file

@ -8,4 +8,4 @@ from metagpt.roles.role import Role
def test_role_desc():
role = Role(profile="Sales", desc="Best Seller")
assert role.profile == "Sales"
assert role._setting.desc == "Best Seller"
assert role.desc == "Best Seller"

View file

@ -10,15 +10,20 @@ from metagpt.llm import LLM
def test_action_serialize():
action = Action()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" not in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
assert "__module_class_name" not in ser_action_dict
action = Action(name="test")
ser_action_dict = action.model_dump()
assert "test" in ser_action_dict["name"]
@pytest.mark.asyncio
async def test_action_deserialize():
action = Action()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = Action(**serialized_data)

View file

@ -10,19 +10,19 @@ from metagpt.roles.architect import Architect
def test_architect_serialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_architect_deserialize():
role = Architect()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = Architect(**ser_role_dict)
# new_role = Architect.deserialize(ser_role_dict)
assert new_role.name == "Bob"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run(with_messages="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run(with_messages="write a cli snake game")

View file

@ -20,14 +20,15 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
def test_env_serialize():
env = Environment()
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
assert "roles" in ser_env_dict
assert len(ser_env_dict["roles"]) == 0
def test_env_deserialize():
env = Environment()
env.publish_message(message=Message(content="test env serialize"))
ser_env_dict = env.dict()
ser_env_dict = env.model_dump()
new_env = Environment(**ser_env_dict)
assert len(new_env.roles) == 0
assert len(new_env.history) == 25
@ -47,16 +48,16 @@ def test_environment_serdeser():
environment.add_role(role_c)
environment.publish_message(message)
ser_data = environment.dict()
ser_data = environment.model_dump()
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
new_env: Environment = Environment(**ser_data)
assert len(new_env.roles) == 1
assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states
assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions
assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK)
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
assert list(new_env.roles.values())[0].actions == list(environment.roles.values())[0].actions
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
def test_environment_serdeser_v2():
@ -64,13 +65,13 @@ def test_environment_serdeser_v2():
pm = ProjectManager()
environment.add_role(pm)
ser_data = environment.dict()
ser_data = environment.model_dump()
new_env: Environment = Environment(**ser_data)
role = new_env.get_role(pm.profile)
assert isinstance(role, ProjectManager)
assert isinstance(role._actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks)
assert isinstance(role.actions[0], WriteTasks)
assert isinstance(list(new_env.roles.values())[0].actions[0], WriteTasks)
def test_environment_serdeser_save():
@ -85,4 +86,4 @@ def test_environment_serdeser_save():
new_env: Environment = Environment.deserialize(stg_path)
assert len(new_env.roles) == 1
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK

View file

@ -25,7 +25,7 @@ def test_memory_serdeser():
memory = Memory()
memory.add_batch([msg1, msg2])
ser_data = memory.dict()
ser_data = memory.model_dump()
new_memory = Memory(**ser_data)
assert new_memory.count() == 2
@ -35,6 +35,9 @@ def test_memory_serdeser():
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
assert new_msg2.role == "Boss"
memory = Memory(storage=[msg1, msg2], index={msg1.cause_by: [msg1], msg2.cause_by: [msg2]})
assert memory.count() == 2
def test_memory_serdeser_save():
msg1 = Message(role="User", content="write a 2048 game", cause_by=UserRequirement)

View file

@ -0,0 +1,58 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of polymorphic conditions
from pydantic import BaseModel, ConfigDict, SerializeAsAny
from metagpt.actions import Action
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOKV2,
ActionPass,
)
class ActionSubClasses(BaseModel):
actions: list[SerializeAsAny[Action]] = []
class ActionSubClassesNoSAA(BaseModel):
"""without SerializeAsAny"""
model_config = ConfigDict(arbitrary_types_allowed=True)
actions: list[Action] = []
def test_serialize_as_any():
"""test subclasses of action with different fields in ser&deser"""
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert action_subcls_dict["actions"][0]["extra_field"] == ActionOKV2().extra_field
def test_no_serialize_as_any():
# ActionOKV2 with a extra field `extra_field`
action_subcls = ActionSubClassesNoSAA(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
# without `SerializeAsAny`, it will serialize as Action
assert "extra_field" not in action_subcls_dict["actions"][0]
def test_polymorphic():
_ = ActionOKV2(
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
)
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
action_subcls_dict = action_subcls.model_dump()
assert "__module_class_name" in action_subcls_dict["actions"][0]
new_action_subcls = ActionSubClasses(**action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
assert isinstance(new_action_subcls.actions[1], ActionPass)

View file

@ -12,10 +12,10 @@ from metagpt.schema import Message
@pytest.mark.asyncio
async def test_product_manager_deserialize():
role = ProductManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProductManager(**ser_role_dict)
assert new_role.name == "Alice"
assert len(new_role._actions) == 2
assert isinstance(new_role._actions[0], Action)
await new_role._actions[0].run([Message(content="write a cli snake game")])
assert len(new_role.actions) == 2
assert isinstance(new_role.actions[0], Action)
await new_role.actions[0].run([Message(content="write a cli snake game")])

View file

@ -11,20 +11,20 @@ from metagpt.roles.project_manager import ProjectManager
def test_project_manager_serialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_project_manager_deserialize():
role = ProjectManager()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump(by_alias=True)
new_role = ProjectManager(**ser_role_dict)
assert new_role.name == "Eve"
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], Action)
assert isinstance(new_role._actions[0], WriteTasks)
# await new_role._actions[0].run(context="write a cli snake game")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], Action)
assert isinstance(new_role.actions[0], WriteTasks)
# await new_role.actions[0].run(context="write a cli snake game")

View file

@ -6,6 +6,7 @@
import shutil
import pytest
from pydantic import BaseModel, SerializeAsAny
from metagpt.actions import WriteCode
from metagpt.actions.add_requirement import UserRequirement
@ -17,48 +18,67 @@ from metagpt.roles.role import Role
from metagpt.schema import Message
from metagpt.utils.common import format_trackback_info
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
ActionOK,
RoleA,
RoleB,
RoleC,
RoleD,
serdeser_path,
)
def test_roles():
role_a = RoleA()
assert len(role_a._rc.watch) == 1
assert len(role_a.rc.watch) == 1
role_b = RoleB()
assert len(role_a._rc.watch) == 1
assert len(role_b._rc.watch) == 1
assert len(role_a.rc.watch) == 1
assert len(role_b.rc.watch) == 1
role_d = RoleD(actions=[ActionOK()])
assert len(role_d.actions) == 1
def test_role_subclasses():
"""test subclasses of role with same fields in ser&deser"""
class RoleSubClasses(BaseModel):
roles: list[SerializeAsAny[Role]] = []
role_subcls = RoleSubClasses(roles=[RoleA(), RoleB()])
role_subcls_dict = role_subcls.model_dump()
new_role_subcls = RoleSubClasses(**role_subcls_dict)
assert isinstance(new_role_subcls.roles[0], RoleA)
assert isinstance(new_role_subcls.roles[1], RoleB)
def test_role_serialize():
role = Role()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
def test_engineer_serialize():
role = Engineer()
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
assert "name" in ser_role_dict
assert "_states" in ser_role_dict
assert "_actions" in ser_role_dict
assert "states" in ser_role_dict
assert "actions" in ser_role_dict
@pytest.mark.asyncio
async def test_engineer_deserialize():
role = Engineer(use_code_review=True)
ser_role_dict = role.dict(by_alias=True)
ser_role_dict = role.model_dump()
new_role = Engineer(**ser_role_dict)
assert new_role.name == "Alex"
assert new_role.use_code_review is True
assert len(new_role._actions) == 1
assert isinstance(new_role._actions[0], WriteCode)
# await new_role._actions[0].run(context="write a cli snake game", filename="test_code")
assert len(new_role.actions) == 1
assert isinstance(new_role.actions[0], WriteCode)
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
def test_role_serdeser_save():
@ -87,10 +107,10 @@ async def test_role_serdeser_interrupt():
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
role_c.serialize(stg_path)
assert role_c._rc.memory.count() == 1
assert role_c.rc.memory.count() == 1
new_role_a: Role = Role.deserialize(stg_path)
assert new_role_a._rc.state == 1
assert new_role_a.rc.state == 1
with pytest.raises(Exception):
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))

View file

@ -4,9 +4,12 @@
from metagpt.actions.action_node import ActionNode
from metagpt.actions.write_code import WriteCode
from metagpt.schema import Message
from metagpt.schema import Document, Documents, Message
from metagpt.utils.common import any_to_str
from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
MockICMessage,
MockMessage,
)
def test_message_serdeser():
@ -15,14 +18,24 @@ def test_message_serdeser():
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
ser_data = message.dict()
ser_data = message.model_dump()
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert ser_data["instruct_content"]["class"] == "code"
new_message = Message(**ser_data)
assert new_message.cause_by == any_to_str(WriteCode)
assert new_message.cause_by in [any_to_str(WriteCode)]
assert new_message.instruct_content == ic_obj(**out_data)
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
message = Message(content="test_ic", instruct_content=MockICMessage())
ser_data = message.model_dump()
new_message = Message(**ser_data)
assert new_message.instruct_content != MockICMessage() # TODO
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
ser_data = message.model_dump()
assert "class" in ser_data["instruct_content"]
def test_message_without_postprocess():
@ -31,8 +44,9 @@ def test_message_without_postprocess():
out_data = {"field1": ["field1 value1", "field1 value2"]}
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
ser_data = message.dict()
assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]}
ser_data = message.model_dump()
assert ser_data["instruct_content"] == {}
ser_data["instruct_content"] = None
new_message = MockMessage(**ser_data)
assert new_message.instruct_content != ic_obj(**out_data)

View file

@ -4,6 +4,7 @@
import asyncio
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, Field
@ -15,15 +16,19 @@ from metagpt.roles.role import Role, RoleReactMode
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
class MockICMessage(BaseModel):
content: str = "test_ic"
class MockMessage(BaseModel):
"""to test normal dict without postprocess"""
content: str = ""
instruct_content: BaseModel = Field(default=None)
instruct_content: Optional[BaseModel] = Field(default=None)
class ActionPass(Action):
name: str = Field(default="ActionPass")
name: str = "ActionPass"
async def run(self, messages: list["Message"]) -> ActionOutput:
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
@ -35,7 +40,7 @@ class ActionPass(Action):
class ActionOK(Action):
name: str = Field(default="ActionOK")
name: str = "ActionOK"
async def run(self, messages: list["Message"]) -> str:
await asyncio.sleep(5)
@ -43,12 +48,17 @@ class ActionOK(Action):
class ActionRaise(Action):
name: str = Field(default="ActionRaise")
name: str = "ActionRaise"
async def run(self, messages: list["Message"]) -> str:
raise RuntimeError("parse error in ActionRaise")
class ActionOKV2(Action):
name: str = "ActionOKV2"
extra_field: str = "ActionOKV2 Extra Info"
class RoleA(Role):
name: str = Field(default="RoleA")
profile: str = Field(default="Role A")
@ -71,7 +81,7 @@ class RoleB(Role):
super(RoleB, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([ActionPass])
self._rc.react_mode = RoleReactMode.BY_ORDER
self.rc.react_mode = RoleReactMode.BY_ORDER
class RoleC(Role):
@ -84,5 +94,12 @@ class RoleC(Role):
super(RoleC, self).__init__(**kwargs)
self._init_actions([ActionOK, ActionRaise])
self._watch([UserRequirement])
self._rc.react_mode = RoleReactMode.BY_ORDER
self._rc.memory.ignore_id = True
self.rc.react_mode = RoleReactMode.BY_ORDER
self.rc.memory.ignore_id = True
class RoleD(Role):
name: str = Field(default="RoleD")
profile: str = Field(default="Role D")
goal: str = "RoleD's goal"
constraints: str = "RoleD's constraints"

View file

@ -33,8 +33,8 @@ def test_team_deserialize():
]
)
assert len(company.env.get_roles()) == 3
ser_company = company.dict()
new_company = Team(**ser_company)
ser_company = company.model_dump()
new_company = Team.model_validate(ser_company)
assert len(new_company.env.get_roles()) == 3
assert new_company.env.get_role(pm.profile) is not None
@ -47,6 +47,7 @@ def test_team_deserialize():
def test_team_serdeser_save():
company = Team()
company.hire([RoleC()])
stg_path = serdeser_path.joinpath("team")
@ -71,13 +72,13 @@ async def test_team_recover():
company.run_project(idea)
await company.run(n_round=4)
ser_data = company.dict()
ser_data = company.model_dump()
new_company = Team(**ser_data)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory # TODO
assert new_role_c._rc.env != role_c._rc.env # TODO
assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK
new_company.env.get_role(role_c.profile)
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
# assert new_role_c.rc.env != role_c.rc.env # TODO
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
new_company.run_project(idea)
await new_company.run(n_round=4)
@ -97,11 +98,11 @@ async def test_team_recover_save():
new_company = Team.deserialize(stg_path)
new_role_c = new_company.env.get_role(role_c.profile)
# assert new_role_c._rc.memory == role_c._rc.memory
assert new_role_c._rc.env != role_c._rc.env
# assert new_role_c.rc.memory == role_c.rc.memory
# assert new_role_c.rc.env != role_c.rc.env
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo`
assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news`
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
new_company.run_project(idea)
await new_company.run(n_round=4)
@ -116,10 +117,6 @@ async def test_team_recover_multi_roles_save():
role_a = RoleA()
role_b = RoleB()
assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA", "RoleA"}
assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB", "RoleB"}
assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"}
company = Team()
company.hire([role_a, role_b])
company.run_project(idea)
@ -130,6 +127,6 @@ async def test_team_recover_multi_roles_save():
new_company = Team.deserialize(stg_path)
new_company.run_project(idea)
assert new_company.env.get_role(role_b.profile)._rc.state == 1
assert new_company.env.get_role(role_b.profile).rc.state == 1
await new_company.run(n_round=4)

View file

@ -12,9 +12,9 @@ from metagpt.schema import CodingContext, Document
def test_write_design_serialize():
action = WriteCode()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert ser_action_dict["name"] == "WriteCode"
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
@ -22,9 +22,9 @@ async def test_write_code_deserialize():
context = CodingContext(
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
)
doc = Document(content=context.json())
doc = Document(content=context.model_dump_json())
action = WriteCode(context=doc)
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteCode(**serialized_data)
assert new_action.name == "WriteCode"

View file

@ -22,7 +22,7 @@ def div(a: int, b: int = 0):
)
action = WriteCodeReview(context=context)
serialized_data = action.dict()
serialized_data = action.model_dump()
assert serialized_data["name"] == "WriteCodeReview"
new_action = WriteCodeReview(**serialized_data)

View file

@ -10,22 +10,22 @@ from metagpt.llm import LLM
def test_write_design_serialize():
action = WriteDesign()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
def test_write_task_serialize():
action = WriteTasks()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_write_design_deserialize():
action = WriteDesign()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteDesign(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()
@ -35,7 +35,7 @@ async def test_write_design_deserialize():
@pytest.mark.asyncio
async def test_write_task_deserialize():
action = WriteTasks()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WriteTasks(**serialized_data)
assert new_action.name == "CreateTasks"
assert new_action.llm == LLM()

View file

@ -12,15 +12,15 @@ from metagpt.schema import Message
def test_action_serialize():
action = WritePRD()
ser_action_dict = action.dict()
ser_action_dict = action.model_dump()
assert "name" in ser_action_dict
# assert "llm" in ser_action_dict # not export
assert "llm" not in ser_action_dict # not export
@pytest.mark.asyncio
async def test_action_deserialize():
action = WritePRD()
serialized_data = action.dict()
serialized_data = action.model_dump()
new_action = WritePRD(**serialized_data)
assert new_action.name == ""
assert new_action.llm == LLM()

View file

@ -33,6 +33,15 @@ class MockRole(Role):
self._init_actions([MockAction()])
def test_basic():
mock_role = MockRole()
assert mock_role.subscription == {"tests.metagpt.test_role.MockRole"}
assert mock_role.rc.watch == {"metagpt.actions.add_requirement.UserRequirement"}
mock_role = MockRole(name="mock_role")
assert mock_role.subscription == {"tests.metagpt.test_role.MockRole", "mock_role"}
@pytest.mark.asyncio
async def test_react():
class Input(BaseModel):
@ -60,12 +69,12 @@ async def test_react():
name=seed.name, profile=seed.profile, goal=seed.goal, constraints=seed.constraints, desc=seed.desc
)
role.subscribe({seed.subscription})
assert role._rc.watch == {any_to_str(UserRequirement)}
assert role.rc.watch == {any_to_str(UserRequirement)}
assert role.name == seed.name
assert role.profile == seed.profile
assert role._setting.goal == seed.goal
assert role._setting.constraints == seed.constraints
assert role._setting.desc == seed.desc
assert role.goal == seed.goal
assert role.constraints == seed.constraints
assert role.desc == seed.desc
assert role.is_idle
env = Environment()
env.add_role(role)

View file

@ -31,6 +31,8 @@ def test_messages():
def test_message():
Message("a", role="v1")
m = Message(content="a", role="v1")
v = m.dump()
d = json.loads(v)
@ -74,22 +76,22 @@ def test_message_serdeser():
ic_obj = ActionNode.create_model_class("code", out_mapping)
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
message_dict = message.dict()
message_dict = message.model_dump()
assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode"
assert message_dict["instruct_content"] == {
"class": "code",
"mapping": {"field3": "(<class 'str'>, Ellipsis)", "field4": "(list[str], Ellipsis)"},
"value": {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]},
}
new_message = Message(**message_dict)
new_message = Message.model_validate(message_dict)
assert new_message.content == message.content
assert new_message.instruct_content == message.instruct_content
assert new_message.instruct_content.model_dump() == message.instruct_content.model_dump()
assert new_message.instruct_content != message.instruct_content # TODO
assert new_message.cause_by == message.cause_by
assert new_message.instruct_content.field3 == out_data["field3"]
message = Message(content="code")
message_dict = message.dict()
message_dict = message.model_dump()
new_message = Message(**message_dict)
assert new_message.instruct_content is None
assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement"

View file

@ -9,23 +9,25 @@ import pytest
from typer.testing import CliRunner
from metagpt.logs import logger
from metagpt.startup import app
from metagpt.team import Team
runner = CliRunner()
@pytest.mark.asyncio
async def test_team():
async def test_empty_team():
# FIXME: we're now using "metagpt" cli, so the entrance should be replaced instead.
company = Team()
company.run_project("做一个基础搜索引擎,可以支持知识库")
history = await company.run(n_round=5)
history = await company.run(idea="Build a simple search system. I will upload my files later.")
logger.info(history)
# def test_startup():
# args = ["Make a 2048 game"]
# result = runner.invoke(app, args)
def test_startup():
args = ["Make a cli snake game"]
result = runner.invoke(app, args)
logger.info(result)
logger.info(result.output)
if __name__ == "__main__":

View file

@ -10,4 +10,4 @@ def test_team():
company = Team()
company.hire([ProjectManager()])
assert len(company.environment.roles) == 1
assert len(company.env.roles) == 1

View file

@ -56,7 +56,7 @@ class TestGetProjectRoot:
def test_any_to_str(self):
class Input(BaseModel):
x: Any
x: Any = None
want: str
inputs = [
@ -74,7 +74,7 @@ class TestGetProjectRoot:
def test_any_to_str_set(self):
class Input(BaseModel):
x: Any
x: Any = None
want: Set
inputs = [

View file

@ -21,8 +21,8 @@ from metagpt.utils.dependency_file import DependencyFile
async def test_dependency_file():
class Input(BaseModel):
x: Union[Path, str]
deps: Optional[Set[Union[Path, str]]]
key: Optional[Union[Path, str]]
deps: Optional[Set[Union[Path, str]]] = None
key: Optional[Union[Path, str]] = None
want: Set[str]
inputs = [