mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
add serializers to support serialization and deserialization.
This commit is contained in:
parent
086ef5e805
commit
b5934a412b
19 changed files with 234 additions and 144 deletions
|
|
@ -78,7 +78,6 @@ exp_pool:
|
|||
enable_read: false
|
||||
enable_write: false
|
||||
persist_path: .chroma_exp_data # The directory.
|
||||
init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool.
|
||||
|
||||
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
azure_tts_region: "eastus"
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.serializers import ActionNodeSerializer
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -466,29 +467,7 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
|
||||
"""Customized deserialization, it will be triggered when a perfect experience is found.
|
||||
|
||||
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
|
||||
"""
|
||||
|
||||
class InstructContent:
|
||||
def __init__(self, json_data):
|
||||
self.json_data = json_data
|
||||
|
||||
def model_dump_json(self):
|
||||
return self.json_data
|
||||
|
||||
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
|
||||
action_node.instruct_content = InstructContent(serialized_data)
|
||||
|
||||
return action_node
|
||||
|
||||
@exp_cache(
|
||||
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
|
||||
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
|
||||
)
|
||||
@exp_cache(serializer=ActionNodeSerializer())
|
||||
async def fill(
|
||||
self,
|
||||
*,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,3 @@ class ExperiencePoolConfig(YamlModel):
|
|||
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
|
||||
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
|
||||
init_exp: bool = Field(
|
||||
default=False, description="Put some basic experiences associated with the roles into the experience pool."
|
||||
)
|
||||
|
|
|
|||
33
metagpt/exp_pool/context_builders/action_node.py
Normal file
33
metagpt/exp_pool/context_builders/action_node.py
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
"""Action Node context builder."""
|
||||
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
ACTION_NODE_CONTEXT_TEMPLATE = """
|
||||
{req}
|
||||
|
||||
### Experiences
|
||||
-----
|
||||
{exps}
|
||||
-----
|
||||
|
||||
## Instruction
|
||||
Consider **Experiences** to generate a better answer.
|
||||
"""
|
||||
|
||||
|
||||
class ActionNodeContextBuilder(BaseContextBuilder):
|
||||
async def build(self, **kwargs) -> str:
|
||||
"""Builds the action node context string.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, expecting 'req' as a key.
|
||||
|
||||
Returns:
|
||||
str: The formatted context string using the request and formatted experiences.
|
||||
If no experiences are available, returns the request as is.
|
||||
"""
|
||||
req = kwargs.get("req", "")
|
||||
exps = self.format_exps()
|
||||
|
||||
return ACTION_NODE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
"""Base context builder."""
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -17,11 +16,19 @@ class BaseContextBuilder(BaseModel, ABC):
|
|||
exps: list[Experience] = []
|
||||
|
||||
@abstractmethod
|
||||
async def build(self, *args, **kwargs) -> Any:
|
||||
async def build(self, **kwargs) -> Any:
|
||||
"""Build context from parameters."""
|
||||
|
||||
def format_exps(self) -> str:
|
||||
"""Format experiences into a numbered list of strings."""
|
||||
"""Format experiences into a numbered list of strings.
|
||||
|
||||
Example:
|
||||
1. Given the request: req1, We can get the response: resp1, Which scored: 8.
|
||||
2. Given the request: req2, We can get the response: resp2, Which scored: 9.
|
||||
|
||||
Returns:
|
||||
str: The formatted experiences as a string.
|
||||
"""
|
||||
|
||||
result = []
|
||||
for i, exp in enumerate(self.exps, start=1):
|
||||
|
|
@ -29,25 +36,3 @@ class BaseContextBuilder(BaseModel, ABC):
|
|||
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=score_val))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
|
||||
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
|
||||
|
||||
Args:
|
||||
text (str): The original text.
|
||||
new_content (str): The new content to replace the old content.
|
||||
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
|
||||
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
|
||||
|
||||
Returns:
|
||||
str: The text with the content replaced.
|
||||
"""
|
||||
|
||||
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
|
||||
|
||||
def replacement(match):
|
||||
return f"{match.group(1)}{new_content}\n{match.group(3)}"
|
||||
|
||||
replaced_text = pattern.sub(replacement, text)
|
||||
return replaced_text
|
||||
|
|
|
|||
|
|
@ -1,15 +1,19 @@
|
|||
"""RoleZero context builder."""
|
||||
import copy
|
||||
import json
|
||||
|
||||
import re
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
|
||||
class RoleZeroContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> list[dict]:
|
||||
async def build(self, **kwargs) -> list[dict]:
|
||||
"""Builds the context by updating the req with formatted experiences.
|
||||
|
||||
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments, expecting 'req' as a key.
|
||||
|
||||
Returns:
|
||||
list[dict]: The updated request with formatted experiences or the original request if no experiences are available.
|
||||
"""
|
||||
|
||||
req = kwargs.get("req", [])
|
||||
|
|
@ -28,23 +32,23 @@ class RoleZeroContextBuilder(BaseContextBuilder):
|
|||
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)
|
||||
|
||||
@staticmethod
|
||||
def req_serialize(req: list[dict]) -> str:
|
||||
"""Serialize the request for database storage, ensuring it is a string.
|
||||
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
|
||||
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
|
||||
|
||||
This function deep copies the request and modifies the content of the last element
|
||||
to remove unnecessary sections, making the request more concise.
|
||||
Args:
|
||||
text (str): The original text.
|
||||
new_content (str): The new content to replace the old content.
|
||||
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
|
||||
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
|
||||
|
||||
Returns:
|
||||
str: The text with the content replaced.
|
||||
"""
|
||||
|
||||
req_copy = copy.deepcopy(req)
|
||||
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
|
||||
|
||||
last_content = req_copy[-1]["content"]
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Data Structure", "# Current Plan", ""
|
||||
)
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Example", "# Instruction", ""
|
||||
)
|
||||
def replacement(match):
|
||||
return f"{match.group(1)}{new_content}\n{match.group(3)}"
|
||||
|
||||
req_copy[-1]["content"] = last_content
|
||||
|
||||
return json.dumps(req_copy)
|
||||
replaced_text = pattern.sub(replacement, text)
|
||||
return replaced_text
|
||||
|
|
|
|||
|
|
@ -4,21 +4,21 @@
|
|||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
SIMPLE_CONTEXT_TEMPLATE = """
|
||||
{req}
|
||||
## Context
|
||||
|
||||
### Experiences
|
||||
-----
|
||||
{exps}
|
||||
-----
|
||||
|
||||
## User Requirement
|
||||
{req}
|
||||
|
||||
## Instruction
|
||||
Consider **Experiences** to generate a better answer.
|
||||
"""
|
||||
|
||||
|
||||
class SimpleContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> str:
|
||||
req = kwargs.get("req", "")
|
||||
exps = self.format_exps()
|
||||
|
||||
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req
|
||||
async def build(self, **kwargs) -> str:
|
||||
return SIMPLE_CONTEXT_TEMPLATE.format(req=kwargs.get("req", ""), exps=self.format_exps())
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
"""Experience Decorator."""
|
||||
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
|
|
@ -12,6 +13,7 @@ from metagpt.exp_pool.manager import ExperienceManager, exp_manager
|
|||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
|
||||
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
|
||||
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -26,9 +28,7 @@ def exp_cache(
|
|||
scorer: Optional[BaseScorer] = None,
|
||||
perfect_judge: Optional[BasePerfectJudge] = None,
|
||||
context_builder: Optional[BaseContextBuilder] = None,
|
||||
req_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None,
|
||||
serializer: Optional[BaseSerializer] = None,
|
||||
tag: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
|
||||
|
|
@ -44,9 +44,7 @@ def exp_cache(
|
|||
scorer: Evaluate experience. Default to `SimpleScorer()`.
|
||||
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
|
||||
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
|
||||
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
|
||||
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
|
||||
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
|
||||
serializer: Serializes the request and the function's return value for storage, deserializes the stored response back to the function's return value. Defaults to `SimpleSerializer()`.
|
||||
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
|
||||
"""
|
||||
|
||||
|
|
@ -65,9 +63,7 @@ def exp_cache(
|
|||
exp_scorer=scorer,
|
||||
exp_perfect_judge=perfect_judge,
|
||||
context_builder=context_builder,
|
||||
req_serialize=req_serialize,
|
||||
resp_serialize=resp_serialize,
|
||||
resp_deserialize=resp_deserialize,
|
||||
serializer=serializer,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
|
|
@ -96,9 +92,7 @@ class ExpCacheHandler(BaseModel):
|
|||
exp_scorer: Optional[BaseScorer] = None
|
||||
exp_perfect_judge: Optional[BasePerfectJudge] = None
|
||||
context_builder: Optional[BaseContextBuilder] = None
|
||||
req_serialize: Optional[Callable[..., str]] = None
|
||||
resp_serialize: Optional[Callable[..., str]] = None
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None
|
||||
serializer: Optional[BaseSerializer] = None
|
||||
tag: Optional[str] = None
|
||||
|
||||
_exps: list[Experience] = None
|
||||
|
|
@ -120,12 +114,10 @@ class ExpCacheHandler(BaseModel):
|
|||
self.exp_scorer = self.exp_scorer or SimpleScorer()
|
||||
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
|
||||
self.context_builder = self.context_builder or SimpleContextBuilder()
|
||||
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
|
||||
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
|
||||
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
|
||||
self.serializer = self.serializer or SimpleSerializer()
|
||||
self.tag = self.tag or self._generate_tag()
|
||||
|
||||
self._req = self.req_serialize(self.kwargs["req"])
|
||||
self._req = self.serializer.serialize_req(copy.deepcopy(self.kwargs["req"]))
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -140,7 +132,7 @@ class ExpCacheHandler(BaseModel):
|
|||
for exp in self._exps:
|
||||
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
|
||||
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
|
||||
return self.resp_deserialize(exp.resp)
|
||||
return self.serializer.deserialize_resp(exp.resp)
|
||||
|
||||
return None
|
||||
|
||||
|
|
@ -148,7 +140,7 @@ class ExpCacheHandler(BaseModel):
|
|||
"""Execute the function, and save resp."""
|
||||
|
||||
self._raw_resp = await self._execute_function()
|
||||
self._resp = self.resp_serialize(self._raw_resp)
|
||||
self._resp = self.serializer.serialize_resp(copy.deepcopy(self._raw_resp))
|
||||
|
||||
@handle_exception
|
||||
async def process_experience(self):
|
||||
|
|
@ -204,7 +196,7 @@ class ExpCacheHandler(BaseModel):
|
|||
async def _build_context(self) -> str:
|
||||
self.context_builder.exps = self._exps
|
||||
|
||||
return await self.context_builder.build(*self.args, **self.kwargs)
|
||||
return await self.context_builder.build(**self.kwargs)
|
||||
|
||||
async def _execute_function(self):
|
||||
self.kwargs["req"] = await self._build_context()
|
||||
|
|
|
|||
|
|
@ -47,14 +47,12 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
self.init_exp_pool()
|
||||
|
||||
logger.debug(f"exp_pool config: {self.config.exp_pool}")
|
||||
return self
|
||||
|
||||
@handle_exception
|
||||
def init_exp_pool(self):
|
||||
if not self.config.exp_pool.init_exp:
|
||||
if not self.config.exp_pool.enable_write:
|
||||
return
|
||||
|
||||
if self._has_exps():
|
||||
|
|
|
|||
9
metagpt/exp_pool/serializers/__init__.py
Normal file
9
metagpt/exp_pool/serializers/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Serializers init."""
|
||||
|
||||
from metagpt.exp_pool.serializers.base import BaseSerializer
|
||||
from metagpt.exp_pool.serializers.simple import SimpleSerializer
|
||||
from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer
|
||||
from metagpt.exp_pool.serializers.role_zero import RoleZeroSerializer
|
||||
|
||||
|
||||
__all__ = ["BaseSerializer", "SimpleSerializer", "ActionNodeSerializer", "RoleZeroSerializer"]
|
||||
36
metagpt/exp_pool/serializers/action_node.py
Normal file
36
metagpt/exp_pool/serializers/action_node.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""ActionNode Serializer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Type
|
||||
|
||||
# Import ActionNode only for type checking to avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
from metagpt.exp_pool.serializers.simple import SimpleSerializer
|
||||
|
||||
|
||||
class ActionNodeSerializer(SimpleSerializer):
|
||||
def serialize_resp(self, resp: ActionNode) -> str:
|
||||
return resp.instruct_content.model_dump_json()
|
||||
|
||||
def deserialize_resp(self, resp: str) -> ActionNode:
|
||||
"""Customized deserialization, it will be triggered when a perfect experience is found.
|
||||
|
||||
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
|
||||
"""
|
||||
|
||||
class InstructContent:
|
||||
def __init__(self, json_data):
|
||||
self.json_data = json_data
|
||||
|
||||
def model_dump_json(self):
|
||||
return self.json_data
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="")
|
||||
action_node.instruct_content = InstructContent(resp)
|
||||
|
||||
return action_node
|
||||
22
metagpt/exp_pool/serializers/base.py
Normal file
22
metagpt/exp_pool/serializers/base.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Base serializer."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BaseSerializer(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
def serialize_req(self, req: Any) -> str:
|
||||
"""Serializes the request for storage."""
|
||||
|
||||
@abstractmethod
|
||||
def serialize_resp(self, resp: Any) -> str:
|
||||
"""Serializes the function's return value for storage."""
|
||||
|
||||
@abstractmethod
|
||||
def deserialize_resp(self, resp: str) -> Any:
|
||||
"""Deserializes the stored response back to the function's return value"""
|
||||
40
metagpt/exp_pool/serializers/role_zero.py
Normal file
40
metagpt/exp_pool/serializers/role_zero.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""RoleZero Serializer."""
|
||||
|
||||
import json
|
||||
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers.simple import SimpleSerializer
|
||||
|
||||
|
||||
class RoleZeroSerializer(SimpleSerializer):
|
||||
def serialize_req(self, req: list[dict]) -> str:
|
||||
"""Serialize the request for database storage, ensuring it is a string.
|
||||
|
||||
This function modifies the content of the last element in the request to remove unnecessary sections,
|
||||
making the request more concise.
|
||||
|
||||
Args:
|
||||
req (list[dict]): The request to be serialized. Example:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
{"role": "user", "content": "..."},
|
||||
]
|
||||
|
||||
Returns:
|
||||
str: The serialized request as a JSON string.
|
||||
"""
|
||||
if not req:
|
||||
return ""
|
||||
|
||||
last_content = req[-1]["content"]
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Data Structure", "# Current Plan", ""
|
||||
)
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Example", "# Instruction", ""
|
||||
)
|
||||
|
||||
req[-1]["content"] = last_content
|
||||
|
||||
return json.dumps(req)
|
||||
22
metagpt/exp_pool/serializers/simple.py
Normal file
22
metagpt/exp_pool/serializers/simple.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Simple Serializer."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from metagpt.exp_pool.serializers.base import BaseSerializer
|
||||
|
||||
|
||||
class SimpleSerializer(BaseSerializer):
|
||||
def serialize_req(self, req: Any) -> str:
|
||||
"""Just use `str` to convert the request object into a string."""
|
||||
|
||||
return str(req)
|
||||
|
||||
def serialize_resp(self, resp: Any) -> str:
|
||||
"""Just use `str` to convert the response object into a string."""
|
||||
|
||||
return str(resp)
|
||||
|
||||
def deserialize_resp(self, resp: str) -> Any:
|
||||
"""Just return the string response as it is."""
|
||||
|
||||
return resp
|
||||
|
|
@ -12,6 +12,7 @@ from metagpt.actions import Action
|
|||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.exp_pool.serializers import RoleZeroSerializer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.di.role_zero import (
|
||||
CMD_PROMPT,
|
||||
|
|
@ -165,9 +166,7 @@ class RoleZero(Role):
|
|||
|
||||
return True
|
||||
|
||||
@exp_cache(
|
||||
context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZeroContextBuilder.req_serialize(req)
|
||||
)
|
||||
@exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer())
|
||||
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
|
||||
return await self.llm.aask(req, system_msgs=system_msgs)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,16 +30,3 @@ class TestBaseContextBuilder:
|
|||
]
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers(self):
|
||||
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers_no_match(self):
|
||||
text = "Start\nNo markers\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
assert result == text
|
||||
|
|
|
|||
|
|
@ -30,9 +30,22 @@ class TestRoleZeroContextBuilder:
|
|||
assert result == [{"content": "Updated content"}]
|
||||
|
||||
def test_replace_example_content(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
|
||||
mocker.patch.object(RoleZeroContextBuilder, "replace_content_between_markers", return_value="Replaced content")
|
||||
result = context_builder.replace_example_content("Original text", "New example content")
|
||||
assert result == "Replaced content"
|
||||
context_builder.replace_content_between_markers.assert_called_once_with(
|
||||
"Original text", "# Example", "# Instruction", "New example content"
|
||||
)
|
||||
|
||||
def test_replace_content_between_markers(self):
|
||||
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
|
||||
new_content = "New content"
|
||||
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers_no_match(self):
|
||||
text = "Start\nNo markers\nEnd"
|
||||
new_content = "New content"
|
||||
result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
assert result == text
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ class TestSimpleContextBuilder:
|
|||
req = "Test request"
|
||||
result = await context_builder.build(req=req)
|
||||
|
||||
assert result == req
|
||||
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps="")
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_without_req(self, context_builder, mocker):
|
||||
|
|
|
|||
|
|
@ -11,9 +11,7 @@ from metagpt.rag.engines import SimpleEngine
|
|||
class TestExperienceManager:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return Config(
|
||||
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
|
||||
)
|
||||
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self, mocker):
|
||||
|
|
@ -65,30 +63,6 @@ class TestExperienceManager:
|
|||
result = await mock_experience_manager.query_exps("query")
|
||||
assert result == []
|
||||
|
||||
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_called_once()
|
||||
mock_experience_manager._init_engineer2_exps.assert_called_once()
|
||||
|
||||
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_not_called()
|
||||
mock_experience_manager._init_engineer2_exps.assert_not_called()
|
||||
|
||||
def test_has_exps(self, mock_experience_manager, mock_storage):
|
||||
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue