diff --git a/examples/exp_pool/init_exp_pool.py b/examples/exp_pool/init_exp_pool.py index 1601abe0b..321c38d78 100644 --- a/examples/exp_pool/init_exp_pool.py +++ b/examples/exp_pool/init_exp_pool.py @@ -8,7 +8,7 @@ import json from pathlib import Path from metagpt.const import EXAMPLE_DATA_PATH -from metagpt.exp_pool import exp_manager +from metagpt.exp_pool import get_exp_manager from metagpt.exp_pool.schema import EntryType, Experience, Metric, Score from metagpt.logs import logger from metagpt.utils.common import aread @@ -45,7 +45,7 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None): tag=tag, metric=metric or Metric(score=Score(val=10, reason="Manual")), ) - + exp_manager = get_exp_manager() exp_manager.config.exp_pool.enable_write = True exp_manager.create_exp(exp) logger.info(f"New experience created for the request `{req[:10]}`.") @@ -79,7 +79,7 @@ async def add_exps_from_file(tag: str, filepath: Path): def query_exps_count(): """Queries and logs the total count of experiences in the pool.""" - + exp_manager = get_exp_manager() count = exp_manager.get_exps_count() logger.info(f"Experiences Count: {count}") diff --git a/examples/exp_pool/manager.py b/examples/exp_pool/manager.py index ae998214a..5aead08e9 100644 --- a/examples/exp_pool/manager.py +++ b/examples/exp_pool/manager.py @@ -6,7 +6,7 @@ This script creates a new experience, logs its creation, and then queries for ex import asyncio -from metagpt.exp_pool import exp_manager +from metagpt.exp_pool import get_exp_manager from metagpt.exp_pool.schema import EntryType, Experience from metagpt.logs import logger @@ -15,6 +15,7 @@ async def main(): # Define the simple request and response req = "Simple req" resp = "Simple resp" + exp_manager = get_exp_manager() # Add the new experience exp = Experience(req=req, resp=resp, entry_type=EntryType.MANUAL) diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index 8a9d4bc95..d647883bd 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -29,7 +29,7 @@ class RAGEmbeddingFactory(GenericFactory): LLMType.AZURE: self._create_azure, } super().__init__(creators) - self.config = config if self.config else Config.default() + self.config = config if config else Config.default() def get_rag_embedding(self, key: EmbeddingType = None) -> BaseEmbedding: """Key is EmbeddingType.""" diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index 5d27cde3a..59f6db4d9 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -10,7 +10,7 @@ from llama_index.core.llms import ( LLMMetadata, ) from llama_index.core.llms.callbacks import llm_completion_callback -from pydantic import Field, model_validator +from pydantic import Field from metagpt.config2 import Config from metagpt.llm import LLM @@ -30,19 +30,30 @@ class RAGLLM(CustomLLM): num_output: int = -1 model_name: str = "" - @model_validator(mode="after") - def update_from_config(self): + def __init__( + self, + model_infer: BaseLLM, + context_window: int = -1, + num_output: int = -1, + model_name: str = "", + *args, + **kwargs + ): + super().__init__(*args, **kwargs) config = Config.default() - if self.context_window < 0: - self.context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) + if context_window < 0: + context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) - if self.num_output < 0: - self.num_output = config.llm.max_token + if num_output < 0: + num_output = config.llm.max_token - if not self.model_name: - self.model_name = config.llm.model + if not model_name: + model_name = config.llm.model - return self + self.model_infer = model_infer + self.context_window = context_window + self.num_output = num_output + self.model_name = model_name @property def metadata(self) -> LLMMetadata: diff --git a/tests/metagpt/exp_pool/test_decorator.py b/tests/metagpt/exp_pool/test_decorator.py index 0ca4c6ce1..9d104fca4 100644 --- a/tests/metagpt/exp_pool/test_decorator.py +++ b/tests/metagpt/exp_pool/test_decorator.py @@ -155,7 +155,10 @@ class TestExpCache: @pytest.fixture def mock_config(self, mocker): - return mocker.patch("metagpt.exp_pool.decorator.config") + config = Config.default().model_copy(deep=True) + default = mocker.patch("metagpt.config2.Config.default") + default.return_value = config + return config @pytest.mark.asyncio async def test_exp_cache_disabled(self, mock_config, mock_exp_manager): @@ -171,7 +174,9 @@ class TestExpCache: @pytest.mark.asyncio async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer): + mock_config.exp_pool.enabled = True mock_config.exp_pool.enable_read = True + mock_config.exp_pool.enable_write = True mock_exp_manager.query_exps.return_value = [] @exp_cache(manager=mock_exp_manager, scorer=mock_scorer) @@ -185,6 +190,7 @@ class TestExpCache: @pytest.mark.asyncio async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge): + mock_config.exp_pool.enabled = True mock_config.exp_pool.enable_read = True perfect_exp = Experience(req="test", resp="perfect_result") mock_exp_manager.query_exps.return_value = [perfect_exp]