mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
Merge pull request #590 from better629/new_main
use BaseModel uniformly and Ser&Deser
This commit is contained in:
commit
9229b5a7f9
53 changed files with 1663 additions and 358 deletions
|
|
@ -8,34 +8,47 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext
|
||||
|
||||
action_subclass_registry = {}
|
||||
|
||||
class Action(ABC):
|
||||
"""Action abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
name: str
|
||||
llm: LLM
|
||||
# FIXME: simplify context
|
||||
context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None
|
||||
prefix: str
|
||||
desc: str
|
||||
node: ActionNode | None
|
||||
class Action(BaseModel):
|
||||
name: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
context: dict | CodingContext | CodeSummarizeContext | TestingContext | RunCodeContext | str | None = ""
|
||||
prefix = "" # aask*时会加上prefix,作为system_message
|
||||
desc = "" # for skill manager
|
||||
# node: ActionNode = Field(default_factory=ActionNode, exclude=True)
|
||||
|
||||
def __init__(self, name: str = "", context=None, llm: LLM = None):
|
||||
self.name: str = name
|
||||
if llm is None:
|
||||
llm = LLM()
|
||||
self.llm = llm
|
||||
self.context = context
|
||||
self.prefix = "" # aask*时会加上prefix,作为system_message
|
||||
self.desc = "" # for skill manager
|
||||
self.node = None
|
||||
# builtin variables
|
||||
builtin_class_name: str = ""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# deserialize child classes dynamically for inherited `action`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
action_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
obj_dict = super(Action, self).dict(*args, **kwargs)
|
||||
if "llm" in obj_dict:
|
||||
obj_dict.pop("llm")
|
||||
return obj_dict
|
||||
|
||||
def set_prefix(self, prefix):
|
||||
"""Set prefix for later usage"""
|
||||
|
|
|
|||
|
|
@ -10,11 +10,14 @@
|
|||
"""
|
||||
import re
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeResult
|
||||
from metagpt.schema import RunCodeResult, RunCodeContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -47,8 +50,9 @@ Now you should start rewriting the code:
|
|||
|
||||
|
||||
class DebugError(Action):
|
||||
def __init__(self, name="DebugError", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
|
|||
|
|
@ -11,6 +11,9 @@
|
|||
"""
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.design_api_an import DESIGN_API_NODE
|
||||
|
|
@ -22,16 +25,13 @@ from metagpt.const import (
|
|||
SYSTEM_DESIGN_FILE_REPO,
|
||||
SYSTEM_DESIGN_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
# from metagpt.utils.get_template import get_template
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
||||
# from typing import List
|
||||
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
### Legacy Content
|
||||
{old_design}
|
||||
|
|
@ -42,15 +42,14 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WriteDesign(Action):
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
self.desc = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
"clearly and in detail."
|
||||
)
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
desc: str = "Based on the PRD, think about the system design, and design the corresponding APIs, " \
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback " \
|
||||
"clearly and in detail."
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
async def run(self, with_messages: Message, schema: str = CONFIG.prompt_schema):
|
||||
# Use `git diff` to identify which PRD documents have been modified in the `docs/prds` directory.
|
||||
prds_file_repo = CONFIG.git_repo.new_file_repository(PRDS_FILE_REPO)
|
||||
changed_prds = prds_file_repo.changed_files
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from metagpt.actions import Action
|
|||
|
||||
class FixBug(Action):
|
||||
"""Fix bug action without any implementation details"""
|
||||
name: str = "FixBug"
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -9,10 +9,15 @@
|
|||
"""
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -20,6 +25,9 @@ from metagpt.utils.git_repository import GitRepository
|
|||
|
||||
class PrepareDocuments(Action):
|
||||
"""PrepareDocuments Action: initialize project folder and add new requirements to docs/requirements.txt."""
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -9,7 +9,11 @@
|
|||
2. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
3. According to the design in Section 2.2.3.5.4 of RFC 135, add incremental iteration functionality.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
|
|
@ -21,14 +25,12 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TASK_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
# from typing import List
|
||||
|
||||
# from metagpt.utils.get_template import get_template
|
||||
|
||||
NEW_REQ_TEMPLATE = """
|
||||
### Legacy Content
|
||||
{old_tasks}
|
||||
|
|
@ -39,8 +41,9 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WriteTasks(Action):
|
||||
def __init__(self, name="CreateTasks", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@
|
|||
import subprocess
|
||||
from typing import Tuple
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeResult
|
||||
from metagpt.schema import RunCodeResult, RunCodeContext
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
|
|
@ -74,8 +77,9 @@ standard errors:
|
|||
|
||||
|
||||
class RunCode(Action):
|
||||
def __init__(self, name="RunCode", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
|
|
|
|||
|
|
@ -6,12 +6,17 @@
|
|||
@File : search_google.py
|
||||
"""
|
||||
import pydantic
|
||||
from typing import Optional, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.config import Config, CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from pydantic import root_validator
|
||||
|
||||
SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements
|
||||
1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation.
|
||||
|
|
@ -54,7 +59,6 @@ SEARCH_AND_SUMMARIZE_PROMPT = """
|
|||
|
||||
"""
|
||||
|
||||
|
||||
SEARCH_AND_SUMMARIZE_SALES_SYSTEM = """## Requirements
|
||||
1. Please summarize the latest dialogue based on the reference information (secondary) and dialogue history (primary). Do not include text that is irrelevant to the conversation.
|
||||
- The context is for reference only. If it is irrelevant to the user's search request history, please reduce its reference and usage.
|
||||
|
|
@ -101,23 +105,37 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
|
||||
|
||||
class SearchAndSummarize(Action):
|
||||
def __init__(self, name="", context=None, llm=None, engine=None, search_func=None):
|
||||
self.config = Config()
|
||||
self.engine = engine or self.config.search_engine
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[str] = CONFIG.search_engine
|
||||
search_func: Optional[str] = None
|
||||
search_engine: SearchEngine = None
|
||||
|
||||
result = ""
|
||||
|
||||
@root_validator
|
||||
def validate_engine_and_run_func(cls, values):
|
||||
engine = values.get("engine")
|
||||
search_func = values.get("search_func")
|
||||
config = Config()
|
||||
|
||||
if engine is None:
|
||||
engine = config.search_engine
|
||||
try:
|
||||
self.search_engine = SearchEngine(self.engine, run_func=search_func)
|
||||
search_engine = SearchEngine(engine=engine, run_func=search_func)
|
||||
except pydantic.ValidationError:
|
||||
self.search_engine = None
|
||||
search_engine = None
|
||||
|
||||
self.result = ""
|
||||
super().__init__(name, context, llm)
|
||||
values["search_engine"] = search_engine
|
||||
return values
|
||||
|
||||
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
|
||||
if self.search_engine is None:
|
||||
logger.warning("Configure one of SERPAPI_API_KEY, SERPER_API_KEY, GOOGLE_API_KEY to unlock full feature")
|
||||
return ""
|
||||
|
||||
|
||||
query = context[-1].content
|
||||
# logger.debug(query)
|
||||
rsp = await self.search_engine.run(query)
|
||||
|
|
@ -126,9 +144,9 @@ class SearchAndSummarize(Action):
|
|||
logger.error("empty rsp...")
|
||||
return ""
|
||||
# logger.info(rsp)
|
||||
|
||||
|
||||
system_prompt = [system_text]
|
||||
|
||||
|
||||
prompt = SEARCH_AND_SUMMARIZE_PROMPT.format(
|
||||
ROLE=self.prefix,
|
||||
CONTEXT=rsp,
|
||||
|
|
|
|||
|
|
@ -7,12 +7,15 @@
|
|||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
PROMPT_TEMPLATE = """
|
||||
|
|
@ -89,8 +92,9 @@ flowchart TB
|
|||
|
||||
|
||||
class SummarizeCode(Action):
|
||||
def __init__(self, name="SummarizeCode", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
|
|||
|
|
@ -14,8 +14,10 @@
|
|||
3. Encapsulate the input of RunCode into RunCodeContext and encapsulate the output of RunCode into
|
||||
RunCodeResult to standardize and unify parameter passing between WriteCode, RunCode, and DebugError.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from pydantic import Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
|
|
@ -27,7 +29,9 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -84,8 +88,9 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
|
||||
|
||||
class WriteCode(Action):
|
||||
def __init__(self, name="WriteCode", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
@ -126,7 +131,9 @@ class WriteCode(Action):
|
|||
logger.info(f"Writing {coding_context.filename}..")
|
||||
code = await self.write_code(prompt)
|
||||
if not coding_context.code_doc:
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=CONFIG.src_workspace)
|
||||
# avoid root_path pydantic ValidationError if use WriteCode alone
|
||||
root_path = CONFIG.src_workspace if CONFIG.src_workspace else ""
|
||||
coding_context.code_doc = Document(filename=coding_context.filename, root_path=root_path)
|
||||
coding_context.code_doc.content = code
|
||||
return coding_context
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,15 @@
|
|||
WriteCode object, rather than passing them in when calling the run function.
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -32,7 +35,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
```
|
||||
"""
|
||||
|
||||
|
||||
EXAMPLE_AND_INSTRUCTION = """
|
||||
|
||||
{format_example}
|
||||
|
|
@ -119,8 +121,9 @@ REWRITE_CODE_TEMPLATE = """
|
|||
|
||||
|
||||
class WriteCodeReview(Action):
|
||||
def __init__(self, name="WriteCodeReview", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
@ -139,9 +142,15 @@ class WriteCodeReview(Action):
|
|||
iterative_code = self.context.code_doc.content
|
||||
k = CONFIG.code_review_k_times or 1
|
||||
for i in range(k):
|
||||
format_example = FORMAT_EXAMPLE.format(filename=self.context.code_doc.filename)
|
||||
task_content = self.context.task_doc.content if self.context.task_doc else ""
|
||||
code_context = await WriteCode.get_codes(self.context.task_doc, exclude=self.context.filename)
|
||||
format_example = FORMAT_EXAMPLE.format(
|
||||
filename=self.context.code_doc.filename
|
||||
)
|
||||
task_content = (
|
||||
self.context.task_doc.content if self.context.task_doc else ""
|
||||
)
|
||||
code_context = await WriteCode.get_codes(
|
||||
self.context.task_doc, exclude=self.context.filename
|
||||
)
|
||||
context = "\n".join(
|
||||
[
|
||||
"## System Design\n" + str(self.context.design_doc) + "\n",
|
||||
|
|
@ -158,7 +167,8 @@ class WriteCodeReview(Action):
|
|||
format_example=format_example,
|
||||
)
|
||||
logger.info(
|
||||
f"Code review and rewrite {self.context.code_doc.filename}: {i+1}/{k} | {len(iterative_code)=}, {len(self.context.code_doc.content)=}"
|
||||
f"Code review and rewrite {self.context.code_doc.filename}: {i + 1}/{k} | {len(iterative_code)=}, "
|
||||
f"{len(self.context.code_doc.content)=}"
|
||||
)
|
||||
result, rewrited_code = await self.write_code_review_and_rewrite(
|
||||
context_prompt, cr_prompt, self.context.code_doc.filename
|
||||
|
|
|
|||
|
|
@ -10,10 +10,14 @@
|
|||
3. Move the document storage operations related to WritePRD from the save operation of WriteDesign.
|
||||
@Modified By: mashenquan, 2023/12/5. Move the generation logic of the project name to WritePRD.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -32,17 +36,14 @@ from metagpt.const import (
|
|||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import BugFixContext, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
# from metagpt.utils.get_template import get_template
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
||||
# from typing import List
|
||||
|
||||
|
||||
CONTEXT_TEMPLATE = """
|
||||
### Project Name
|
||||
{project_name}
|
||||
|
|
@ -64,15 +65,16 @@ NEW_REQ_TEMPLATE = """
|
|||
|
||||
|
||||
class WritePRD(Action):
|
||||
def __init__(self, name="", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
# related to the PRD. If they are related, rewrite the PRD.
|
||||
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
|
||||
requirement_doc = await docs_file_repo.get(filename=REQUIREMENT_FILENAME)
|
||||
if await self._is_bugfix(requirement_doc.content):
|
||||
if requirement_doc and await self._is_bugfix(requirement_doc.content):
|
||||
await docs_file_repo.save(filename=BUGFIX_FILENAME, content=requirement_doc.content)
|
||||
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
|
|
@ -141,7 +143,8 @@ class WritePRD(Action):
|
|||
|
||||
async def _update_prd(self, requirement_doc, prd_doc, prds_file_repo, *args, **kwargs) -> Document | None:
|
||||
if not prd_doc:
|
||||
prd = await self._run_new_requirement(requirements=[requirement_doc.content], *args, **kwargs)
|
||||
prd = await self._run_new_requirement(requirements=[requirement_doc.content if requirement_doc else ""],
|
||||
*args, **kwargs)
|
||||
new_prd_doc = Document(
|
||||
root_path=PRDS_FILE_REPO,
|
||||
filename=FileRepository.new_filename() + ".json",
|
||||
|
|
|
|||
|
|
@ -5,20 +5,28 @@
|
|||
@Author : alexanderwu
|
||||
@File : write_prd_review.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
|
||||
|
||||
class WritePRDReview(Action):
|
||||
def __init__(self, name, context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
self.prd = None
|
||||
self.desc = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
self.prd_review_prompt_template = """
|
||||
Given the following Product Requirement Document (PRD):
|
||||
{prd}
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
prd_review_prompt_template: str = """
|
||||
Given the following Product Requirement Document (PRD):
|
||||
{prd}
|
||||
|
||||
As a project manager, please review it and provide your feedback and suggestions.
|
||||
"""
|
||||
As a project manager, please review it and provide your feedback and suggestions.
|
||||
"""
|
||||
|
||||
async def run(self, prd):
|
||||
self.prd = prd
|
||||
|
|
|
|||
|
|
@ -7,6 +7,12 @@
|
|||
@Modified By: mashenquan, 2023-11-27. Following the think-act principle, solidify the task parameters when creating the
|
||||
WriteTest object, rather than passing them in when calling the run function.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
|
|
@ -36,8 +42,9 @@ you should correctly import the necessary classes based on these file locations!
|
|||
|
||||
|
||||
class WriteTest(Action):
|
||||
def __init__(self, name="WriteTest", context=None, llm=None):
|
||||
super().__init__(name, context, llm)
|
||||
name: str = "WriteTest"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
|
|||
|
|
@ -55,11 +55,14 @@ DATA_PATH = METAGPT_ROOT / "data"
|
|||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
TUTORIAL_PATH = DATA_PATH / "tutorial_docx"
|
||||
INVOICE_OCR_TABLE_PATH = DATA_PATH / "invoice_table"
|
||||
|
||||
UT_PATH = DATA_PATH / "ut"
|
||||
SWAGGER_PATH = UT_PATH / "files/api/"
|
||||
UT_PY_PATH = UT_PATH / "files/ut/"
|
||||
API_QUESTIONS_PATH = UT_PATH / "files/question/"
|
||||
|
||||
SERDESER_PATH = DEFAULT_WORKSPACE_ROOT / "storage" # TODO to store `storage` under the individual generated project
|
||||
|
||||
TMP = METAGPT_ROOT / "tmp"
|
||||
|
||||
SOURCE_ROOT = METAGPT_ROOT / "metagpt"
|
||||
|
|
|
|||
|
|
@ -12,29 +12,84 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.roles.role import Role, role_subclass_registry
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_subscribed
|
||||
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
"""环境,承载一批角色,角色可以向环境发布消息,可以被其他角色观察到
|
||||
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
|
||||
|
||||
"""
|
||||
|
||||
roles: dict[str, Role] = Field(default_factory=dict)
|
||||
members: dict[Role, Set] = Field(default_factory=dict)
|
||||
history: str = Field(default="") # For debug
|
||||
history: str = "" # For debug
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
roles = []
|
||||
for role_key, role in kwargs.get("roles", {}).items():
|
||||
current_role = kwargs["roles"][role_key]
|
||||
if isinstance(current_role, dict):
|
||||
item_class_name = current_role.get("builtin_class_name", None)
|
||||
for name, subclass in role_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_role = subclass(**current_role)
|
||||
break
|
||||
kwargs["roles"][role_key] = current_role
|
||||
roles.append(current_role)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.add_roles(roles) # add_roles again to init the Role.set_env
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = []
|
||||
for role_key, role in self.roles.items():
|
||||
roles_info.append({
|
||||
"role_class": role.__class__.__name__,
|
||||
"module_name": role.__module__,
|
||||
"role_name": role.name,
|
||||
"role_sub_tags": list(self.members.get(role))
|
||||
})
|
||||
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
|
||||
write_json_file(roles_path, roles_info)
|
||||
|
||||
history_path = stg_path.joinpath("history.json")
|
||||
write_json_file(history_path, {"content": self.history})
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Environment":
|
||||
""" stg_path: ./storage/team/environment/ """
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = read_json_file(roles_path)
|
||||
roles = []
|
||||
for role_info in roles_info:
|
||||
# role stored in ./environment/roles/{role_class}_{role_name}
|
||||
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
|
||||
role = Role.deserialize(role_path)
|
||||
roles.append(role)
|
||||
|
||||
history = read_json_file(stg_path.joinpath("history.json"))
|
||||
history = history.get("content")
|
||||
|
||||
environment = Environment(**{
|
||||
"history": history
|
||||
})
|
||||
environment.add_roles(roles)
|
||||
|
||||
return environment
|
||||
|
||||
def add_role(self, role: Role):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
|
|
|
|||
|
|
@ -4,6 +4,12 @@
|
|||
@Desc : the implement of Long-term memory
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
|
|
@ -16,12 +22,12 @@ class LongTermMemory(Memory):
|
|||
- recover memory when it staruped
|
||||
- update memory when it changed
|
||||
"""
|
||||
memory_storage: MemoryStorage = Field(default_factory=MemoryStorage)
|
||||
rc: Optional["RoleContext"] = None
|
||||
msg_from_recover: bool = False
|
||||
|
||||
def __init__(self):
|
||||
self.memory_storage: MemoryStorage = MemoryStorage()
|
||||
super().__init__()
|
||||
self.rc = None # RoleContext
|
||||
self.msg_from_recover = False
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def recover_memory(self, role_id: str, rc: "RoleContext"):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
|
|
|
|||
|
|
@ -6,20 +6,46 @@
|
|||
@File : memory.py
|
||||
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
|
||||
"""
|
||||
import copy
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, read_json_file, write_json_file
|
||||
|
||||
|
||||
class Memory:
|
||||
class Memory(BaseModel):
|
||||
"""The most basic memory: super-memory"""
|
||||
storage: list[Message] = []
|
||||
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize an empty storage list and an empty index dictionary"""
|
||||
self.storage: list[Message] = []
|
||||
self.index: dict[str, list[Message]] = defaultdict(list)
|
||||
def __init__(self, **kwargs):
|
||||
index = kwargs.get("index", {})
|
||||
new_index = defaultdict(list)
|
||||
for action_str, value in index.items():
|
||||
new_index[action_str] = [Message(**item_dict) for item_dict in value]
|
||||
kwargs["index"] = new_index
|
||||
super(Memory, self).__init__(**kwargs)
|
||||
self.index = new_index
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
""" stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/ """
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.dict()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Memory":
|
||||
""" stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
|
||||
memory_dict = read_json_file(memory_path)
|
||||
memory = Memory(**memory_dict)
|
||||
|
||||
return memory
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
|
|
@ -41,6 +67,16 @@ class Memory:
|
|||
"""Return all messages containing a specified content"""
|
||||
return [message for message in self.storage if content in message.content]
|
||||
|
||||
def delete_newest(self) -> "Message":
|
||||
""" delete the newest message from the storage"""
|
||||
if len(self.storage) > 0:
|
||||
newest_msg = self.storage.pop()
|
||||
if newest_msg.cause_by and newest_msg in self.index[newest_msg.cause_by]:
|
||||
self.index[newest_msg.cause_by].remove(newest_msg)
|
||||
else:
|
||||
newest_msg = None
|
||||
return newest_msg
|
||||
|
||||
def delete(self, message: Message):
|
||||
"""Delete the specified message from storage, while updating the index"""
|
||||
self.storage.remove(message)
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class BasePostPrecessPlugin(object):
|
|||
|
||||
def run_retry_parse_json_text(self, content: str) -> Union[dict, list]:
|
||||
"""inherited class can re-implement the function"""
|
||||
logger.debug(f"extracted json CONTENT from output:\n{content}")
|
||||
# logger.info(f"extracted json CONTENT from output:\n{content}")
|
||||
parsed_data = retry_parse_json_text(output=content) # should use output=content
|
||||
return parsed_data
|
||||
|
||||
|
|
|
|||
|
|
@ -5,10 +5,11 @@
|
|||
@Author : alexanderwu
|
||||
@File : architect.py
|
||||
"""
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.roles import Role
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
class Architect(Role):
|
||||
|
|
@ -21,18 +22,14 @@ class Architect(Role):
|
|||
goal (str): Primary goal or responsibility of the architect.
|
||||
constraints (str): Constraints or guidelines for the architect.
|
||||
"""
|
||||
name: str = "Bob"
|
||||
profile: str = "Architect"
|
||||
goal: str = "design a concise, usable, complete software system"
|
||||
constraints: str = "make sure the architecture is simple enough and use appropriate open source " \
|
||||
"libraries. Use same language as user requirement"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Bob",
|
||||
profile: str = "Architect",
|
||||
goal: str = "design a concise, usable, complete software system",
|
||||
constraints: str = "make sure the architecture is simple enough and use appropriate open source libraries."
|
||||
"Use same language as user requirement",
|
||||
) -> None:
|
||||
"""Initializes the Architect with given attributes."""
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
# Initialize actions specific to the Architect role
|
||||
self._init_actions([WriteDesign])
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,9 @@
|
|||
@Author : alexanderwu
|
||||
@File : sales.py
|
||||
"""
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.roles import Sales
|
||||
|
||||
# from metagpt.actions import SearchAndSummarize
|
||||
|
|
@ -24,5 +27,14 @@ DESC = """
|
|||
|
||||
|
||||
class CustomerService(Sales):
|
||||
def __init__(self, name="Xiaomei", profile="Human customer service", desc=DESC, store=None):
|
||||
super().__init__(name, profile, desc=desc, store=store)
|
||||
|
||||
name: str = "Xiaomei"
|
||||
profile: str = "Human customer service"
|
||||
desc: str = DESC
|
||||
|
||||
store: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
@Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results
|
||||
of SummarizeCode.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
|
@ -23,6 +24,8 @@ from collections import defaultdict
|
|||
from pathlib import Path
|
||||
from typing import Set
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, WriteCode, WriteCodeReview, WriteTasks
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
|
|
@ -66,25 +69,21 @@ class Engineer(Role):
|
|||
n_borg (int): Number of borgs.
|
||||
use_code_review (bool): Whether to use code review.
|
||||
"""
|
||||
name: str = "Alex"
|
||||
profile: str = "Engineer"
|
||||
goal: str = "write elegant, readable, extensible, efficient code"
|
||||
constraints: str = "the code should conform to standards like google-style and be modular and maintainable. " \
|
||||
"Use same language as user requirement"
|
||||
n_borg: int = 1
|
||||
use_code_review: bool = False
|
||||
code_todos: list = []
|
||||
summarize_todos = []
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Alex",
|
||||
profile: str = "Engineer",
|
||||
goal: str = "write elegant, readable, extensible, efficient code",
|
||||
constraints: str = "the code should conform to standards like google-style and be modular and maintainable. "
|
||||
"Use same language as user requirement",
|
||||
n_borg: int = 1,
|
||||
use_code_review: bool = False,
|
||||
) -> None:
|
||||
"""Initializes the Engineer role with given attributes."""
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self.use_code_review = use_code_review
|
||||
self._init_actions([WriteCode])
|
||||
self._watch([WriteTasks, SummarizeCode, WriteCode, WriteCodeReview, FixBug])
|
||||
self.code_todos = []
|
||||
self.summarize_todos = []
|
||||
self.n_borg = n_borg
|
||||
|
||||
@staticmethod
|
||||
def _parse_tasks(task_msg: Document) -> list[str]:
|
||||
|
|
@ -213,7 +212,7 @@ class Engineer(Role):
|
|||
|
||||
@staticmethod
|
||||
async def _new_coding_context(
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
) -> CodingContext:
|
||||
old_code_doc = await src_file_repo.get(filename)
|
||||
if not old_code_doc:
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add `PrepareDocuments` action according to Section 2.2.3.5.1 of RFC 135.
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.prepare_documents import PrepareDocuments
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.roles import Role
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
class ProductManager(Role):
|
||||
|
|
@ -23,24 +25,13 @@ class ProductManager(Role):
|
|||
goal (str): Goal of the product manager.
|
||||
constraints (str): Constraints or limitations for the product manager.
|
||||
"""
|
||||
name: str = "Alice"
|
||||
profile: str = "Product Manager"
|
||||
goal: str = "efficiently create a successful product that meets market demands and user expectations"
|
||||
constraints: str = "utilize the same language as the user requirements for seamless communication"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Alice",
|
||||
profile: str = "Product Manager",
|
||||
goal: str = "efficiently create a successful product that meets market demands and user expectations",
|
||||
constraints: str = "utilize the same language as the user requirements for seamless communication",
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the ProductManager role with given attributes.
|
||||
|
||||
Args:
|
||||
name (str): Name of the product manager.
|
||||
profile (str): Role profile.
|
||||
goal (str): Goal of the product manager.
|
||||
constraints (str): Constraints or limitations for the product manager.
|
||||
"""
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._init_actions([PrepareDocuments, WritePRD])
|
||||
self._watch([UserRequirement, PrepareDocuments])
|
||||
|
|
|
|||
|
|
@ -5,9 +5,11 @@
|
|||
@Author : alexanderwu
|
||||
@File : project_manager.py
|
||||
"""
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import WriteTasks
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.roles import Role
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
class ProjectManager(Role):
|
||||
|
|
@ -20,24 +22,14 @@ class ProjectManager(Role):
|
|||
goal (str): Goal of the project manager.
|
||||
constraints (str): Constraints or limitations for the project manager.
|
||||
"""
|
||||
name: str = "Eve"
|
||||
profile: str = "Project Manager"
|
||||
goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task " \
|
||||
"dependencies to start with the prerequisite modules"
|
||||
constraints: str = "use same language as user requirement"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Eve",
|
||||
profile: str = "Project Manager",
|
||||
goal: str = "break down tasks according to PRD/technical design, generate a task list, and analyze task "
|
||||
"dependencies to start with the prerequisite modules",
|
||||
constraints: str = "use same language as user requirement",
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the ProjectManager role with given attributes.
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
Args:
|
||||
name (str): Name of the project manager.
|
||||
profile (str): Role profile.
|
||||
goal (str): Goal of the project manager.
|
||||
constraints (str): Constraints or limitations for the project manager.
|
||||
"""
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self._init_actions([WriteTasks])
|
||||
self._watch([WriteDesign])
|
||||
|
|
|
|||
|
|
@ -14,7 +14,14 @@
|
|||
@Modified By: mashenquan, 2023-12-5. Enhance the workflow to navigate to WriteCode or QaEngineer based on the results
|
||||
of SummarizeCode.
|
||||
"""
|
||||
from metagpt.actions import DebugError, RunCode, WriteTest
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import (
|
||||
DebugError,
|
||||
RunCode,
|
||||
WriteTest,
|
||||
)
|
||||
from metagpt.actions.summarize_code import SummarizeCode
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -30,21 +37,21 @@ from metagpt.utils.file_repository import FileRepository
|
|||
|
||||
|
||||
class QaEngineer(Role):
|
||||
def __init__(
|
||||
self,
|
||||
name="Edward",
|
||||
profile="QaEngineer",
|
||||
goal="Write comprehensive and robust tests to ensure codes will work as expected without bugs",
|
||||
constraints="The test code you write should conform to code standard like PEP8, be modular, easy to read and maintain",
|
||||
test_round_allowed=5,
|
||||
):
|
||||
super().__init__(name, profile, goal, constraints)
|
||||
self._init_actions(
|
||||
[WriteTest]
|
||||
) # FIXME: a bit hack here, only init one action to circumvent _think() logic, will overwrite _think() in future updates
|
||||
name: str = "Edward"
|
||||
profile: str = "QaEngineer"
|
||||
goal: str = "Write comprehensive and robust tests to ensure codes will work as expected without bugs"
|
||||
constraints: str = "The test code you write should conform to code standard like PEP8, be modular, " \
|
||||
"easy to read and maintain"
|
||||
test_round_allowed: int = 5
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# FIXME: a bit hack here, only init one action to circumvent _think() logic,
|
||||
# will overwrite _think() in future updates
|
||||
self._init_actions([WriteTest])
|
||||
self._watch([SummarizeCode, WriteTest, RunCode, DebugError])
|
||||
self.test_round = 0
|
||||
self.test_round_allowed = test_round_allowed
|
||||
|
||||
async def _write_test(self, message: Message) -> None:
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
|
|
|
|||
|
|
@ -18,20 +18,26 @@
|
|||
@Modified By: mashenquan, 2023-11-4. According to the routing feature plan in Chapter 2.2.3.2 of RFC 113, the routing
|
||||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from typing import Iterable, Set, Type
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set, Type, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput, UserRequirement
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action import action_subclass_registry
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, read_json_file, write_json_file, import_class, role_raise_decorator
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
|
||||
|
|
@ -74,37 +80,20 @@ class RoleReactMode(str, Enum):
|
|||
return [item.value for item in cls]
|
||||
|
||||
|
||||
class RoleSetting(BaseModel):
|
||||
"""Role Settings"""
|
||||
|
||||
name: str
|
||||
profile: str
|
||||
goal: str
|
||||
constraints: str
|
||||
desc: str
|
||||
is_human: bool
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
|
||||
class RoleContext(BaseModel):
|
||||
"""Role Runtime Context"""
|
||||
|
||||
env: "Environment" = Field(default=None)
|
||||
msg_buffer: MessageQueue = Field(default_factory=MessageQueue) # Message Buffer with Asynchronous Updates
|
||||
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
|
||||
env: "Environment" = Field(default=None, exclude=True)
|
||||
# TODO judge if ser&deser
|
||||
msg_buffer: MessageQueue = Field(default_factory=MessageQueue,
|
||||
exclude=True) # Message Buffer with Asynchronous Updates
|
||||
memory: Memory = Field(default_factory=Memory)
|
||||
# long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
|
||||
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
|
||||
todo: Action = Field(default=None)
|
||||
todo: Action = Field(default=None, exclude=True)
|
||||
watch: set[str] = Field(default_factory=set)
|
||||
news: list[Type[Message]] = Field(default=[])
|
||||
react_mode: RoleReactMode = (
|
||||
RoleReactMode.REACT
|
||||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
news: list[Type[Message]] = Field(default=[], exclude=True) # TODO not used
|
||||
react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
class Config:
|
||||
|
|
@ -126,33 +115,146 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
class Role:
|
||||
"""Role/Agent"""
|
||||
role_subclass_registry = {}
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
"""Role/Agent"""
|
||||
name: str = ""
|
||||
profile: str = ""
|
||||
goal: str = ""
|
||||
constraints: str = ""
|
||||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
_llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
_role_id: str = ""
|
||||
_states: list[str] = []
|
||||
_actions: list[Action] = []
|
||||
_rc: RoleContext = Field(default_factory=RoleContext)
|
||||
_subscription: tuple[str] = set()
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
latest_observed_msg: Message = None # record the latest observed message when interrupted
|
||||
builtin_class_name: str = ""
|
||||
|
||||
_private_attributes = {
|
||||
"_llm": LLM() if not is_human else HumanProvider(),
|
||||
"_role_id": _role_id,
|
||||
"_states": [],
|
||||
"_actions": [],
|
||||
"_rc": RoleContext(),
|
||||
"_subscription": set()
|
||||
}
|
||||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
exclude = ["_llm"]
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for index in range(len(kwargs.get("_actions", []))):
|
||||
current_action = kwargs["_actions"][index]
|
||||
if isinstance(current_action, dict):
|
||||
item_class_name = current_action.get("builtin_class_name", None)
|
||||
for name, subclass in action_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_action = subclass(**current_action)
|
||||
break
|
||||
kwargs["_actions"][index] = current_action
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655
|
||||
self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider()
|
||||
self._private_attributes["_role_id"] = str(self._setting)
|
||||
self._private_attributes["_subscription"] = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
if key == "_rc":
|
||||
_rc = RoleContext(**kwargs["_rc"])
|
||||
object.__setattr__(self, "_rc", _rc)
|
||||
else:
|
||||
if key == "_rc":
|
||||
# # Warning, if use self._private_attributes["_rc"],
|
||||
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
|
||||
object.__setattr__(self, key, RoleContext())
|
||||
else:
|
||||
object.__setattr__(self, key, self._private_attributes[key])
|
||||
|
||||
def __init__(self, name="", profile="", goal="", constraints="", desc="", is_human=False):
|
||||
self._llm = LLM() if not is_human else HumanProvider()
|
||||
self._setting = RoleSetting(
|
||||
name=name, profile=profile, goal=goal, constraints=constraints, desc=desc, is_human=is_human
|
||||
)
|
||||
self._llm.system_prompt = self._get_prefix()
|
||||
self._states = []
|
||||
self._actions = []
|
||||
self._role_id = str(self._setting)
|
||||
self._rc = RoleContext(watch={any_to_str(UserRequirement)})
|
||||
self._subscription = {any_to_str(self), name} if name else {any_to_str(self)}
|
||||
|
||||
# deserialize child classes dynamically for inherited `role`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
role_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def _reset(self):
|
||||
self._states = []
|
||||
self._actions = []
|
||||
object.__setattr__(self, "_states", [])
|
||||
object.__setattr__(self, "_actions", [])
|
||||
|
||||
@property
|
||||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}") \
|
||||
if stg_path is None else stg_path
|
||||
|
||||
role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True})
|
||||
role_info.update({
|
||||
"role_class": self.__class__.__name__,
|
||||
"module_name": self.__module__
|
||||
})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self._rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
""" stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
role_info = read_json_file(role_info_path)
|
||||
|
||||
role_class_str = role_info.pop("role_class")
|
||||
module_name = role_info.pop("module_name")
|
||||
role_class = import_class(class_name=role_class_str, module_name=module_name)
|
||||
|
||||
role = role_class(**role_info) # initiate particular Role
|
||||
role.set_recovered(True) # set True to make a tag
|
||||
|
||||
role_memory = Memory.deserialize(stg_path)
|
||||
role.set_memory(role_memory)
|
||||
|
||||
return role
|
||||
|
||||
def _init_action_system_message(self, action: Action):
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self._rc.memory = memory
|
||||
|
||||
def init_actions(self, actions):
|
||||
self._init_actions(actions)
|
||||
|
||||
def _init_actions(self, actions):
|
||||
self._reset()
|
||||
for idx, action in enumerate(actions):
|
||||
if not isinstance(action, Action):
|
||||
i = action("", llm=self._llm)
|
||||
## 默认初始化
|
||||
i = action(name="", llm=self._llm)
|
||||
else:
|
||||
if self._setting.is_human and not isinstance(action.llm, HumanProvider):
|
||||
logger.warning(
|
||||
|
|
@ -207,7 +309,7 @@ class Role:
|
|||
def _set_state(self, state: int):
|
||||
"""Update the current state."""
|
||||
self._rc.state = state
|
||||
logger.debug(self._actions)
|
||||
logger.debug(f"actions={self._actions}, state={state}")
|
||||
self._rc.todo = self._actions[self._rc.state] if state >= 0 else None
|
||||
|
||||
def set_env(self, env: "Environment"):
|
||||
|
|
@ -217,16 +319,6 @@ class Role:
|
|||
if env:
|
||||
env.set_subscription(self, self._subscription)
|
||||
|
||||
@property
|
||||
def profile(self):
|
||||
"""Get the role description (position)"""
|
||||
return self._setting.profile
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
"""Get virtual user name"""
|
||||
return self._setting.name
|
||||
|
||||
@property
|
||||
def subscription(self) -> Set:
|
||||
"""The labels for messages to be consumed by the Role object."""
|
||||
|
|
@ -234,9 +326,14 @@ class Role:
|
|||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
if self._setting.desc:
|
||||
return self._setting.desc
|
||||
return PREFIX_TEMPLATE.format(**self._setting.dict())
|
||||
if self.desc:
|
||||
return self.desc
|
||||
return PREFIX_TEMPLATE.format(**{
|
||||
"profile": self.profile,
|
||||
"name": self.name,
|
||||
"goal": self.goal,
|
||||
"constraints": self.constraints
|
||||
})
|
||||
|
||||
async def _think(self) -> None:
|
||||
"""Think about what to do and decide on the next action"""
|
||||
|
|
@ -244,6 +341,11 @@ class Role:
|
|||
# If there is only one action, then only this one can be performed
|
||||
self._set_state(0)
|
||||
return
|
||||
if self.recovered and self._rc.state >= 0:
|
||||
self._set_state(self._rc.state) # action to run from recovered state
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return
|
||||
|
||||
prompt = self._get_prefix()
|
||||
prompt += STATE_TEMPLATE.format(
|
||||
history=self._rc.history,
|
||||
|
|
@ -251,10 +353,11 @@ class Role:
|
|||
n_states=len(self._states) - 1,
|
||||
previous_state=self._rc.state,
|
||||
)
|
||||
# print(prompt)
|
||||
|
||||
next_state = await self._llm.aask(prompt)
|
||||
next_state = extract_state_value_from_output(next_state)
|
||||
logger.debug(f"{prompt=}")
|
||||
|
||||
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)):
|
||||
logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1")
|
||||
next_state = -1
|
||||
|
|
@ -283,15 +386,30 @@ class Role:
|
|||
|
||||
return msg
|
||||
|
||||
def _find_news(self, observed: list[Message], existed: list[Message]) -> list[Message]:
|
||||
news = []
|
||||
# Warning, remove `id` here to make it work for recover
|
||||
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
|
||||
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
|
||||
for idx, new in enumerate(observed_pure):
|
||||
if new["cause_by"] in self._rc.watch and new not in existed_pure:
|
||||
news.append(observed[idx])
|
||||
return news
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
"""Prepare new messages for processing from the message buffer and other sources."""
|
||||
# Read unprocessed messages from the msg buffer.
|
||||
news = self._rc.msg_buffer.pop_all()
|
||||
if self.recovered:
|
||||
news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
else:
|
||||
self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
|
||||
# Store the read messages in your own memory to prevent duplicate processing.
|
||||
old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
self._rc.memory.add_batch(news)
|
||||
# Filter out messages of interest.
|
||||
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
|
||||
self._rc.news = self._find_news(news, old_messages)
|
||||
|
||||
# Design Rules:
|
||||
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
|
|
@ -322,7 +440,7 @@ class Role:
|
|||
Use llm to select actions in _think dynamically
|
||||
"""
|
||||
actions_taken = 0
|
||||
rsp = Message("No actions taken yet") # will be overwritten after Role _act
|
||||
rsp = Message(content="No actions taken yet") # will be overwritten after Role _act
|
||||
while actions_taken < self._rc.max_react_loop:
|
||||
# think
|
||||
await self._think()
|
||||
|
|
@ -336,7 +454,8 @@ class Role:
|
|||
|
||||
async def _act_by_order(self) -> Message:
|
||||
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
|
||||
for i in range(len(self._states)):
|
||||
start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state
|
||||
for i in range(start_idx, len(self._states)):
|
||||
self._set_state(i)
|
||||
rsp = await self._act()
|
||||
return rsp # return output from the last action
|
||||
|
|
@ -344,7 +463,7 @@ class Role:
|
|||
async def _plan_and_act(self) -> Message:
|
||||
"""first plan, then execute an action sequence, i.e. _think (of a plan) -> _act -> _act -> ... Use llm to come up with the plan dynamically."""
|
||||
# TODO: to be implemented
|
||||
return Message("")
|
||||
return Message(content="")
|
||||
|
||||
async def react(self) -> Message:
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message"""
|
||||
|
|
@ -378,16 +497,17 @@ class Role:
|
|||
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
|
||||
return self._rc.memory.get(k=k)
|
||||
|
||||
@role_raise_decorator
|
||||
async def run(self, with_message=None):
|
||||
"""Observe, and think and act based on the results of the observation"""
|
||||
if with_message:
|
||||
msg = None
|
||||
if isinstance(with_message, str):
|
||||
msg = Message(with_message)
|
||||
msg = Message(content=with_message)
|
||||
elif isinstance(with_message, Message):
|
||||
msg = with_message
|
||||
elif isinstance(with_message, list):
|
||||
msg = Message("\n".join(with_message))
|
||||
msg = Message(content="\n".join(with_message))
|
||||
if not msg.cause_by:
|
||||
msg.cause_by = UserRequirement
|
||||
self.put_message(msg)
|
||||
|
|
|
|||
|
|
@ -5,26 +5,31 @@
|
|||
@Author : alexanderwu
|
||||
@File : sales.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import SearchAndSummarize
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class Sales(Role):
|
||||
def __init__(
|
||||
self,
|
||||
name="Xiaomei",
|
||||
profile="Retail sales guide",
|
||||
desc="I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
|
||||
"will answer questions only based on the information in the knowledge base."
|
||||
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
|
||||
" I don't know, and I won't tell you that this is from the knowledge base,"
|
||||
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
|
||||
"professional guide",
|
||||
store=None,
|
||||
):
|
||||
super().__init__(name, profile, desc=desc)
|
||||
self._set_store(store)
|
||||
|
||||
name: str = "Xiaomei"
|
||||
profile: str = "Retail sales guide"
|
||||
desc: str = "I am a sales guide in retail. My name is Xiaomei. I will answer some customer questions next, and I "
|
||||
"will answer questions only based on the information in the knowledge base."
|
||||
"If I feel that you can't get the answer from the reference material, then I will directly reply that"
|
||||
" I don't know, and I won't tell you that this is from the knowledge base,"
|
||||
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
|
||||
"professional guide"
|
||||
|
||||
store: Optional[str] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._set_store(self.store)
|
||||
|
||||
def _set_store(self, store):
|
||||
if store:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@
|
|||
@Modified By: mashenquan, 2023-11-1. According to Chapter 2.2.1 and 2.2.2 of RFC 116, change the data type of
|
||||
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import ActionOutput, SearchAndSummarize
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -27,15 +30,13 @@ class Searcher(Role):
|
|||
engine (SearchEngineType): The type of search engine to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str = "Alice",
|
||||
profile: str = "Smart Assistant",
|
||||
goal: str = "Provide search services for users",
|
||||
constraints: str = "Answer is rich and complete",
|
||||
engine=SearchEngineType.SERPAPI_GOOGLE,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
name: str = Field(default="Alice")
|
||||
profile: str = Field(default="Smart Assistant")
|
||||
goal: str = "Provide search services for users"
|
||||
constraints: str = "Answer is rich and complete"
|
||||
engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the Searcher role with given attributes.
|
||||
|
||||
|
|
@ -46,8 +47,8 @@ class Searcher(Role):
|
|||
constraints (str): Constraints or limitations for the searcher.
|
||||
engine (SearchEngineType): The type of search engine to use.
|
||||
"""
|
||||
super().__init__(name, profile, goal, constraints, **kwargs)
|
||||
self._init_actions([SearchAndSummarize(engine=engine)])
|
||||
super().__init__(**kwargs)
|
||||
self._init_actions([SearchAndSummarize(engine=self.engine)])
|
||||
|
||||
def set_search_func(self, search_func):
|
||||
"""Sets a custom search function for the searcher."""
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
between actions.
|
||||
3. Add `id` to `Message` according to Section 2.2.3.1.1 of RFC 135.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -22,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar
|
||||
from typing import Dict, List, Optional, Set, Type, TypedDict, TypeVar, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -36,7 +37,9 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set, import_class
|
||||
from metagpt.utils.serialize import actionoutout_schema_to_mapping, actionoutput_mapping_to_str, \
|
||||
actionoutput_str_to_mapping
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -98,41 +101,29 @@ class Message(BaseModel):
|
|||
|
||||
id: str # According to Section 2.2.3.1.1 of RFC 135
|
||||
content: str
|
||||
instruct_content: BaseModel = Field(default=None)
|
||||
instruct_content: BaseModel = None
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = ""
|
||||
sent_from: str = ""
|
||||
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
content,
|
||||
instruct_content=None,
|
||||
role="user",
|
||||
cause_by="",
|
||||
sent_from="",
|
||||
send_to=MESSAGE_ROUTE_TO_ALL,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters not listed below will be stored as meta info, including custom parameters.
|
||||
:param content: Message content.
|
||||
:param instruct_content: Message content struct.
|
||||
:param cause_by: Message producer
|
||||
:param sent_from: Message route info tells who sent this message.
|
||||
:param send_to: Specifies the target recipient or consumer for message delivery in the environment.
|
||||
:param role: Message meta info tells who sent this message.
|
||||
"""
|
||||
super().__init__(
|
||||
id=uuid.uuid4().hex,
|
||||
content=content,
|
||||
instruct_content=instruct_content,
|
||||
role=role,
|
||||
cause_by=any_to_str(cause_by),
|
||||
sent_from=any_to_str(sent_from),
|
||||
send_to=any_to_str_set(send_to),
|
||||
**kwargs,
|
||||
)
|
||||
def __init__(self, **kwargs):
|
||||
ic = kwargs.get("instruct_content", None)
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
kwargs["instruct_content"] = ic_new
|
||||
|
||||
kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
|
||||
kwargs["cause_by"] = any_to_str(kwargs.get("cause_by",
|
||||
import_class("UserRequirement", "metagpt.actions.add_requirement")))
|
||||
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
|
||||
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
|
||||
super(Message, self).__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
|
|
@ -146,6 +137,22 @@ class Message(BaseModel):
|
|||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
""" overwrite the `dict` to dump dynamic pydantic model"""
|
||||
obj_dict = super(Message, self).dict(*args, **kwargs)
|
||||
ic = self.instruct_content
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
return obj_dict
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
return f"{self.role}: {self.content}"
|
||||
|
|
@ -196,11 +203,24 @@ class AIMessage(Message):
|
|||
super().__init__(content=content, role="assistant")
|
||||
|
||||
|
||||
class MessageQueue:
|
||||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
def __init__(self):
|
||||
self._queue = Queue()
|
||||
_queue: Queue = Field(default_factory=Queue)
|
||||
|
||||
_private_attributes = {
|
||||
"_queue": Queue()
|
||||
}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
else:
|
||||
object.__setattr__(self, key, Queue())
|
||||
|
||||
def pop(self) -> Message | None:
|
||||
"""Pop one message from the queue."""
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
import asyncio
|
||||
|
||||
import typer
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ def startup(
|
|||
help="The maximum number of times the 'SummarizeCode' action is automatically invoked, with -1 indicating "
|
||||
"unlimited. This parameter is used for debugging the workflow.",
|
||||
),
|
||||
recover_path: str = typer.Option(default=None, help="recover the project from existing serialized storage")
|
||||
):
|
||||
"""Run a startup. Be a boss."""
|
||||
from metagpt.roles import (
|
||||
|
|
@ -44,20 +46,29 @@ def startup(
|
|||
|
||||
CONFIG.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code)
|
||||
|
||||
company = Team()
|
||||
company.hire(
|
||||
[
|
||||
ProductManager(),
|
||||
Architect(),
|
||||
ProjectManager(),
|
||||
]
|
||||
)
|
||||
if not recover_path:
|
||||
company = Team()
|
||||
company.hire(
|
||||
[
|
||||
ProductManager(),
|
||||
Architect(),
|
||||
ProjectManager(),
|
||||
]
|
||||
)
|
||||
|
||||
if implement or code_review:
|
||||
company.hire([Engineer(n_borg=5, use_code_review=code_review)])
|
||||
if implement or code_review:
|
||||
company.hire([Engineer(n_borg=5, use_code_review=code_review)])
|
||||
|
||||
if run_tests:
|
||||
company.hire([QaEngineer()])
|
||||
if run_tests:
|
||||
company.hire([QaEngineer()])
|
||||
else:
|
||||
# # stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = Path(recover_path)
|
||||
if not stg_path.exists() or not str(stg_path).endswith("team"):
|
||||
raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`")
|
||||
|
||||
company = Team.deserialize(stg_path=stg_path)
|
||||
idea = company.idea # use original idea
|
||||
|
||||
company.invest(investment)
|
||||
company.run_project(idea)
|
||||
|
|
|
|||
|
|
@ -7,17 +7,20 @@
|
|||
@Modified By: mashenquan, 2023/11/27. Add an archiving operation after completing the project, as specified in
|
||||
Section 2.2.3.3 of RFC 135.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import NoMoneyException
|
||||
from metagpt.utils.common import NoMoneyException, read_json_file, write_json_file, serialize_decorator
|
||||
|
||||
|
||||
class Team(BaseModel):
|
||||
|
|
@ -33,6 +36,36 @@ class Team(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
write_json_file(team_info_path, self.dict(exclude={"env": True}))
|
||||
|
||||
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
|
||||
|
||||
@classmethod
|
||||
def recover(cls, stg_path: Path) -> "Team":
|
||||
return cls.deserialize(stg_path)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Team":
|
||||
""" stg_path = ./storage/team """
|
||||
# recover team_info
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
if not team_info_path.exists():
|
||||
raise FileNotFoundError("recover storage meta file `team_info.json` not exist, "
|
||||
"not to recover and please start a new project.")
|
||||
|
||||
team_info: dict = read_json_file(team_info_path)
|
||||
|
||||
# recover environment
|
||||
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
|
||||
team_info.update({"env": environment})
|
||||
|
||||
team = Team(**team_info)
|
||||
return team
|
||||
|
||||
def hire(self, roles: list[Role]):
|
||||
"""Hire roles to cooperate"""
|
||||
self.env.add_roles(roles)
|
||||
|
|
@ -69,6 +102,7 @@ class Team(BaseModel):
|
|||
def _save(self):
|
||||
logger.info(self.json(ensure_ascii=False))
|
||||
|
||||
@serialize_decorator
|
||||
async def run(self, n_round=3):
|
||||
"""Run company until target round or no money"""
|
||||
while n_round > 0:
|
||||
|
|
@ -76,6 +110,7 @@ class Team(BaseModel):
|
|||
n_round -= 1
|
||||
logger.debug(f"max {n_round=} left.")
|
||||
self._check_balance()
|
||||
|
||||
await self.env.run()
|
||||
if CONFIG.git_repo:
|
||||
CONFIG.git_repo.archive()
|
||||
|
|
|
|||
|
|
@ -13,15 +13,21 @@ from __future__ import annotations
|
|||
|
||||
import ast
|
||||
import contextlib
|
||||
import importlib
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import platform
|
||||
import re
|
||||
import traceback
|
||||
import typing
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import List, Tuple, Union, get_args, get_origin
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
from pydantic.json import pydantic_encoder
|
||||
from tenacity import RetryCallState, _utils
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
|
|
@ -213,7 +219,7 @@ class OutputParser:
|
|||
|
||||
if start_index != -1 and end_index != -1:
|
||||
# Extract the structure part
|
||||
structure_text = text[start_index : end_index + 1]
|
||||
structure_text = text[start_index: end_index + 1]
|
||||
|
||||
try:
|
||||
# Attempt to convert the text to a Python data type using ast.literal_eval
|
||||
|
|
@ -426,6 +432,77 @@ def general_after_log(i: "loguru.Logger", sec_format: str = "%0.3f") -> typing.C
|
|||
return log_it
|
||||
|
||||
|
||||
def read_json_file(json_file: str, encoding=None) -> list[Any]:
|
||||
if not Path(json_file).exists():
|
||||
raise FileNotFoundError(f"json_file: {json_file} not exist, return []")
|
||||
|
||||
with open(json_file, "r", encoding=encoding) as fin:
|
||||
try:
|
||||
data = json.load(fin)
|
||||
except Exception as exp:
|
||||
raise ValueError(f"read json file: {json_file} failed")
|
||||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding=None):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
module = importlib.import_module(module_name)
|
||||
a_class = getattr(module, class_name)
|
||||
return a_class
|
||||
|
||||
|
||||
def import_class_inst(class_name: str, module_name: str, *args, **kwargs) -> object:
|
||||
a_class = import_class(class_name, module_name)
|
||||
class_inst = a_class(*args, **kwargs)
|
||||
return class_inst
|
||||
|
||||
|
||||
def format_trackback_info(limit: int = 2):
|
||||
return traceback.format_exc(limit=limit)
|
||||
|
||||
|
||||
def serialize_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
result = await func(self, *args, **kwargs)
|
||||
return result
|
||||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
except Exception as exp:
|
||||
logger.error(f"Exception occurs, start to serialize the project, exp:\n{format_trackback_info()}")
|
||||
self.serialize() # Team.serialize
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def role_raise_decorator(func):
|
||||
async def wrapper(self, *args, **kwargs):
|
||||
try:
|
||||
return await func(self, *args, **kwargs)
|
||||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
|
||||
if self.latest_observed_msg:
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside
|
||||
except Exception as exp:
|
||||
if self.latest_observed_msg:
|
||||
logger.warning("There is a exception in role's execution, in order to resume, "
|
||||
"we delete the newest role communication message in the role's memory.")
|
||||
# remove role newest observed msg to make it observed again
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
raise Exception(format_trackback_info(limit=None)) # raise again to make it captured outside
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@handle_exception
|
||||
async def aread(file_path: str) -> str:
|
||||
"""Read file asynchronously."""
|
||||
|
|
|
|||
|
|
@ -253,7 +253,7 @@ def retry_parse_json_text(output: str) -> Union[list, dict]:
|
|||
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
|
||||
it's a two-layer retry cycle
|
||||
"""
|
||||
logger.debug(f"output to json decode:\n{output}")
|
||||
# logger.debug(f"output to json decode:\n{output}")
|
||||
|
||||
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
|
||||
parsed_data = CustomDecoder(strict=False).decode(output)
|
||||
|
|
|
|||
|
|
@ -4,13 +4,11 @@
|
|||
|
||||
import copy
|
||||
import pickle
|
||||
from typing import Dict, List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
||||
def actionoutout_schema_to_mapping(schema: dict) -> dict:
|
||||
"""
|
||||
directly traverse the `properties` in the first level.
|
||||
schema structure likes
|
||||
|
|
@ -35,14 +33,31 @@ def actionoutout_schema_to_mapping(schema: Dict) -> Dict:
|
|||
if property["type"] == "string":
|
||||
mapping[field] = (str, ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "string":
|
||||
mapping[field] = (List[str], ...)
|
||||
mapping[field] = (list[str], ...)
|
||||
elif property["type"] == "array" and property["items"]["type"] == "array":
|
||||
# here only consider the `List[List[str]]` situation
|
||||
mapping[field] = (List[List[str]], ...)
|
||||
# here only consider the `list[list[str]]` situation
|
||||
mapping[field] = (list[list[str]], ...)
|
||||
return mapping
|
||||
|
||||
|
||||
def serialize_message(message: Message):
|
||||
def actionoutput_mapping_to_str(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
new_mapping[key] = str(value)
|
||||
return new_mapping
|
||||
|
||||
|
||||
def actionoutput_str_to_mapping(mapping: dict) -> dict:
|
||||
new_mapping = {}
|
||||
for key, value in mapping.items():
|
||||
if value == "(<class 'str'>, Ellipsis)":
|
||||
new_mapping[key] = (str, ...)
|
||||
else:
|
||||
new_mapping[key] = eval(value) # `"'(list[str], Ellipsis)"` to `(list[str], ...)`
|
||||
return new_mapping
|
||||
|
||||
|
||||
def serialize_message(message: "Message"):
|
||||
message_cp = copy.deepcopy(message) # avoid `instruct_content` value update by reference
|
||||
ic = message_cp.instruct_content
|
||||
if ic:
|
||||
|
|
@ -56,11 +71,12 @@ def serialize_message(message: Message):
|
|||
return msg_ser
|
||||
|
||||
|
||||
def deserialize_message(message_ser: str) -> Message:
|
||||
def deserialize_message(message_ser: str) -> "Message":
|
||||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
|
|
|
|||
11
tests/metagpt/roles/test_role.py
Normal file
11
tests/metagpt/roles/test_role.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of Role
|
||||
|
||||
from metagpt.roles.role import Role
|
||||
|
||||
|
||||
def test_role_desc():
|
||||
role = Role(profile="Sales", desc="Best Seller")
|
||||
assert role.profile == "Sales"
|
||||
assert role._setting.desc == "Best Seller"
|
||||
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
4
tests/metagpt/serialize_deserialize/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
27
tests/metagpt/serialize_deserialize/test_action.py
Normal file
27
tests/metagpt/serialize_deserialize/test_action.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 11:48 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = Action()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = Action()
|
||||
serialized_data = action.dict()
|
||||
|
||||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:04 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.architect import Architect
|
||||
|
||||
|
||||
def test_architect_serialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_deserialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
new_role = Architect(**ser_role_dict)
|
||||
# new_role = Architect.deserialize(ser_role_dict)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
await new_role._actions[0].run(with_messages="write a cli snake game")
|
||||
87
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
87
tests/metagpt/serialize_deserialize/test_environment.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.environment import Environment
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleC, ActionOK, serdeser_path
|
||||
|
||||
|
||||
def test_env_serialize():
|
||||
env = Environment()
|
||||
ser_env_dict = env.dict()
|
||||
assert "roles" in ser_env_dict
|
||||
|
||||
|
||||
def test_env_deserialize():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
ser_env_dict = env.dict()
|
||||
new_env = Environment(**ser_env_dict)
|
||||
assert len(new_env.roles) == 0
|
||||
assert len(new_env.history) == 25
|
||||
|
||||
|
||||
def test_environment_serdeser():
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="prd",
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role="product manager",
|
||||
cause_by=any_to_str(UserRequirement)
|
||||
)
|
||||
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
environment.add_role(role_c)
|
||||
environment.publish_message(message)
|
||||
|
||||
ser_data = environment.dict()
|
||||
assert ser_data["roles"]["Role C"]["name"] == "RoleC"
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
assert len(new_env.roles) == 1
|
||||
|
||||
assert list(new_env.roles.values())[0]._states == list(environment.roles.values())[0]._states
|
||||
assert list(new_env.roles.values())[0]._actions == list(environment.roles.values())[0]._actions
|
||||
assert isinstance(list(environment.roles.values())[0]._actions[0], ActionOK)
|
||||
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
|
||||
|
||||
|
||||
def test_environment_serdeser_v2():
|
||||
environment = Environment()
|
||||
pm = ProjectManager()
|
||||
environment.add_role(pm)
|
||||
|
||||
ser_data = environment.dict()
|
||||
|
||||
new_env: Environment = Environment(**ser_data)
|
||||
role = new_env.get_role(pm.profile)
|
||||
assert isinstance(role, ProjectManager)
|
||||
assert isinstance(role._actions[0], WriteTasks)
|
||||
assert isinstance(list(new_env.roles.values())[0]._actions[0], WriteTasks)
|
||||
|
||||
|
||||
def test_environment_serdeser_save():
|
||||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
environment.add_role(role_c)
|
||||
environment.serialize(stg_path)
|
||||
|
||||
new_env: Environment = Environment.deserialize(stg_path)
|
||||
assert len(new_env.roles) == 1
|
||||
assert type(list(new_env.roles.values())[0]._actions[0]) == ActionOK
|
||||
69
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
69
tests/metagpt/serialize_deserialize/test_memory.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of memory
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
|
||||
|
||||
|
||||
def test_memory_serdeser():
|
||||
msg1 = Message(role="Boss",
|
||||
content="write a snake game",
|
||||
cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field2": (list[str], ...)}
|
||||
out_data = {"field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(role="Architect",
|
||||
instruct_content=ic_obj(**out_data),
|
||||
content="system design content",
|
||||
cause_by=WriteDesign)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
ser_data = memory.dict()
|
||||
|
||||
new_memory = Memory(**ser_data)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(2)[0]
|
||||
assert isinstance(new_msg2, BaseModel)
|
||||
assert isinstance(new_memory.storage[-1], BaseModel)
|
||||
assert new_memory.storage[-1].cause_by == any_to_str(WriteDesign)
|
||||
assert new_msg2.role == "Boss"
|
||||
|
||||
|
||||
def test_memory_serdeser_save():
|
||||
msg1 = Message(role="User",
|
||||
content="write a 2048 game",
|
||||
cause_by=UserRequirement)
|
||||
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("system_design", out_mapping)
|
||||
msg2 = Message(role="Architect",
|
||||
instruct_content=ic_obj(**out_data),
|
||||
content="system design content",
|
||||
cause_by=WriteDesign)
|
||||
|
||||
memory = Memory()
|
||||
memory.add_batch([msg1, msg2])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
memory.serialize(stg_path)
|
||||
assert stg_path.joinpath("memory.json").exists()
|
||||
|
||||
new_memory = Memory.deserialize(stg_path)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(1)[0]
|
||||
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
|
||||
assert new_msg2.cause_by == any_to_str(WriteDesign)
|
||||
assert len(new_memory.index) == 2
|
||||
|
||||
stg_path.joinpath("memory.json").unlink()
|
||||
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
21
tests/metagpt/serialize_deserialize/test_product_manager.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager_deserialize():
|
||||
role = ProductManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
new_role = ProductManager(**ser_role_dict)
|
||||
|
||||
assert new_role.name == "Alice"
|
||||
assert len(new_role._actions) == 2
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
await new_role._actions[0].run([Message(content="write a cli snake game")])
|
||||
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
30
tests/metagpt/serialize_deserialize/test_project_manager.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/26/2023 2:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_project_manager_serialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_deserialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
|
||||
new_role = ProjectManager(**ser_role_dict)
|
||||
assert new_role.name == "Eve"
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], Action)
|
||||
assert isinstance(new_role._actions[0], WriteTasks)
|
||||
# await new_role._actions[0].run(context="write a cli snake game")
|
||||
95
tests/metagpt/serialize_deserialize/test_role.py
Normal file
95
tests/metagpt/serialize_deserialize/test_role.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/23/2023 4:49 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import format_trackback_info
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path
|
||||
|
||||
|
||||
def test_roles():
|
||||
role_a = RoleA()
|
||||
assert len(role_a._rc.watch) == 1
|
||||
role_b = RoleB()
|
||||
assert len(role_a._rc.watch) == 1
|
||||
assert len(role_b._rc.watch) == 1
|
||||
|
||||
|
||||
def test_role_serialize():
|
||||
role = Role()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
def test_engineer_serialize():
|
||||
role = Engineer()
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "_states" in ser_role_dict
|
||||
assert "_actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_deserialize():
|
||||
role = Engineer(use_code_review=True)
|
||||
ser_role_dict = role.dict(by_alias=True)
|
||||
|
||||
new_role = Engineer(**ser_role_dict)
|
||||
assert new_role.name == "Alex"
|
||||
assert new_role.use_code_review is True
|
||||
assert len(new_role._actions) == 1
|
||||
assert isinstance(new_role._actions[0], WriteCode)
|
||||
# await new_role._actions[0].run(context="write a cli snake game", filename="test_code")
|
||||
|
||||
|
||||
def test_role_serdeser_save():
|
||||
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
pm = ProductManager()
|
||||
role_tag = f"{pm.__class__.__name__}_{pm.name}"
|
||||
stg_path = stg_path_prefix.joinpath(role_tag)
|
||||
pm.serialize(stg_path)
|
||||
|
||||
new_pm = Role.deserialize(stg_path)
|
||||
assert new_pm.name == pm.name
|
||||
assert len(new_pm.get_memories(1)) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_serdeser_interrupt():
|
||||
role_c = RoleC()
|
||||
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = SERDESER_PATH.joinpath(f"team", "environment", "roles", "{role_c.__class__.__name__}_{role_c.name}")
|
||||
try:
|
||||
await role_c.run(
|
||||
with_message=Message(content="demo", cause_by=UserRequirement)
|
||||
)
|
||||
except Exception as exp:
|
||||
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
|
||||
role_c.serialize(stg_path)
|
||||
|
||||
assert role_c._rc.memory.count() == 1
|
||||
|
||||
new_role_a: Role = Role.deserialize(stg_path)
|
||||
assert new_role_a._rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await role_c.run(
|
||||
with_message=Message(content="demo", cause_by=UserRequirement)
|
||||
)
|
||||
46
tests/metagpt/serialize_deserialize/test_schema.py
Normal file
46
tests/metagpt/serialize_deserialize/test_schema.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of schema ser&deser
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import MockMessage
|
||||
|
||||
|
||||
def test_message_serdeser():
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="code",
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role="engineer",
|
||||
cause_by=WriteCode
|
||||
)
|
||||
ser_data = message.dict()
|
||||
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert ser_data["instruct_content"]["class"] == "code"
|
||||
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.cause_by == any_to_str(WriteCode)
|
||||
assert new_message.cause_by in [any_to_str(WriteCode)]
|
||||
assert new_message.instruct_content == ic_obj(**out_data)
|
||||
|
||||
|
||||
def test_message_without_postprocess():
|
||||
""" to explain `instruct_content` should be postprocessed """
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
message = MockMessage(
|
||||
content="code",
|
||||
instruct_content=ic_obj(**out_data)
|
||||
)
|
||||
ser_data = message.dict()
|
||||
assert ser_data["instruct_content"] == {"field1": ["field1 value1", "field1 value2"]}
|
||||
|
||||
new_message = MockMessage(**ser_data)
|
||||
assert new_message.instruct_content != ic_obj(**out_data)
|
||||
88
tests/metagpt/serialize_deserialize/test_serdeser_base.py
Normal file
88
tests/metagpt/serialize_deserialize/test_serdeser_base.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : base test actions / roles used in unittest
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.roles.role import Role, RoleReactMode
|
||||
|
||||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class MockMessage(BaseModel):
|
||||
""" to test normal dict without postprocess """
|
||||
content: str = ""
|
||||
instruct_content: BaseModel = Field(default=None)
|
||||
|
||||
|
||||
class ActionPass(Action):
|
||||
name: str = Field(default="ActionPass")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> ActionOutput:
|
||||
await asyncio.sleep(5) # sleep to make other roles can watch the executed Message
|
||||
output_mapping = {
|
||||
"result": (str, ...)
|
||||
}
|
||||
pass_class = ActionNode.create_model_class("pass", output_mapping)
|
||||
pass_output = ActionOutput("ActionPass run passed", pass_class(**{"result": "pass result"}))
|
||||
|
||||
return pass_output
|
||||
|
||||
|
||||
class ActionOK(Action):
|
||||
name: str = Field(default="ActionOK")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
await asyncio.sleep(5)
|
||||
return "ok"
|
||||
|
||||
|
||||
class ActionRaise(Action):
|
||||
name: str = Field(default="ActionRaise")
|
||||
|
||||
async def run(self, messages: list["Message"]) -> str:
|
||||
raise RuntimeError("parse error in ActionRaise")
|
||||
|
||||
|
||||
class RoleA(Role):
|
||||
name: str = Field(default="RoleA")
|
||||
profile: str = Field(default="Role A")
|
||||
goal: str = "RoleA's goal"
|
||||
constraints: str = "RoleA's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleA, self).__init__(**kwargs)
|
||||
self._init_actions([ActionPass])
|
||||
self._watch([UserRequirement])
|
||||
|
||||
|
||||
class RoleB(Role):
|
||||
name: str = Field(default="RoleB")
|
||||
profile: str = Field(default="Role B")
|
||||
goal: str = "RoleB's goal"
|
||||
constraints: str = "RoleB's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleB, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self._watch([ActionPass])
|
||||
self._rc.react_mode = RoleReactMode.BY_ORDER
|
||||
|
||||
|
||||
class RoleC(Role):
|
||||
name: str = Field(default="RoleC")
|
||||
profile: str = Field(default="Role C")
|
||||
goal: str = "RoleC's goal"
|
||||
constraints: str = "RoleC's constraints"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(RoleC, self).__init__(**kwargs)
|
||||
self._init_actions([ActionOK, ActionRaise])
|
||||
self._watch([UserRequirement])
|
||||
self._rc.react_mode = RoleReactMode.BY_ORDER
|
||||
131
tests/metagpt/serialize_deserialize/test_team.py
Normal file
131
tests/metagpt/serialize_deserialize/test_team.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/27/2023 10:07 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.roles import ProjectManager, ProductManager, Architect
|
||||
from metagpt.team import Team
|
||||
from metagpt.logs import logger
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import RoleA, RoleB, RoleC, serdeser_path, ActionOK
|
||||
|
||||
|
||||
def test_team_deserialize():
|
||||
company = Team()
|
||||
|
||||
pm = ProductManager()
|
||||
arch = Architect()
|
||||
company.hire(
|
||||
[
|
||||
pm,
|
||||
arch,
|
||||
ProjectManager(),
|
||||
]
|
||||
)
|
||||
assert len(company.env.get_roles()) == 3
|
||||
ser_company = company.dict()
|
||||
new_company = Team(**ser_company)
|
||||
|
||||
assert len(new_company.env.get_roles()) == 3
|
||||
assert new_company.env.get_role(pm.profile) is not None
|
||||
|
||||
new_pm = new_company.env.get_role(pm.profile)
|
||||
assert type(new_pm) == ProductManager
|
||||
assert new_company.env.get_role(pm.profile) is not None
|
||||
assert new_company.env.get_role(arch.profile) is not None
|
||||
|
||||
|
||||
def test_team_serdeser_save():
|
||||
company = Team()
|
||||
company.hire([RoleC()])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company.serialize(stg_path=stg_path)
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
|
||||
assert len(new_company.env.roles) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover():
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
ser_data = company.dict()
|
||||
new_company = Team(**ser_data)
|
||||
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c._rc.memory == role_c._rc.memory # TODO
|
||||
assert new_role_c._rc.env != role_c._rc.env # TODO
|
||||
assert type(list(new_company.env.roles.values())[0]._actions[0]) == ActionOK
|
||||
|
||||
new_company.run_project(idea)
|
||||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_save():
|
||||
idea = "write a 2048 web game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
role_c = RoleC()
|
||||
company.hire([role_c])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c._rc.memory == role_c._rc.memory
|
||||
assert new_role_c._rc.env != role_c._rc.env
|
||||
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
|
||||
assert new_role_c._rc.todo != role_c._rc.todo # serialize exclude `_rc.todo`
|
||||
assert new_role_c._rc.news != role_c._rc.news # serialize exclude `_rc.news`
|
||||
|
||||
new_company.run_project(idea)
|
||||
await new_company.run(n_round=4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_multi_roles_save():
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
role_a = RoleA()
|
||||
role_b = RoleB()
|
||||
|
||||
assert role_a.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleA",
|
||||
"RoleA"}
|
||||
assert role_b.subscription == {"tests.metagpt.serialize_deserialize.test_serdeser_base.RoleB",
|
||||
"RoleB"}
|
||||
assert role_b._rc.watch == {"tests.metagpt.serialize_deserialize.test_serdeser_base.ActionPass"}
|
||||
|
||||
company = Team()
|
||||
company.hire([role_a, role_b])
|
||||
company.run_project(idea)
|
||||
await company.run(n_round=4)
|
||||
|
||||
logger.info("Team recovered")
|
||||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_company.run_project(idea)
|
||||
|
||||
assert new_company.env.get_role(role_b.profile)._rc.state == 1
|
||||
|
||||
await new_company.run(n_round=4)
|
||||
31
tests/metagpt/serialize_deserialize/test_write_code.py
Normal file
31
tests/metagpt/serialize_deserialize/test_write_code.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/23/2023 10:56 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteCode()
|
||||
ser_action_dict = action.dict()
|
||||
assert ser_action_dict["name"] == "WriteCode"
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deserialize():
|
||||
context = CodingContext(filename="test_code.py",
|
||||
design_doc=Document(content="write add function to calculate two numbers"))
|
||||
doc = Document(content=context.json())
|
||||
action = WriteCode(context=doc)
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteCode(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCode"
|
||||
assert new_action.llm == LLM()
|
||||
await action.run()
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of WriteCodeReview SerDeser
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCodeReview
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review_deserialize():
|
||||
code_content = """
|
||||
def div(a: int, b: int = 0):
|
||||
return a / b
|
||||
"""
|
||||
context = CodingContext(
|
||||
filename="test_op.py",
|
||||
design_doc=Document(content="divide two numbers"),
|
||||
code_doc=Document(content=code_content)
|
||||
)
|
||||
|
||||
action = WriteCodeReview(context=context)
|
||||
serialized_data = action.dict()
|
||||
assert serialized_data["name"] == "WriteCodeReview"
|
||||
|
||||
new_action = WriteCodeReview(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCodeReview"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run()
|
||||
42
tests/metagpt/serialize_deserialize/test_write_design.py
Normal file
42
tests/metagpt/serialize_deserialize/test_write_design.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 8:19 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteDesign()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
def test_write_task_serialize():
|
||||
action = WriteTasks()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_design_deserialize():
|
||||
action = WriteDesign()
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_task_deserialize():
|
||||
action = WriteTasks()
|
||||
serialized_data = action.dict()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
assert new_action.name == "CreateTasks"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
28
tests/metagpt/serialize_deserialize/test_write_prd.py
Normal file
28
tests/metagpt/serialize_deserialize/test_write_prd.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 11/22/2023 1:47 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
action = WritePRD()
|
||||
ser_action_dict = action.dict()
|
||||
assert "name" in ser_action_dict
|
||||
# assert "llm" in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = WritePRD()
|
||||
serialized_data = action.dict()
|
||||
new_action = WritePRD(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
|
|
@ -7,6 +7,7 @@
|
|||
"""
|
||||
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.environment import Environment
|
||||
|
|
@ -16,38 +17,51 @@ from metagpt.roles import Architect, ProductManager, Role
|
|||
from metagpt.schema import Message
|
||||
|
||||
|
||||
serdeser_path = Path(__file__).absolute().parent.joinpath("../data/serdeser_storage")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def env():
|
||||
return Environment()
|
||||
|
||||
|
||||
def test_add_role(env: Environment):
|
||||
role = ProductManager("Alice", "product manager", "create a new product", "limited resources")
|
||||
role = ProductManager(name="Alice",
|
||||
profile="product manager",
|
||||
goal="create a new product",
|
||||
constraints="limited resources")
|
||||
env.add_role(role)
|
||||
assert env.get_role(role.profile) == role
|
||||
|
||||
|
||||
def test_get_roles(env: Environment):
|
||||
role1 = Role("Alice", "product manager", "create a new product", "limited resources")
|
||||
role2 = Role("Bob", "engineer", "develop the new product", "short deadline")
|
||||
role1 = Role(name="Alice",
|
||||
profile="product manager",
|
||||
goal="create a new product",
|
||||
constraints="limited resources")
|
||||
role2 = Role(name="Bob",
|
||||
profile="engineer",
|
||||
goal="develop the new product",
|
||||
constraints="short deadline")
|
||||
env.add_role(role1)
|
||||
env.add_role(role2)
|
||||
roles = env.get_roles()
|
||||
assert roles == {role1.profile: role1, role2.profile: role2}
|
||||
|
||||
|
||||
def test_set_manager(env: Environment):
|
||||
manager = Manager()
|
||||
env.set_manager(manager)
|
||||
assert env.manager == manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_and_process_message(env: Environment):
|
||||
product_manager = ProductManager("Alice", "Product Manager", "做AI Native产品", "资源有限")
|
||||
architect = Architect("Bob", "Architect", "设计一个可用、高效、较低成本的系统,包括数据结构与接口", "资源有限,需要节省成本")
|
||||
product_manager = ProductManager(name="Alice",
|
||||
profile="Product Manager",
|
||||
goal="做AI Native产品",
|
||||
constraints="资源有限")
|
||||
architect = Architect(name="Bob",
|
||||
profile="Architect",
|
||||
goal="设计一个可用、高效、较低成本的系统,包括数据结构与接口",
|
||||
constraints="资源有限,需要节省成本")
|
||||
|
||||
env.add_roles([product_manager, architect])
|
||||
|
||||
env.set_manager(Manager())
|
||||
env.publish_message(Message(role="User", content="需要一个基于LLM做总结的搜索引擎", cause_by=UserRequirement))
|
||||
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@
|
|||
@Modified By: mashenquan, 2023-11-1. In line with Chapter 2.2.1 and 2.2.2 of RFC 116, introduce unit tests for
|
||||
the utilization of the new feature of `Message` class.
|
||||
"""
|
||||
import json
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.schema import AIMessage, Message, SystemMessage, UserMessage
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.utils.common import any_to_str
|
||||
|
||||
|
||||
|
|
@ -20,10 +22,10 @@ from metagpt.utils.common import any_to_str
|
|||
def test_messages():
|
||||
test_content = "test_message"
|
||||
msgs = [
|
||||
UserMessage(test_content),
|
||||
SystemMessage(test_content),
|
||||
AIMessage(test_content),
|
||||
Message(test_content, role="QA"),
|
||||
UserMessage(content=test_content),
|
||||
SystemMessage(content=test_content),
|
||||
AIMessage(content=test_content),
|
||||
Message(content=test_content, role="QA"),
|
||||
]
|
||||
text = str(msgs)
|
||||
roles = ["user", "system", "assistant", "QA"]
|
||||
|
|
@ -32,7 +34,7 @@ def test_messages():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
def test_message():
|
||||
m = Message("a", role="v1")
|
||||
m = Message(content="a", role="v1")
|
||||
v = m.dump()
|
||||
d = json.loads(v)
|
||||
assert d
|
||||
|
|
@ -45,7 +47,7 @@ def test_message():
|
|||
assert m.content == "a"
|
||||
assert m.role == "v2"
|
||||
|
||||
m = Message("a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m = Message(content="a", role="b", cause_by="c", x="d", send_to="c")
|
||||
assert m.content == "a"
|
||||
assert m.role == "b"
|
||||
assert m.send_to == {"c"}
|
||||
|
|
@ -63,12 +65,46 @@ def test_message():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
def test_routes():
|
||||
m = Message("a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m = Message(content="a", role="b", cause_by="c", x="d", send_to="c")
|
||||
m.send_to = "b"
|
||||
assert m.send_to == {"b"}
|
||||
m.send_to = {"e", Action}
|
||||
assert m.send_to == {"e", any_to_str(Action)}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
def test_message_serdeser():
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="code",
|
||||
instruct_content=ic_obj(**out_data),
|
||||
role="engineer",
|
||||
cause_by=WriteCode
|
||||
)
|
||||
message_dict = message.dict()
|
||||
assert message_dict["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert message_dict["instruct_content"] == {
|
||||
"class": "code",
|
||||
"mapping": {
|
||||
"field3": "(<class 'str'>, Ellipsis)",
|
||||
"field4": "(list[str], Ellipsis)"
|
||||
},
|
||||
"value": {
|
||||
"field3": "field3 value3",
|
||||
"field4": ["field4 value1", "field4 value2"]
|
||||
}
|
||||
}
|
||||
|
||||
new_message = Message(**message_dict)
|
||||
assert new_message.content == message.content
|
||||
assert new_message.instruct_content == message.instruct_content
|
||||
assert new_message.cause_by == message.cause_by
|
||||
assert new_message.instruct_content.field3 == out_data["field3"]
|
||||
|
||||
message = Message(content="code")
|
||||
message_dict = message.dict()
|
||||
new_message = Message(**message_dict)
|
||||
assert new_message.instruct_content is None
|
||||
assert new_message.cause_by == "metagpt.actions.add_requirement.UserRequirement"
|
||||
|
|
|
|||
13
tests/metagpt/test_team.py
Normal file
13
tests/metagpt/test_team.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of team
|
||||
|
||||
from metagpt.team import Team
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_team():
|
||||
company = Team()
|
||||
company.hire([ProjectManager()])
|
||||
|
||||
assert len(company.environment.roles) == 1
|
||||
Loading…
Add table
Add a link
Reference in a new issue