mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
Merge pull request #728 from better629/feat_simple_ser
Feat simpler serialization in one file
This commit is contained in:
commit
9a42a14c91
29 changed files with 232 additions and 316 deletions
|
|
@ -27,7 +27,7 @@ from metagpt.schema import (
|
|||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
||||
class Action(SerializationMixin, is_polymorphic_base=True):
|
||||
class Action(SerializationMixin):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@
|
|||
functionality is to be consolidated into the `Environment` class.
|
||||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
|
@ -21,7 +20,7 @@ from metagpt.context import Context
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_send_to, read_json_file, write_json_file
|
||||
from metagpt.utils.common import is_send_to
|
||||
|
||||
|
||||
class Environment(BaseModel):
|
||||
|
|
@ -42,44 +41,6 @@ class Environment(BaseModel):
|
|||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = []
|
||||
for role_key, role in self.roles.items():
|
||||
roles_info.append(
|
||||
{
|
||||
"role_class": role.__class__.__name__,
|
||||
"module_name": role.__module__,
|
||||
"role_name": role.name,
|
||||
"role_sub_tags": list(self.member_addrs.get(role)),
|
||||
}
|
||||
)
|
||||
role.serialize(stg_path=stg_path.joinpath(f"roles/{role.__class__.__name__}_{role.name}"))
|
||||
write_json_file(roles_path, roles_info)
|
||||
|
||||
history_path = stg_path.joinpath("history.json")
|
||||
write_json_file(history_path, {"content": self.history})
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Environment":
|
||||
"""stg_path: ./storage/team/environment/"""
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
roles_info = read_json_file(roles_path)
|
||||
roles = []
|
||||
for role_info in roles_info:
|
||||
# role stored in ./environment/roles/{role_class}_{role_name}
|
||||
role_path = stg_path.joinpath(f"roles/{role_info.get('role_class')}_{role_info.get('role_name')}")
|
||||
role = Role.deserialize(role_path)
|
||||
roles.append(role)
|
||||
|
||||
history = read_json_file(stg_path.joinpath("history.json"))
|
||||
history = history.get("content")
|
||||
|
||||
environment = Environment(**{"history": history})
|
||||
environment.add_roles(roles)
|
||||
|
||||
return environment
|
||||
|
||||
def add_role(self, role: Role):
|
||||
"""增加一个在当前环境的角色
|
||||
Add a role in the current environment
|
||||
|
|
|
|||
|
|
@ -7,19 +7,13 @@
|
|||
@Modified By: mashenquan, 2023-11-1. According to RFC 116: Updated the type of index key.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import (
|
||||
any_to_str,
|
||||
any_to_str_set,
|
||||
read_json_file,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
|
|
@ -29,22 +23,6 @@ class Memory(BaseModel):
|
|||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.model_dump()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Memory":
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
|
||||
memory_dict = read_json_file(memory_path)
|
||||
memory = Memory(**memory_dict)
|
||||
|
||||
return memory
|
||||
|
||||
def add(self, message: Message):
|
||||
"""Add a new message to storage, while updating the index"""
|
||||
if self.ignore_id:
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
|
@ -31,7 +30,6 @@ from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validat
|
|||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.context import Context, context
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -39,14 +37,7 @@ from metagpt.memory import Memory
|
|||
from metagpt.provider import HumanProvider
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
import_class,
|
||||
read_json_file,
|
||||
role_raise_decorator,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator
|
||||
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
|
||||
|
||||
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}. """
|
||||
|
|
@ -128,7 +119,7 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
class Role(SerializationMixin, is_polymorphic_base=True):
|
||||
class Role(SerializationMixin):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
|
@ -217,6 +208,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
if self.latest_observed_msg:
|
||||
self.recovered = True
|
||||
|
||||
def _reset(self):
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
|
@ -225,47 +219,12 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
def _setting(self):
|
||||
return f"{self.name}({self.profile})"
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = (
|
||||
SERDESER_PATH.joinpath(f"team/environment/roles/{self.__class__.__name__}_{self.name}")
|
||||
if stg_path is None
|
||||
else stg_path
|
||||
)
|
||||
|
||||
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
|
||||
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self.rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
"""stg_path = ./storage/team/environment/roles/{role_class}_{role_name}"""
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
role_info = read_json_file(role_info_path)
|
||||
|
||||
role_class_str = role_info.pop("role_class")
|
||||
module_name = role_info.pop("module_name")
|
||||
role_class = import_class(class_name=role_class_str, module_name=module_name)
|
||||
|
||||
role = role_class(**role_info) # initiate particular Role
|
||||
role.set_recovered(True) # set True to make a tag
|
||||
|
||||
role_memory = Memory.deserialize(stg_path)
|
||||
role.set_memory(role_memory)
|
||||
|
||||
return role
|
||||
|
||||
def _init_action_system_message(self, action: Action):
|
||||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def refresh_system_message(self):
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self.rc.memory = memory
|
||||
|
||||
|
|
@ -376,7 +335,7 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
|
||||
if self.recovered and self.rc.state >= 0:
|
||||
self._set_state(self.rc.state) # action to run from recovered state
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
self.recovered = False # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
|
@ -32,8 +32,9 @@ from pydantic import (
|
|||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.const import (
|
||||
MESSAGE_ROUTE_CAUSE_BY,
|
||||
|
|
@ -53,7 +54,7 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class SerializationMixin(BaseModel):
|
||||
class SerializationMixin(BaseModel, extra="forbid"):
|
||||
"""
|
||||
PolyMorphic subclasses Serialization / Deserialization Mixin
|
||||
- First of all, we need to know that pydantic is not designed for polymorphism.
|
||||
|
|
@ -68,49 +69,44 @@ class SerializationMixin(BaseModel):
|
|||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
@model_serializer(mode="wrap")
|
||||
def __serialize_with_class_type__(self, default_serializer) -> Any:
|
||||
# default serializer, then append the `__module_class_name` field and return
|
||||
ret = default_serializer(self)
|
||||
ret["__module_class_name"] = f"{self.__class__.__module__}.{self.__class__.__qualname__}"
|
||||
return ret
|
||||
|
||||
@model_validator(mode="wrap")
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
def __convert_to_real_type__(cls, value: Any, handler):
|
||||
if isinstance(value, dict) is False:
|
||||
return handler(value)
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
# it is a dict so make sure to remove the __module_class_name
|
||||
# because we don't allow extra keywords but want to ensure
|
||||
# e.g Cat.model_validate(cat.model_dump()) works
|
||||
class_full_name = value.pop("__module_class_name", None)
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
# if it's not the polymorphic base we construct via default handler
|
||||
if not cls.__is_polymorphic_base:
|
||||
if class_full_name is None:
|
||||
return handler(value)
|
||||
elif str(cls) == f"<class '{class_full_name}'>":
|
||||
return handler(value)
|
||||
else:
|
||||
# f"Trying to instantiate {class_full_name} but this is not the polymorphic base class")
|
||||
pass
|
||||
|
||||
# otherwise we lookup the correct polymorphic type and construct that
|
||||
# instead
|
||||
if class_full_name is None:
|
||||
raise ValueError("Missing __module_class_name field")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(class_full_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
# TODO could try dynamic import
|
||||
raise TypeError("Trying to instantiate {class_full_name}, which has not yet been defined!")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
|
|
@ -186,12 +182,17 @@ class Message(BaseModel):
|
|||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
if ic and isinstance(ic, dict) and "class" in ic:
|
||||
if "mapping" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
elif "module" in ic:
|
||||
# subclasses of BaseModel
|
||||
ic_obj = import_class(ic["class"], ic["module"])
|
||||
else:
|
||||
raise KeyError("missing required key to init Message.instruct_content from dict")
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
|
|
@ -216,13 +217,16 @@ class Message(BaseModel):
|
|||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
ic_type = str(type(ic))
|
||||
if "<class 'metagpt.actions.action_node" in ic_type:
|
||||
# instruct_content from AutoNode.create_model_class, for now, it's single level structure.
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
else:
|
||||
# due to instruct_content can be assigned by subclasses of BaseModel
|
||||
ic_dict = {"class": schema["title"], "module": ic.__module__, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
|
|
|
|||
|
|
@ -49,28 +49,21 @@ class Team(BaseModel):
|
|||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
|
||||
|
||||
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Team":
|
||||
"""stg_path = ./storage/team"""
|
||||
# recover team_info
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
if not team_info_path.exists():
|
||||
raise FileNotFoundError(
|
||||
"recover storage meta file `team_info.json` not exist, "
|
||||
"not to recover and please start a new project."
|
||||
"recover storage meta file `team.json` not exist, " "not to recover and please start a new project."
|
||||
)
|
||||
|
||||
team_info: dict = read_json_file(team_info_path)
|
||||
|
||||
# recover environment
|
||||
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
|
||||
team_info.update({"env": environment})
|
||||
team = Team(**team_info)
|
||||
return team
|
||||
|
||||
|
|
|
|||
|
|
@ -18,12 +18,12 @@ from metagpt.config2 import config
|
|||
|
||||
def make_sk_kernel():
|
||||
kernel = sk.Kernel()
|
||||
if llm := config.get_openai_llm():
|
||||
if llm := config.get_azure_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
AzureChatCompletion(llm.model, llm.base_url, llm.api_key),
|
||||
)
|
||||
else:
|
||||
elif llm := config.get_openai_llm():
|
||||
kernel.add_chat_service(
|
||||
"chat_completion",
|
||||
OpenAIChatCompletion(llm.model, llm.api_key),
|
||||
|
|
|
|||
|
|
@ -8,25 +8,20 @@ from metagpt.actions import Action
|
|||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_action_serialize():
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_serdeser():
|
||||
action = Action()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
assert "__module_class_name" not in ser_action_dict
|
||||
assert "__module_class_name" in ser_action_dict
|
||||
|
||||
action = Action(name="test")
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "test" in ser_action_dict["name"]
|
||||
|
||||
new_action = Action(**ser_action_dict)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
action = Action()
|
||||
serialized_data = action.model_dump()
|
||||
|
||||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == "Action"
|
||||
assert new_action.name == "test"
|
||||
assert isinstance(new_action.llm, type(LLM()))
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
|
|||
|
|
@ -8,20 +8,15 @@ from metagpt.actions.action import Action
|
|||
from metagpt.roles.architect import Architect
|
||||
|
||||
|
||||
def test_architect_serialize():
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_serdeser():
|
||||
role = Architect()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_architect_deserialize():
|
||||
role = Architect()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = Architect(**ser_role_dict)
|
||||
# new_role = Architect.deserialize(ser_role_dict)
|
||||
assert new_role.name == "Bob"
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], Action)
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import shutil
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
|
|
@ -10,7 +9,7 @@ from metagpt.actions.project_management import WriteTasks
|
|||
from metagpt.environment import Environment
|
||||
from metagpt.roles.project_manager import ProjectManager
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
ActionRaise,
|
||||
|
|
@ -19,17 +18,14 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_env_serialize():
|
||||
def test_env_serdeser():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
|
||||
ser_env_dict = env.model_dump()
|
||||
assert "roles" in ser_env_dict
|
||||
assert len(ser_env_dict["roles"]) == 0
|
||||
|
||||
|
||||
def test_env_deserialize():
|
||||
env = Environment()
|
||||
env.publish_message(message=Message(content="test env serialize"))
|
||||
ser_env_dict = env.model_dump()
|
||||
new_env = Environment(**ser_env_dict)
|
||||
assert len(new_env.roles) == 0
|
||||
assert len(new_env.history) == 25
|
||||
|
|
@ -79,12 +75,13 @@ def test_environment_serdeser_save():
|
|||
environment = Environment()
|
||||
role_c = RoleC()
|
||||
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
env_path = stg_path.joinpath("env.json")
|
||||
environment.add_role(role_c)
|
||||
environment.serialize(stg_path)
|
||||
|
||||
new_env: Environment = Environment.deserialize(stg_path)
|
||||
write_json_file(env_path, environment.model_dump())
|
||||
|
||||
env_dict = read_json_file(env_path)
|
||||
new_env: Environment = Environment(**env_dict)
|
||||
assert len(new_env.roles) == 1
|
||||
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from metagpt.actions.add_requirement import UserRequirement
|
|||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.memory.memory import Memory
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import any_to_str
|
||||
from metagpt.utils.common import any_to_str, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import serdeser_path
|
||||
|
||||
|
||||
|
|
@ -53,14 +53,14 @@ def test_memory_serdeser_save():
|
|||
memory.add_batch([msg1, msg2])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team", "environment")
|
||||
memory.serialize(stg_path)
|
||||
assert stg_path.joinpath("memory.json").exists()
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
write_json_file(memory_path, memory.model_dump())
|
||||
assert memory_path.exists()
|
||||
|
||||
new_memory = Memory.deserialize(stg_path)
|
||||
memory_dict = read_json_file(memory_path)
|
||||
new_memory = Memory(**memory_dict)
|
||||
assert new_memory.count() == 2
|
||||
new_msg2 = new_memory.get(1)[0]
|
||||
assert new_msg2.instruct_content.field1 == ["field1 value1", "field1 value2"]
|
||||
assert new_msg2.cause_by == any_to_str(WriteDesign)
|
||||
assert len(new_memory.index) == 2
|
||||
|
||||
stg_path.joinpath("memory.json").unlink()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of polymorphic conditions
|
||||
import copy
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, SerializeAsAny
|
||||
|
||||
|
|
@ -12,6 +13,8 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
|
||||
|
||||
class ActionSubClasses(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
actions: list[SerializeAsAny[Action]] = []
|
||||
|
||||
|
||||
|
|
@ -40,19 +43,21 @@ def test_no_serialize_as_any():
|
|||
|
||||
|
||||
def test_polymorphic():
|
||||
_ = ActionOKV2(
|
||||
ok_v2 = ActionOKV2(
|
||||
**{"name": "ActionOKV2", "context": "", "prefix": "", "desc": "", "extra_field": "ActionOKV2 Extra Info"}
|
||||
)
|
||||
|
||||
action_subcls = ActionSubClasses(actions=[ActionOKV2(), ActionPass()])
|
||||
action_subcls_dict = action_subcls.model_dump()
|
||||
action_subcls_dict2 = copy.deepcopy(action_subcls_dict)
|
||||
|
||||
assert "__module_class_name" in action_subcls_dict["actions"][0]
|
||||
|
||||
new_action_subcls = ActionSubClasses(**action_subcls_dict)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert new_action_subcls.actions[0].extra_field == ok_v2.extra_field
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict)
|
||||
new_action_subcls = ActionSubClasses.model_validate(action_subcls_dict2)
|
||||
assert isinstance(new_action_subcls.actions[0], ActionOKV2)
|
||||
assert isinstance(new_action_subcls.actions[1], ActionPass)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.actions.prepare_interview import PrepareInterview
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
async def test_action_serdeser():
|
||||
action = PrepareInterview()
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "PrepareInterview"
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from metagpt.schema import Message
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_product_manager_deserialize(new_filename):
|
||||
async def test_product_manager_serdeser(new_filename):
|
||||
role = ProductManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
new_role = ProductManager(**ser_role_dict)
|
||||
|
|
|
|||
|
|
@ -9,19 +9,14 @@ from metagpt.actions.project_management import WriteTasks
|
|||
from metagpt.roles.project_manager import ProjectManager
|
||||
|
||||
|
||||
def test_project_manager_serialize():
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_serdeser():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_project_manager_deserialize():
|
||||
role = ProjectManager()
|
||||
ser_role_dict = role.model_dump(by_alias=True)
|
||||
|
||||
new_role = ProjectManager(**ser_role_dict)
|
||||
assert new_role.name == "Eve"
|
||||
assert len(new_role.actions) == 1
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from metagpt.roles.researcher import Researcher
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_assistant_deserialize():
|
||||
async def test_tutorial_assistant_serdeser():
|
||||
role = Researcher()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
|
|
|
|||
|
|
@ -10,13 +10,12 @@ from pydantic import BaseModel, SerializeAsAny
|
|||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.engineer import Engineer
|
||||
from metagpt.roles.product_manager import ProductManager
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import format_trackback_info
|
||||
from metagpt.utils.common import format_trackback_info, read_json_file, write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
|
|
@ -60,37 +59,31 @@ def test_role_serialize():
|
|||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
def test_engineer_serialize():
|
||||
def test_engineer_serdeser():
|
||||
role = Engineer()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
assert "states" in ser_role_dict
|
||||
assert "actions" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_engineer_deserialize():
|
||||
role = Engineer(use_code_review=True)
|
||||
ser_role_dict = role.model_dump()
|
||||
|
||||
new_role = Engineer(**ser_role_dict)
|
||||
assert new_role.name == "Alex"
|
||||
assert new_role.use_code_review is True
|
||||
assert new_role.use_code_review is False
|
||||
assert len(new_role.actions) == 1
|
||||
assert isinstance(new_role.actions[0], WriteCode)
|
||||
# await new_role.actions[0].run(context="write a cli snake game", filename="test_code")
|
||||
|
||||
|
||||
def test_role_serdeser_save():
|
||||
stg_path_prefix = serdeser_path.joinpath("team", "environment", "roles")
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
pm = ProductManager()
|
||||
role_tag = f"{pm.__class__.__name__}_{pm.name}"
|
||||
stg_path = stg_path_prefix.joinpath(role_tag)
|
||||
pm.serialize(stg_path)
|
||||
|
||||
new_pm = Role.deserialize(stg_path)
|
||||
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{pm.__class__.__name__}_{pm.name}")
|
||||
role_path = stg_path.joinpath("role.json")
|
||||
write_json_file(role_path, pm.model_dump())
|
||||
|
||||
role_dict = read_json_file(role_path)
|
||||
new_pm = ProductManager(**role_dict)
|
||||
assert new_pm.name == pm.name
|
||||
assert len(new_pm.get_memories(1)) == 0
|
||||
|
||||
|
|
@ -98,22 +91,24 @@ def test_role_serdeser_save():
|
|||
@pytest.mark.asyncio
|
||||
async def test_role_serdeser_interrupt():
|
||||
role_c = RoleC()
|
||||
shutil.rmtree(SERDESER_PATH.joinpath("team"), ignore_errors=True)
|
||||
shutil.rmtree(serdeser_path.joinpath("team"), ignore_errors=True)
|
||||
|
||||
stg_path = SERDESER_PATH.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
stg_path = serdeser_path.joinpath("team", "environment", "roles", f"{role_c.__class__.__name__}_{role_c.name}")
|
||||
role_path = stg_path.joinpath("role.json")
|
||||
try:
|
||||
await role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
except Exception:
|
||||
logger.error(f"Exception in `role_a.run`, detail: {format_trackback_info()}")
|
||||
role_c.serialize(stg_path)
|
||||
logger.error(f"Exception in `role_c.run`, detail: {format_trackback_info()}")
|
||||
write_json_file(role_path, role_c.model_dump())
|
||||
|
||||
assert role_c.rc.memory.count() == 1
|
||||
|
||||
new_role_a: Role = Role.deserialize(stg_path)
|
||||
assert new_role_a.rc.state == 1
|
||||
role_dict = read_json_file(role_path)
|
||||
new_role_c: Role = RoleC(**role_dict)
|
||||
assert new_role_c.rc.state == 1
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await new_role_a.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
await new_role_c.run(with_message=Message(content="demo", cause_by=UserRequirement))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : unittest of schema ser&deser
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.schema import CodingContext, Document, Documents, Message, TestingContext
|
||||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
MockICMessage,
|
||||
|
|
@ -12,12 +13,16 @@ from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
|||
)
|
||||
|
||||
|
||||
def test_message_serdeser():
|
||||
def test_message_serdeser_from_create_model():
|
||||
with pytest.raises(KeyError):
|
||||
_ = Message(content="code", instruct_content={"class": "test", "key": "value"})
|
||||
|
||||
out_mapping = {"field3": (str, ...), "field4": (list[str], ...)}
|
||||
out_data = {"field3": "field3 value3", "field4": ["field4 value1", "field4 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
ic_inst = ic_obj(**out_data)
|
||||
|
||||
message = Message(content="code", instruct_content=ic_obj(**out_data), role="engineer", cause_by=WriteCode)
|
||||
message = Message(content="code", instruct_content=ic_inst, role="engineer", cause_by=WriteCode)
|
||||
ser_data = message.model_dump()
|
||||
assert ser_data["cause_by"] == "metagpt.actions.write_code.WriteCode"
|
||||
assert ser_data["instruct_content"]["class"] == "code"
|
||||
|
|
@ -25,28 +30,67 @@ def test_message_serdeser():
|
|||
new_message = Message(**ser_data)
|
||||
assert new_message.cause_by == any_to_str(WriteCode)
|
||||
assert new_message.cause_by in [any_to_str(WriteCode)]
|
||||
|
||||
assert new_message.instruct_content != ic_obj(**out_data) # TODO find why `!=`
|
||||
assert new_message.instruct_content != ic_inst
|
||||
assert new_message.instruct_content.model_dump() == ic_obj(**out_data).model_dump()
|
||||
|
||||
message = Message(content="test_ic", instruct_content=MockICMessage())
|
||||
mock_msg = MockMessage()
|
||||
message = Message(content="test_ic", instruct_content=mock_msg)
|
||||
ser_data = message.model_dump()
|
||||
new_message = Message(**ser_data)
|
||||
assert new_message.instruct_content != MockICMessage() # TODO
|
||||
|
||||
message = Message(content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")}))
|
||||
ser_data = message.model_dump()
|
||||
assert "class" in ser_data["instruct_content"]
|
||||
assert new_message.instruct_content == mock_msg
|
||||
|
||||
|
||||
def test_message_without_postprocess():
|
||||
"""to explain `instruct_content` should be postprocessed"""
|
||||
"""to explain `instruct_content` from `create_model_class` should be postprocessed"""
|
||||
out_mapping = {"field1": (list[str], ...)}
|
||||
out_data = {"field1": ["field1 value1", "field1 value2"]}
|
||||
ic_obj = ActionNode.create_model_class("code", out_mapping)
|
||||
message = MockMessage(content="code", instruct_content=ic_obj(**out_data))
|
||||
message = MockICMessage(content="code", instruct_content=ic_obj(**out_data))
|
||||
ser_data = message.model_dump()
|
||||
assert ser_data["instruct_content"] == {}
|
||||
|
||||
ser_data["instruct_content"] = None
|
||||
new_message = MockMessage(**ser_data)
|
||||
new_message = MockICMessage(**ser_data)
|
||||
assert new_message.instruct_content != ic_obj(**out_data)
|
||||
|
||||
|
||||
def test_message_serdeser_from_basecontext():
|
||||
doc_msg = Message(content="test_document", instruct_content=Document(content="test doc"))
|
||||
ser_data = doc_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["value"]["content"] == "test doc"
|
||||
assert ser_data["instruct_content"]["value"]["filename"] == ""
|
||||
|
||||
docs_msg = Message(
|
||||
content="test_documents", instruct_content=Documents(docs={"doc1": Document(content="test doc")})
|
||||
)
|
||||
ser_data = docs_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["class"] == "Documents"
|
||||
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["content"] == "test doc"
|
||||
assert ser_data["instruct_content"]["value"]["docs"]["doc1"]["filename"] == ""
|
||||
|
||||
code_ctxt = CodingContext(
|
||||
filename="game.py",
|
||||
design_doc=Document(root_path="docs/system_design", filename="xx.json", content="xxx"),
|
||||
task_doc=Document(root_path="docs/tasks", filename="xx.json", content="xxx"),
|
||||
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
|
||||
)
|
||||
code_ctxt_msg = Message(content="coding_context", instruct_content=code_ctxt)
|
||||
ser_data = code_ctxt_msg.model_dump()
|
||||
assert ser_data["instruct_content"]["class"] == "CodingContext"
|
||||
|
||||
new_code_ctxt_msg = Message(**ser_data)
|
||||
assert new_code_ctxt_msg.instruct_content == code_ctxt
|
||||
assert new_code_ctxt_msg.instruct_content.code_doc.filename == "game.py"
|
||||
|
||||
testing_ctxt = TestingContext(
|
||||
filename="test.py",
|
||||
code_doc=Document(root_path="xxx", filename="game.py", content="xxx"),
|
||||
test_doc=Document(root_path="docs/tests", filename="test.py", content="xxx"),
|
||||
)
|
||||
testing_ctxt_msg = Message(content="testing_context", instruct_content=testing_ctxt)
|
||||
ser_data = testing_ctxt_msg.model_dump()
|
||||
new_testing_ctxt_msg = Message(**ser_data)
|
||||
assert new_testing_ctxt_msg.instruct_content == testing_ctxt
|
||||
assert new_testing_ctxt_msg.instruct_content.test_doc.filename == "test.py"
|
||||
|
|
|
|||
|
|
@ -16,14 +16,14 @@ from metagpt.roles.role import Role, RoleReactMode
|
|||
serdeser_path = Path(__file__).absolute().parent.joinpath("..", "..", "data", "serdeser_storage")
|
||||
|
||||
|
||||
class MockICMessage(BaseModel):
|
||||
content: str = "test_ic"
|
||||
|
||||
|
||||
class MockMessage(BaseModel):
|
||||
content: str = "test_msg"
|
||||
|
||||
|
||||
class MockICMessage(BaseModel):
|
||||
"""to test normal dict without postprocess"""
|
||||
|
||||
content: str = ""
|
||||
content: str = "test_ic_msg"
|
||||
instruct_content: Optional[BaseModel] = Field(default=None)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,15 +5,8 @@ import pytest
|
|||
from metagpt.roles.sk_agent import SkAgent
|
||||
|
||||
|
||||
def test_sk_agent_serialize():
|
||||
role = SkAgent()
|
||||
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
|
||||
assert "name" in ser_role_dict
|
||||
assert "planner" in ser_role_dict
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sk_agent_deserialize():
|
||||
async def test_sk_agent_serdeser():
|
||||
role = SkAgent()
|
||||
ser_role_dict = role.model_dump(exclude={"import_semantic_skill_from_directory", "import_skill"})
|
||||
assert "name" in ser_role_dict
|
||||
|
|
|
|||
|
|
@ -4,13 +4,14 @@
|
|||
# @Desc :
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.const import SERDESER_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager
|
||||
from metagpt.team import Team
|
||||
from metagpt.utils.common import write_json_file
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
RoleA,
|
||||
|
|
@ -45,9 +46,16 @@ def test_team_deserialize():
|
|||
assert new_company.env.get_role(arch.profile) is not None
|
||||
|
||||
|
||||
def test_team_serdeser_save():
|
||||
company = Team()
|
||||
def mock_team_serialize(self, stg_path: Path = serdeser_path.joinpath("team")):
|
||||
team_info_path = stg_path.joinpath("team.json")
|
||||
|
||||
write_json_file(team_info_path, self.model_dump())
|
||||
|
||||
|
||||
def test_team_serdeser_save(mocker):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
company = Team()
|
||||
company.hire([RoleC()])
|
||||
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
|
|
@ -61,9 +69,11 @@ def test_team_serdeser_save():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover():
|
||||
async def test_team_recover(mocker):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
|
|
@ -75,9 +85,9 @@ async def test_team_recover():
|
|||
ser_data = company.model_dump()
|
||||
new_company = Team(**ser_data)
|
||||
|
||||
new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c.rc.memory == role_c.rc.memory # TODO
|
||||
# assert new_role_c.rc.env != role_c.rc.env # TODO
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
assert new_role_c.rc.memory == role_c.rc.memory
|
||||
assert new_role_c.rc.env != role_c.rc.env
|
||||
assert type(list(new_company.env.roles.values())[0].actions[0]) == ActionOK
|
||||
|
||||
new_company.run_project(idea)
|
||||
|
|
@ -85,9 +95,11 @@ async def test_team_recover():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_save():
|
||||
async def test_team_recover_save(mocker):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a 2048 web game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
company = Team()
|
||||
|
|
@ -98,8 +110,8 @@ async def test_team_recover_save():
|
|||
|
||||
new_company = Team.deserialize(stg_path)
|
||||
new_role_c = new_company.env.get_role(role_c.profile)
|
||||
# assert new_role_c.rc.memory == role_c.rc.memory
|
||||
# assert new_role_c.rc.env != role_c.rc.env
|
||||
assert new_role_c.rc.memory == role_c.rc.memory
|
||||
assert new_role_c.rc.env != role_c.rc.env
|
||||
assert new_role_c.recovered != role_c.recovered # here cause previous ut is `!=`
|
||||
assert new_role_c.rc.todo != role_c.rc.todo # serialize exclude `rc.todo`
|
||||
assert new_role_c.rc.news != role_c.rc.news # serialize exclude `rc.news`
|
||||
|
|
@ -109,9 +121,11 @@ async def test_team_recover_save():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_recover_multi_roles_save():
|
||||
async def test_team_recover_multi_roles_save(mocker):
|
||||
mocker.patch("metagpt.team.Team.serialize", mock_team_serialize)
|
||||
|
||||
idea = "write a snake game"
|
||||
stg_path = SERDESER_PATH.joinpath("team")
|
||||
stg_path = serdeser_path.joinpath("team")
|
||||
shutil.rmtree(stg_path, ignore_errors=True)
|
||||
|
||||
role_a = RoleA()
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from metagpt.roles.tutorial_assistant import TutorialAssistant
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tutorial_assistant_deserialize():
|
||||
async def test_tutorial_assistant_serdeser():
|
||||
role = TutorialAssistant()
|
||||
ser_role_dict = role.model_dump()
|
||||
assert "name" in ser_role_dict
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from metagpt.actions import WriteCode
|
|||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
def test_write_design_serdeser():
|
||||
action = WriteCode()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert ser_action_dict["name"] == "WriteCode"
|
||||
|
|
@ -17,7 +17,7 @@ def test_write_design_serialize():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_deserialize():
|
||||
async def test_write_code_serdeser():
|
||||
context = CodingContext(
|
||||
filename="test_code.py", design_doc=Document(content="write add function to calculate two numbers")
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from metagpt.schema import CodingContext, Document
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_code_review_deserialize():
|
||||
async def test_write_code_review_serdeser():
|
||||
code_content = """
|
||||
def div(a: int, b: int = 0):
|
||||
return a / b
|
||||
|
|
|
|||
|
|
@ -7,33 +7,25 @@ import pytest
|
|||
from metagpt.actions import WriteDesign, WriteTasks
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
action = WriteDesign()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
def test_write_task_serialize():
|
||||
action = WriteTasks()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_design_deserialize():
|
||||
async def test_write_design_serialize():
|
||||
action = WriteDesign()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
new_action = WriteDesign(**ser_action_dict)
|
||||
assert new_action.name == "WriteDesign"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_task_deserialize():
|
||||
async def test_write_task_serialize():
|
||||
action = WriteTasks()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
new_action = WriteTasks(**ser_action_dict)
|
||||
assert new_action.name == "WriteTasks"
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class Person:
|
|||
],
|
||||
ids=["google", "numpy", "sphinx"],
|
||||
)
|
||||
async def test_action_deserialize(style: str, part: str):
|
||||
async def test_action_serdeser(style: str, part: str):
|
||||
action = WriteDocstring()
|
||||
serialized_data = action.model_dump()
|
||||
|
||||
|
|
|
|||
|
|
@ -9,18 +9,14 @@ from metagpt.actions import WritePRD
|
|||
from metagpt.schema import Message
|
||||
|
||||
|
||||
def test_action_serialize(new_filename):
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_serdeser(new_filename):
|
||||
action = WritePRD()
|
||||
ser_action_dict = action.model_dump()
|
||||
assert "name" in ser_action_dict
|
||||
assert "llm" not in ser_action_dict # not export
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize(new_filename):
|
||||
action = WritePRD()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WritePRD(**serialized_data)
|
||||
new_action = WritePRD(**ser_action_dict)
|
||||
assert new_action.name == "WritePRD"
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ CONTEXT = """
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_action_deserialize():
|
||||
async def test_action_serdeser():
|
||||
action = WriteReview()
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteReview"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from metagpt.actions.write_tutorial import WriteContent, WriteDirectory
|
|||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(("language", "topic"), [("English", "Write a tutorial about Python")])
|
||||
async def test_write_directory_deserialize(language: str, topic: str):
|
||||
async def test_write_directory_serdeser(language: str, topic: str):
|
||||
action = WriteDirectory()
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteDirectory"
|
||||
|
|
@ -30,7 +30,7 @@ async def test_write_directory_deserialize(language: str, topic: str):
|
|||
("language", "topic", "directory"),
|
||||
[("English", "Write a tutorial about Python", {"Introduction": ["What is Python?", "Why learn Python?"]})],
|
||||
)
|
||||
async def test_write_content_deserialize(language: str, topic: str, directory: Dict):
|
||||
async def test_write_content_serdeser(language: str, topic: str, directory: Dict):
|
||||
action = WriteContent(language=language, directory=directory)
|
||||
serialized_data = action.model_dump()
|
||||
assert serialized_data["name"] == "WriteContent"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue