diff --git a/config/config2.example.yaml b/config/config2.example.yaml index a3bd5c367..330b73680 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -78,7 +78,6 @@ exp_pool: enable_read: false enable_write: false persist_path: .chroma_exp_data # The directory. - init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index e1e0bddbb..c1de16656 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -19,6 +19,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT from metagpt.exp_pool import exp_cache +from metagpt.exp_pool.serializers import ActionNodeSerializer from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess @@ -466,29 +467,7 @@ class ActionNode: return self - @classmethod - def deserialize_to_action_node(cls, serialized_data) -> "ActionNode": - """Customized deserialization, it will be triggered when a perfect experience is found. - - ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'. - """ - - class InstructContent: - def __init__(self, json_data): - self.json_data = json_data - - def model_dump_json(self): - return self.json_data - - action_node = cls(key="", expected_type=Type[str], instruction="", example="") - action_node.instruct_content = InstructContent(serialized_data) - - return action_node - - @exp_cache( - resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(), - resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp), - ) + @exp_cache(serializer=ActionNodeSerializer()) async def fill( self, *, diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index 0c92312da..786558ed9 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -7,6 +7,3 @@ class ExperiencePoolConfig(YamlModel): enable_read: bool = Field(default=False, description="Enable to read from experience pool.") enable_write: bool = Field(default=False, description="Enable to write to experience pool.") persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") - init_exp: bool = Field( - default=False, description="Put some basic experiences associated with the roles into the experience pool." - ) diff --git a/metagpt/exp_pool/context_builders/action_node.py b/metagpt/exp_pool/context_builders/action_node.py new file mode 100644 index 000000000..ade157822 --- /dev/null +++ b/metagpt/exp_pool/context_builders/action_node.py @@ -0,0 +1,33 @@ +"""Action Node context builder.""" + + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +ACTION_NODE_CONTEXT_TEMPLATE = """ +{req} + +### Experiences +----- +{exps} +----- + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class ActionNodeContextBuilder(BaseContextBuilder): + async def build(self, **kwargs) -> str: + """Builds the action node context string. + + Args: + **kwargs: Arbitrary keyword arguments, expecting 'req' as a key. + + Returns: + str: The formatted context string using the request and formatted experiences. + If no experiences are available, returns the request as is. + """ + req = kwargs.get("req", "") + exps = self.format_exps() + + return ACTION_NODE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req diff --git a/metagpt/exp_pool/context_builders/base.py b/metagpt/exp_pool/context_builders/base.py index a261e452e..d1133c2da 100644 --- a/metagpt/exp_pool/context_builders/base.py +++ b/metagpt/exp_pool/context_builders/base.py @@ -1,6 +1,5 @@ """Base context builder.""" -import re from abc import ABC, abstractmethod from typing import Any @@ -17,11 +16,19 @@ class BaseContextBuilder(BaseModel, ABC): exps: list[Experience] = [] @abstractmethod - async def build(self, *args, **kwargs) -> Any: + async def build(self, **kwargs) -> Any: """Build context from parameters.""" def format_exps(self) -> str: - """Format experiences into a numbered list of strings.""" + """Format experiences into a numbered list of strings. + + Example: + 1. Given the request: req1, We can get the response: resp1, Which scored: 8. + 2. Given the request: req2, We can get the response: resp2, Which scored: 9. + + Returns: + str: The formatted experiences as a string. + """ result = [] for i, exp in enumerate(self.exps, start=1): @@ -29,25 +36,3 @@ class BaseContextBuilder(BaseModel, ABC): result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=score_val)) return "\n".join(result) - - @staticmethod - def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str: - """Replace the content between `start_marker` and `end_marker` in the text with `new_content`. - - Args: - text (str): The original text. - new_content (str): The new content to replace the old content. - start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'. - end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'. - - Returns: - str: The text with the content replaced. - """ - - pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL) - - def replacement(match): - return f"{match.group(1)}{new_content}\n{match.group(3)}" - - replaced_text = pattern.sub(replacement, text) - return replaced_text diff --git a/metagpt/exp_pool/context_builders/role_zero.py b/metagpt/exp_pool/context_builders/role_zero.py index e9ab83d90..b492ca5ca 100644 --- a/metagpt/exp_pool/context_builders/role_zero.py +++ b/metagpt/exp_pool/context_builders/role_zero.py @@ -1,15 +1,19 @@ """RoleZero context builder.""" -import copy -import json + +import re from metagpt.exp_pool.context_builders.base import BaseContextBuilder class RoleZeroContextBuilder(BaseContextBuilder): - async def build(self, *args, **kwargs) -> list[dict]: + async def build(self, **kwargs) -> list[dict]: """Builds the context by updating the req with formatted experiences. - If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences. + Args: + **kwargs: Arbitrary keyword arguments, expecting 'req' as a key. + + Returns: + list[dict]: The updated request with formatted experiences or the original request if no experiences are available. """ req = kwargs.get("req", []) @@ -28,23 +32,23 @@ class RoleZeroContextBuilder(BaseContextBuilder): return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content) @staticmethod - def req_serialize(req: list[dict]) -> str: - """Serialize the request for database storage, ensuring it is a string. + def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str: + """Replace the content between `start_marker` and `end_marker` in the text with `new_content`. - This function deep copies the request and modifies the content of the last element - to remove unnecessary sections, making the request more concise. + Args: + text (str): The original text. + new_content (str): The new content to replace the old content. + start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'. + end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'. + + Returns: + str: The text with the content replaced. """ - req_copy = copy.deepcopy(req) + pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL) - last_content = req_copy[-1]["content"] - last_content = RoleZeroContextBuilder.replace_content_between_markers( - last_content, "# Data Structure", "# Current Plan", "" - ) - last_content = RoleZeroContextBuilder.replace_content_between_markers( - last_content, "# Example", "# Instruction", "" - ) + def replacement(match): + return f"{match.group(1)}{new_content}\n{match.group(3)}" - req_copy[-1]["content"] = last_content - - return json.dumps(req_copy) + replaced_text = pattern.sub(replacement, text) + return replaced_text diff --git a/metagpt/exp_pool/context_builders/simple.py b/metagpt/exp_pool/context_builders/simple.py index 35e2e1c8a..565855664 100644 --- a/metagpt/exp_pool/context_builders/simple.py +++ b/metagpt/exp_pool/context_builders/simple.py @@ -4,21 +4,21 @@ from metagpt.exp_pool.context_builders.base import BaseContextBuilder SIMPLE_CONTEXT_TEMPLATE = """ -{req} +## Context ### Experiences ----- {exps} ----- +## User Requirement +{req} + ## Instruction Consider **Experiences** to generate a better answer. """ class SimpleContextBuilder(BaseContextBuilder): - async def build(self, *args, **kwargs) -> str: - req = kwargs.get("req", "") - exps = self.format_exps() - - return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req + async def build(self, **kwargs) -> str: + return SIMPLE_CONTEXT_TEMPLATE.format(req=kwargs.get("req", ""), exps=self.format_exps()) diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 62f766b9d..deb3faafc 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -1,6 +1,7 @@ """Experience Decorator.""" import asyncio +import copy import functools from typing import Any, Callable, Optional, TypeVar @@ -12,6 +13,7 @@ from metagpt.exp_pool.manager import ExperienceManager, exp_manager from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer +from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer from metagpt.logs import logger from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.exceptions import handle_exception @@ -26,9 +28,7 @@ def exp_cache( scorer: Optional[BaseScorer] = None, perfect_judge: Optional[BasePerfectJudge] = None, context_builder: Optional[BaseContextBuilder] = None, - req_serialize: Optional[Callable[..., str]] = None, - resp_serialize: Optional[Callable[..., str]] = None, - resp_deserialize: Optional[Callable[[str], Any]] = None, + serializer: Optional[BaseSerializer] = None, tag: Optional[str] = None, ): """Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience. @@ -44,9 +44,7 @@ def exp_cache( scorer: Evaluate experience. Default to `SimpleScorer()`. perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`. context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`. - req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`. - resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`. - resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`. + serializer: Serializes the request and the function's return value for storage, deserializes the stored response back to the function's return value. Defaults to `SimpleSerializer()`. tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`. """ @@ -65,9 +63,7 @@ def exp_cache( exp_scorer=scorer, exp_perfect_judge=perfect_judge, context_builder=context_builder, - req_serialize=req_serialize, - resp_serialize=resp_serialize, - resp_deserialize=resp_deserialize, + serializer=serializer, tag=tag, ) @@ -96,9 +92,7 @@ class ExpCacheHandler(BaseModel): exp_scorer: Optional[BaseScorer] = None exp_perfect_judge: Optional[BasePerfectJudge] = None context_builder: Optional[BaseContextBuilder] = None - req_serialize: Optional[Callable[..., str]] = None - resp_serialize: Optional[Callable[..., str]] = None - resp_deserialize: Optional[Callable[[str], Any]] = None + serializer: Optional[BaseSerializer] = None tag: Optional[str] = None _exps: list[Experience] = None @@ -120,12 +114,10 @@ class ExpCacheHandler(BaseModel): self.exp_scorer = self.exp_scorer or SimpleScorer() self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge() self.context_builder = self.context_builder or SimpleContextBuilder() - self.req_serialize = self.req_serialize or (lambda resp: str(resp)) - self.resp_serialize = self.resp_serialize or (lambda resp: str(resp)) - self.resp_deserialize = self.resp_deserialize or (lambda resp: resp) + self.serializer = self.serializer or SimpleSerializer() self.tag = self.tag or self._generate_tag() - self._req = self.req_serialize(self.kwargs["req"]) + self._req = self.serializer.serialize_req(copy.deepcopy(self.kwargs["req"])) return self @@ -140,7 +132,7 @@ class ExpCacheHandler(BaseModel): for exp in self._exps: if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): logger.info(f"Get one perfect experience: {exp.req[:20]}...") - return self.resp_deserialize(exp.resp) + return self.serializer.deserialize_resp(exp.resp) return None @@ -148,7 +140,7 @@ class ExpCacheHandler(BaseModel): """Execute the function, and save resp.""" self._raw_resp = await self._execute_function() - self._resp = self.resp_serialize(self._raw_resp) + self._resp = self.serializer.serialize_resp(copy.deepcopy(self._raw_resp)) @handle_exception async def process_experience(self): @@ -204,7 +196,7 @@ class ExpCacheHandler(BaseModel): async def _build_context(self) -> str: self.context_builder.exps = self._exps - return await self.context_builder.build(*self.args, **self.kwargs) + return await self.context_builder.build(**self.kwargs) async def _execute_function(self): self.kwargs["req"] = await self._build_context() diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 23198eb02..649210a79 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -47,14 +47,12 @@ class ExperienceManager(BaseModel): self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) - self.init_exp_pool() - logger.debug(f"exp_pool config: {self.config.exp_pool}") return self @handle_exception def init_exp_pool(self): - if not self.config.exp_pool.init_exp: + if not self.config.exp_pool.enable_write: return if self._has_exps(): diff --git a/metagpt/exp_pool/serializers/__init__.py b/metagpt/exp_pool/serializers/__init__.py new file mode 100644 index 000000000..8e1045588 --- /dev/null +++ b/metagpt/exp_pool/serializers/__init__.py @@ -0,0 +1,9 @@ +"""Serializers init.""" + +from metagpt.exp_pool.serializers.base import BaseSerializer +from metagpt.exp_pool.serializers.simple import SimpleSerializer +from metagpt.exp_pool.serializers.action_node import ActionNodeSerializer +from metagpt.exp_pool.serializers.role_zero import RoleZeroSerializer + + +__all__ = ["BaseSerializer", "SimpleSerializer", "ActionNodeSerializer", "RoleZeroSerializer"] diff --git a/metagpt/exp_pool/serializers/action_node.py b/metagpt/exp_pool/serializers/action_node.py new file mode 100644 index 000000000..7746d6be4 --- /dev/null +++ b/metagpt/exp_pool/serializers/action_node.py @@ -0,0 +1,36 @@ +"""ActionNode Serializer.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Type + +# Import ActionNode only for type checking to avoid circular imports +if TYPE_CHECKING: + from metagpt.actions.action_node import ActionNode + +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class ActionNodeSerializer(SimpleSerializer): + def serialize_resp(self, resp: ActionNode) -> str: + return resp.instruct_content.model_dump_json() + + def deserialize_resp(self, resp: str) -> ActionNode: + """Customized deserialization, it will be triggered when a perfect experience is found. + + ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'. + """ + + class InstructContent: + def __init__(self, json_data): + self.json_data = json_data + + def model_dump_json(self): + return self.json_data + + from metagpt.actions.action_node import ActionNode + + action_node = ActionNode(key="", expected_type=Type[str], instruction="", example="") + action_node.instruct_content = InstructContent(resp) + + return action_node diff --git a/metagpt/exp_pool/serializers/base.py b/metagpt/exp_pool/serializers/base.py new file mode 100644 index 000000000..82a0ed8c4 --- /dev/null +++ b/metagpt/exp_pool/serializers/base.py @@ -0,0 +1,22 @@ +"""Base serializer.""" + +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class BaseSerializer(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + def serialize_req(self, req: Any) -> str: + """Serializes the request for storage.""" + + @abstractmethod + def serialize_resp(self, resp: Any) -> str: + """Serializes the function's return value for storage.""" + + @abstractmethod + def deserialize_resp(self, resp: str) -> Any: + """Deserializes the stored response back to the function's return value""" diff --git a/metagpt/exp_pool/serializers/role_zero.py b/metagpt/exp_pool/serializers/role_zero.py new file mode 100644 index 000000000..75e5d5ecb --- /dev/null +++ b/metagpt/exp_pool/serializers/role_zero.py @@ -0,0 +1,40 @@ +"""RoleZero Serializer.""" + +import json + +from metagpt.exp_pool.context_builders import RoleZeroContextBuilder +from metagpt.exp_pool.serializers.simple import SimpleSerializer + + +class RoleZeroSerializer(SimpleSerializer): + def serialize_req(self, req: list[dict]) -> str: + """Serialize the request for database storage, ensuring it is a string. + + This function modifies the content of the last element in the request to remove unnecessary sections, + making the request more concise. + + Args: + req (list[dict]): The request to be serialized. Example: + [ + {"role": "user", "content": "..."}, + {"role": "assistant", "content": "..."}, + {"role": "user", "content": "..."}, + ] + + Returns: + str: The serialized request as a JSON string. + """ + if not req: + return "" + + last_content = req[-1]["content"] + last_content = RoleZeroContextBuilder.replace_content_between_markers( + last_content, "# Data Structure", "# Current Plan", "" + ) + last_content = RoleZeroContextBuilder.replace_content_between_markers( + last_content, "# Example", "# Instruction", "" + ) + + req[-1]["content"] = last_content + + return json.dumps(req) diff --git a/metagpt/exp_pool/serializers/simple.py b/metagpt/exp_pool/serializers/simple.py new file mode 100644 index 000000000..32fe29c9f --- /dev/null +++ b/metagpt/exp_pool/serializers/simple.py @@ -0,0 +1,22 @@ +"""Simple Serializer.""" + +from typing import Any + +from metagpt.exp_pool.serializers.base import BaseSerializer + + +class SimpleSerializer(BaseSerializer): + def serialize_req(self, req: Any) -> str: + """Just use `str` to convert the request object into a string.""" + + return str(req) + + def serialize_resp(self, resp: Any) -> str: + """Just use `str` to convert the response object into a string.""" + + return str(resp) + + def deserialize_resp(self, resp: str) -> Any: + """Just return the string response as it is.""" + + return resp diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 103a77911..59c58861f 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -12,6 +12,7 @@ from metagpt.actions import Action from metagpt.actions.di.run_command import RunCommand from metagpt.exp_pool import exp_cache from metagpt.exp_pool.context_builders import RoleZeroContextBuilder +from metagpt.exp_pool.serializers import RoleZeroSerializer from metagpt.logs import logger from metagpt.prompts.di.role_zero import ( CMD_PROMPT, @@ -165,9 +166,7 @@ class RoleZero(Role): return True - @exp_cache( - context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZeroContextBuilder.req_serialize(req) - ) + @exp_cache(context_builder=RoleZeroContextBuilder(), serializer=RoleZeroSerializer()) async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str: return await self.llm.aask(req, system_msgs=system_msgs) diff --git a/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py index 17696e1b4..0a160fb42 100644 --- a/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py +++ b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py @@ -30,16 +30,3 @@ class TestBaseContextBuilder: ] ) assert result == expected - - def test_replace_content_between_markers(self): - text = "Start\n# Example\nOld content\n# Instruction\nEnd" - new_content = "New content" - result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) - expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd" - assert result == expected - - def test_replace_content_between_markers_no_match(self): - text = "Start\nNo markers\nEnd" - new_content = "New content" - result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) - assert result == text diff --git a/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py index 0ea04432d..611d68211 100644 --- a/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py +++ b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py @@ -30,9 +30,22 @@ class TestRoleZeroContextBuilder: assert result == [{"content": "Updated content"}] def test_replace_example_content(self, context_builder, mocker): - mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content") + mocker.patch.object(RoleZeroContextBuilder, "replace_content_between_markers", return_value="Replaced content") result = context_builder.replace_example_content("Original text", "New example content") assert result == "Replaced content" context_builder.replace_content_between_markers.assert_called_once_with( "Original text", "# Example", "# Instruction", "New example content" ) + + def test_replace_content_between_markers(self): + text = "Start\n# Example\nOld content\n# Instruction\nEnd" + new_content = "New content" + result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) + expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd" + assert result == expected + + def test_replace_content_between_markers_no_match(self): + text = "Start\nNo markers\nEnd" + new_content = "New content" + result = RoleZeroContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) + assert result == text diff --git a/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py index e96addab9..b6d0f642e 100644 --- a/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py +++ b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py @@ -32,7 +32,8 @@ class TestSimpleContextBuilder: req = "Test request" result = await context_builder.build(req=req) - assert result == req + expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps="") + assert result == expected @pytest.mark.asyncio async def test_build_without_req(self, context_builder, mocker): diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index c12fc7e8c..6d0693efd 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -11,9 +11,7 @@ from metagpt.rag.engines import SimpleEngine class TestExperienceManager: @pytest.fixture def mock_config(self): - return Config( - llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False) - ) + return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True)) @pytest.fixture def mock_storage(self, mocker): @@ -65,30 +63,6 @@ class TestExperienceManager: result = await mock_experience_manager.query_exps("query") assert result == [] - def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker): - mock_experience_manager._has_exps = mocker.MagicMock(return_value=False) - mock_experience_manager._init_teamleader_exps = mocker.MagicMock() - mock_experience_manager._init_engineer2_exps = mocker.MagicMock() - - mock_config.exp_pool.init_exp = True - mock_experience_manager.init_exp_pool() - - mock_experience_manager._has_exps.assert_called_once() - mock_experience_manager._init_teamleader_exps.assert_called_once() - mock_experience_manager._init_engineer2_exps.assert_called_once() - - def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker): - mock_experience_manager._has_exps = mocker.MagicMock(return_value=True) - mock_experience_manager._init_teamleader_exps = mocker.MagicMock() - mock_experience_manager._init_engineer2_exps = mocker.MagicMock() - - mock_config.exp_pool.init_exp = True - mock_experience_manager.init_exp_pool() - - mock_experience_manager._has_exps.assert_called_once() - mock_experience_manager._init_teamleader_exps.assert_not_called() - mock_experience_manager._init_engineer2_exps.assert_not_called() - def test_has_exps(self, mock_experience_manager, mock_storage): mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]