add serializers to support serialization and deserialization.

This commit is contained in:
seehi 2024-07-10 10:24:04 +08:00
parent 086ef5e805
commit b5934a412b
19 changed files with 234 additions and 144 deletions

View file

@ -78,7 +78,6 @@ exp_pool:
enable_read: false
enable_write: false
persist_path: .chroma_exp_data # The directory.
init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -19,6 +19,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.serializers import ActionNodeSerializer
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -466,29 +467,7 @@ class ActionNode:
return self
@classmethod
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
"""Customized deserialization, it will be triggered when a perfect experience is found.
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
"""
class InstructContent:
def __init__(self, json_data):
self.json_data = json_data
def model_dump_json(self):
return self.json_data
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
action_node.instruct_content = InstructContent(serialized_data)
return action_node
@exp_cache(
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
)
@exp_cache(serializer=ActionNodeSerializer())
async def fill(
self,
*,

View file

@ -7,6 +7,3 @@ class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
init_exp: bool = Field(
default=False, description="Put some basic experiences associated with the roles into the experience pool."
)

View file

@ -0,0 +1,33 @@
"""Action Node context builder."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
ACTION_NODE_CONTEXT_TEMPLATE = """
{req}
### Experiences
-----
{exps}
-----
## Instruction
Consider **Experiences** to generate a better answer.
"""
class ActionNodeContextBuilder(BaseContextBuilder):
async def build(self, **kwargs) -> str:
"""Builds the action node context string.
Args:
**kwargs: Arbitrary keyword arguments, expecting 'req' as a key.
Returns:
str: The formatted context string using the request and formatted experiences.
If no experiences are available, returns the request as is.
"""
req = kwargs.get("req", "")
exps = self.format_exps()
return ACTION_NODE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req

View file

@ -1,6 +1,5 @@
"""Base context builder."""
import re
from abc import ABC, abstractmethod
from typing import Any
@ -17,11 +16,19 @@ class BaseContextBuilder(BaseModel, ABC):
exps: list[Experience] = []
@abstractmethod
async def build(self, *args, **kwargs) -> Any:
async def build(self, **kwargs) -> Any:
"""Build context from parameters."""
def format_exps(self) -> str:
"""Format experiences into a numbered list of strings."""
"""Format experiences into a numbered list of strings.
Example:
1. Given the request: req1, We can get the response: resp1, Which scored: 8.
2. Given the request: req2, We can get the response: resp2, Which scored: 9.
Returns:
str: The formatted experiences as a string.
"""
result = []
for i, exp in enumerate(self.exps, start=1):
@ -29,25 +36,3 @@ class BaseContextBuilder(BaseModel, ABC):
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=score_val))
return "\n".join(result)
@staticmethod
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
Args:
text (str): The original text.
new_content (str): The new content to replace the old content.
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
Returns:
str: The text with the content replaced.
"""
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
def replacement(match):
return f"{match.group(1)}{new_content}\n{match.group(3)}"
replaced_text = pattern.sub(replacement, text)
return replaced_text

View file

@ -1,15 +1,19 @@
"""RoleZero context builder."""
import copy
import json
import re
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
class RoleZeroContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> list[dict]:
async def build(self, **kwargs) -> list[dict]:
"""Builds the context by updating the req with formatted experiences.
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
Args:
**kwargs: Arbitrary keyword arguments, expecting 'req' as a key.
Returns:
list[dict]: The updated request with formatted experiences or the original request if no experiences are available.
"""
req = kwargs.get("req", [])
@ -28,23 +32,23 @@ class RoleZeroContextBuilder(BaseContextBuilder):
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)
@staticmethod
def req_serialize(req: list[dict]) -> str:
"""Serialize the request for database storage, ensuring it is a string.
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
This function deep copies the request and modifies the content of the last element
to remove unnecessary sections, making the request more concise.
Args:
text (str): The original text.
new_content (str): The new content to replace the old content.
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
Returns:
str: The text with the content replaced.
"""
req_copy = copy.deepcopy(req)
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
last_content = req_copy[-1]["content"]
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Data Structure", "# Current Plan", ""
)
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Example", "# Instruction", ""
)
def replacement(match):
return f"{match.group(1)}{new_content}\n{match.group(3)}"
req_copy[-1]["content"] = last_content
return json.dumps(req_copy)
replaced_text = pattern.sub(replacement, text)
return replaced_text

View file

@ -4,21 +4,21 @@
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
SIMPLE_CONTEXT_TEMPLATE = """
{req}
## Context
### Experiences
-----
{exps}
-----
## User Requirement
{req}
## Instruction
Consider **Experiences** to generate a better answer.
"""
class SimpleContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> str:
req = kwargs.get("req", "")
exps = self.format_exps()
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req
async def build(self, **kwargs) -> str:
return SIMPLE_CONTEXT_TEMPLATE.format(req=kwargs.get("req", ""), exps=self.format_exps())

View file

@ -1,6 +1,7 @@
"""Experience Decorator."""
import asyncio
import copy
import functools
from typing import Any, Callable, Optional, TypeVar
@ -12,6 +13,7 @@ from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
from metagpt.logs import logger
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.exceptions import handle_exception
@ -26,9 +28,7 @@ def exp_cache(
scorer: Optional[BaseScorer] = None,
perfect_judge: Optional[BasePerfectJudge] = None,
context_builder: Optional[BaseContextBuilder] = None,
req_serialize: Optional[Callable[..., str]] = None,
resp_serialize: Optional[Callable[..., str]] = None,
resp_deserialize: Optional[Callable[[str], Any]] = None,
serializer: Optional[BaseSerializer] = None,
tag: Optional[str] = None,
):
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
@ -44,9 +44,7 @@ def exp_cache(
scorer: Evaluate experience. Default to `SimpleScorer()`.
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
serializer: Serializes the request and the function's return value for storage, deserializes the stored response back to the function's return value. Defaults to `SimpleSerializer()`.
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
"""
@ -65,9 +63,7 @@ def exp_cache(
exp_scorer=scorer,
exp_perfect_judge=perfect_judge,
context_builder=context_builder,
req_serialize=req_serialize,
resp_serialize=resp_serialize,
resp_deserialize=resp_deserialize,
serializer=serializer,
tag=tag,
)
@ -96,9 +92,7 @@ class ExpCacheHandler(BaseModel):
exp_scorer: Optional[BaseScorer] = None
exp_perfect_judge: Optional[BasePerfectJudge] = None
context_builder: Optional[BaseContextBuilder] = None
req_serialize: Optional[Callable[..., str]] = None
resp_serialize: Optional[Callable[..., str]] = None
resp_deserialize: Optional[Callable[[str], Any]] = None
serializer: Optional[BaseSerializer] = None
tag: Optional[str] = None
_exps: list[Experience] = None
@ -120,12 +114,10 @@ class ExpCacheHandler(BaseModel):
self.exp_scorer = self.exp_scorer or SimpleScorer()
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
self.context_builder = self.context_builder or SimpleContextBuilder()
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
self.serializer = self.serializer or SimpleSerializer()
self.tag = self.tag or self._generate_tag()
self._req = self.req_serialize(self.kwargs["req"])
self._req = self.serializer.serialize_req(copy.deepcopy(self.kwargs["req"]))
return self
@ -140,7 +132,7 @@ class ExpCacheHandler(BaseModel):
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
return self.resp_deserialize(exp.resp)
return self.serializer.deserialize_resp(exp.resp)
return None
@ -148,7 +140,7 @@ class ExpCacheHandler(BaseModel):
"""Execute the function, and save resp."""
self._raw_resp = await self._execute_function()
self._resp = self.resp_serialize(self._raw_resp)
self._resp = self.serializer.serialize_resp(copy.deepcopy(self._raw_resp))
@handle_exception
async def process_experience(self):
@ -204,7 +196,7 @@ class ExpCacheHandler(BaseModel):
async def _build_context(self) -> str:
self.context_builder.exps = self._exps
return await self.context_builder.build(*self.args, **self.kwargs)
return await self.context_builder.build(**self.kwargs)
async def _execute_function(self):
self.kwargs["req"] = await self._build_context()

View file

@ -47,14 +47,12 @@ class ExperienceManager(BaseModel):
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
self.init_exp_pool()
logger.debug(f"exp_pool config: {self.config.exp_pool}")
return self
@handle_exception
def init_exp_pool(self):
if not self.config.exp_pool.init_exp:
if not self.config.exp_pool.enable_write:
return
if self._has_exps():

View file

@ -0,0 +1,9 @@
"""Serializers init."""
from metagpt.exp_pool.serializers.base import BaseSerializer
from metagpt.exp_pool.serializers.simple import SimpleSerializer
from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer
from metagpt.exp_pool.serializers.role_zero import RoleZeroSerializer
__all__ = ["BaseSerializer", "SimpleSerializer", "ActionNodeSerializer", "RoleZeroSerializer"]

View file

@ -0,0 +1,36 @@
"""ActionNode Serializer."""
from __future__ import annotations
from typing import TYPE_CHECKING, Type
# Import ActionNode only for type checking to avoid circular imports
if TYPE_CHECKING:
from metagpt.actions.action_node import ActionNode
from metagpt.exp_pool.serializers.simple import SimpleSerializer
class ActionNodeSerializer(SimpleSerializer):
def serialize_resp(self, resp: ActionNode) -> str:
return resp.instruct_content.model_dump_json()
def deserialize_resp(self, resp: str) -> ActionNode:
"""Customized deserialization, it will be triggered when a perfect experience is found.
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
"""
class InstructContent:
def __init__(self, json_data):
self.json_data = json_data
def model_dump_json(self):
return self.json_data
from metagpt.actions.action_node import ActionNode
action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="")
action_node.instruct_content = InstructContent(resp)
return action_node

View file

@ -0,0 +1,22 @@
"""Base serializer."""
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict
class BaseSerializer(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
def serialize_req(self, req: Any) -> str:
"""Serializes the request for storage."""
@abstractmethod
def serialize_resp(self, resp: Any) -> str:
"""Serializes the function's return value for storage."""
@abstractmethod
def deserialize_resp(self, resp: str) -> Any:
"""Deserializes the stored response back to the function's return value"""

View file

@ -0,0 +1,40 @@
"""RoleZero Serializer."""
import json
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers.simple import SimpleSerializer
class RoleZeroSerializer(SimpleSerializer):
def serialize_req(self, req: list[dict]) -> str:
"""Serialize the request for database storage, ensuring it is a string.
This function modifies the content of the last element in the request to remove unnecessary sections,
making the request more concise.
Args:
req (list[dict]): The request to be serialized. Example:
[
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
{"role": "user", "content": "..."},
]
Returns:
str: The serialized request as a JSON string.
"""
if not req:
return ""
last_content = req[-1]["content"]
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Data Structure", "# Current Plan", ""
)
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Example", "# Instruction", ""
)
req[-1]["content"] = last_content
return json.dumps(req)

View file

@ -0,0 +1,22 @@
"""Simple Serializer."""
from typing import Any
from metagpt.exp_pool.serializers.base import BaseSerializer
class SimpleSerializer(BaseSerializer):
def serialize_req(self, req: Any) -> str:
"""Just use `str` to convert the request object into a string."""
return str(req)
def serialize_resp(self, resp: Any) -> str:
"""Just use `str` to convert the response object into a string."""
return str(resp)
def deserialize_resp(self, resp: str) -> Any:
"""Just return the string response as it is."""
return resp

View file

@ -12,6 +12,7 @@ from metagpt.actions import Action
from metagpt.actions.di.run_command import RunCommand
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.exp_pool.serializers import RoleZeroSerializer
from metagpt.logs import logger
from metagpt.prompts.di.role_zero import (
CMD_PROMPT,
@ -165,9 +166,7 @@ class RoleZero(Role):
return True
@exp_cache(
context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZeroContextBuilder.req_serialize(req)
)
@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
return await self.llm.aask(req, system_msgs=system_msgs)

View file

@ -30,16 +30,3 @@ class TestBaseContextBuilder:
]
)
assert result == expected
def test_replace_content_between_markers(self):
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
assert result == expected
def test_replace_content_between_markers_no_match(self):
text = "Start\nNo markers\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
assert result == text

View file

@ -30,9 +30,22 @@ class TestRoleZeroContextBuilder:
assert result == [{"content": "Updated content"}]
def test_replace_example_content(self, context_builder, mocker):
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
mocker.patch.object(RoleZeroContextBuilder, "replace_content_between_markers", return_value="Replaced content")
result = context_builder.replace_example_content("Original text", "New example content")
assert result == "Replaced content"
context_builder.replace_content_between_markers.assert_called_once_with(
"Original text", "# Example", "# Instruction", "New example content"
)
def test_replace_content_between_markers(self):
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
new_content = "New content"
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
assert result == expected
def test_replace_content_between_markers_no_match(self):
text = "Start\nNo markers\nEnd"
new_content = "New content"
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
assert result == text

View file

@ -32,7 +32,8 @@ class TestSimpleContextBuilder:
req = "Test request"
result = await context_builder.build(req=req)
assert result == req
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps="")
assert result == expected
@pytest.mark.asyncio
async def test_build_without_req(self, context_builder, mocker):

View file

@ -11,9 +11,7 @@ from metagpt.rag.engines import SimpleEngine
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
)
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
@pytest.fixture
def mock_storage(self, mocker):
@ -65,30 +63,6 @@ class TestExperienceManager:
result = await mock_experience_manager.query_exps("query")
assert result == []
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_called_once()
mock_experience_manager._init_engineer2_exps.assert_called_once()
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_not_called()
mock_experience_manager._init_engineer2_exps.assert_not_called()
def test_has_exps(self, mock_experience_manager, mock_storage):
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]