From 6052d8b9ac8095514e07dd58b4c64f46f238f693 Mon Sep 17 00:00:00 2001 From: seehi <6580@pm.me> Date: Tue, 11 Jun 2024 21:40:51 +0800 Subject: [PATCH] update exp_pool decorator --- metagpt/exp_pool/decorator.py | 61 ++++++++---- metagpt/exp_pool/manager.py | 3 +- metagpt/utils/reflection.py | 25 +++-- tests/metagpt/exp_pool/test_decorator.py | 112 +++++++++++++++++++---- tests/metagpt/utils/test_reflection.py | 46 ++++++---- 5 files changed, 173 insertions(+), 74 deletions(-) diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index e559797a3..446220a47 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -2,9 +2,11 @@ import asyncio import functools +import inspect +import json from typing import Any, Callable, Optional, TypeVar -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, model_validator from metagpt.exp_pool.manager import ExperienceManager, exp_manager from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score @@ -42,8 +44,8 @@ def exp_cache( func=func, args=args, kwargs=kwargs, - exp_manager=manager or exp_manager, - exp_scorer=scorer or SimpleScorer(), + exp_manager=manager, + exp_scorer=scorer, pass_exps_to_func=pass_exps_to_func, ) @@ -67,8 +69,8 @@ class ExpCacheHandler(BaseModel): func: Callable args: Any kwargs: Any - exp_manager: ExperienceManager - exp_scorer: ExperienceScorer + exp_manager: Optional[ExperienceManager] = None + exp_scorer: Optional[ExperienceScorer] = None pass_exps_to_func: bool = False _exps: list[Experience] = None @@ -76,11 +78,22 @@ class ExpCacheHandler(BaseModel): _score: Score = None _req: str = None + @model_validator(mode="after") + def initialize(self): + if self.exp_manager is None: + self.exp_manager = exp_manager + + if self.exp_scorer is None: + self.exp_scorer = SimpleScorer() + + self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs) + + return self + async def fetch_experiences(self, query_type: QueryType): """Fetch a potentially perfect existing experience.""" - req = self._get_req_identifier() - self._exps = await self.exp_manager.query_exps(req, query_type=query_type) + self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type) def get_one_perfect_experience(self) -> Optional[Experience]: return self.exp_manager.extract_one_perfect_exp(self._exps) @@ -107,26 +120,29 @@ class ExpCacheHandler(BaseModel): def save_experience(self): """Save the new experience.""" - req = self._get_req_identifier() - exp = Experience(req=req, resp=self._result, metric=Metric(score=self._score)) + exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score)) self.exp_manager.create_exp(exp) - def _get_req_identifier(self): - """Generate a unique request identifier based on the function and its arguments. + @classmethod + def generate_req_identifier(cls, func, *args, **kwargs) -> str: + """Generate a unique request identifier for any given function and its arguments. - Result Example: - - "write_prd-('2048',)-{}" - - "WritePRD.run-('2048',)-{}" + Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'. + + Return Example: + SimpleClass.test_method@[1~2]@{"c"!3} """ - if not self._req: - cls_name = get_class_name(self.func, *self.args) - func_name = f"{cls_name}.{self.func.__name__}" if cls_name else self.func.__name__ - args = self.args[1:] if cls_name and len(self.args) >= 1 else self.args + cls_name = get_class_name(func) + func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__ - self._req = f"{func_name}-{args}-{self.kwargs}" + if cls_name and args and inspect.isfunction(func): + args = args[1:] - return self._req + args = cls._serialize_and_replace(args) + kwargs = cls._serialize_and_replace(kwargs) + + return f"{func_name}@{args}@{kwargs}" @staticmethod def choose_wrapper(func, wrapped_func): @@ -141,6 +157,11 @@ class ExpCacheHandler(BaseModel): return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper + @classmethod + def _serialize_and_replace(cls, data): + json_str = json.dumps(data) + return json_str.replace(", ", "~").replace(": ", "!") + async def _execute_function(self): if self.pass_exps_to_func: return await self._execute_function_with_exps() diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 7382fe8f1..35ee5fdac 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, model_validator from metagpt.config2 import Config, config from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType from metagpt.rag.engines import SimpleEngine -from metagpt.rag.schema import ChromaRetrieverConfig +from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig from metagpt.utils.exceptions import handle_exception @@ -31,6 +31,7 @@ class ExperienceManager(BaseModel): retriever_configs=[ ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data") ], + ranker_configs=[LLMRankerConfig()], ) return self diff --git a/metagpt/utils/reflection.py b/metagpt/utils/reflection.py index 9b10a4b3e..fe852635f 100644 --- a/metagpt/utils/reflection.py +++ b/metagpt/utils/reflection.py @@ -19,24 +19,23 @@ def check_methods(C, *methods): return True -def get_class_name(func, *args) -> str: +def get_class_name(func) -> str: """Returns the class name of the object that a method belongs to. - - If `func` is a bound method, extracts the class name directly from the method. - - If `func` is an unbound method and `args` are provided, assumes the first argument is `self` and extracts the class name. - - Returns an empty string if neither condition is met. + - If `func` is a bound method or a class method, extracts the class name directly from the method. + - Returns an empty string if it's a regular function or cannot determine the class. """ if inspect.ismethod(func): + if inspect.isclass(func.__self__): + return func.__self__.__name__ + return func.__self__.__class__.__name__ - if inspect.isfunction(func) and "self" in inspect.signature(func).parameters and args: - return args[0].__class__.__name__ + if inspect.isfunction(func): + qualname_parts = func.__qualname__.split(".") + if len(qualname_parts) > 1: + class_name = qualname_parts[-2] + if class_name.isidentifier(): + return class_name return "" - - -def get_func_or_method_name(func, *args) -> str: - """Function name, or method name with class name.""" - cls_name = get_class_name(func, *args) - - return f"{cls_name}.{func.__name__}" if cls_name else f"{func.__name__}" diff --git a/tests/metagpt/exp_pool/test_decorator.py b/tests/metagpt/exp_pool/test_decorator.py index 508229d18..bedc4e391 100644 --- a/tests/metagpt/exp_pool/test_decorator.py +++ b/tests/metagpt/exp_pool/test_decorator.py @@ -1,14 +1,28 @@ import asyncio +import inspect import pytest -from metagpt.exp_pool.decorator import ExpCacheHandler +from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache from metagpt.exp_pool.manager import ExperienceManager from metagpt.exp_pool.schema import Experience, QueryType, Score from metagpt.exp_pool.scorers import SimpleScorer from metagpt.rag.engines import SimpleEngine +def for_test_function(a, b, c=None): + return a + b if c is None else a + b + c + + +class ForTestClass: + def for_test_method(self, x, y): + return x * y + + @classmethod + def for_test_class_method(cls, x, y): + return x**y + + class TestExpCache: @pytest.fixture def mock_func(self, mocker): @@ -46,7 +60,7 @@ class TestExpCache: perfect_exp = Experience(req="req", resp="resp") mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp - # Execute + # Exec exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences result = exp_cache_handler.get_one_perfect_experience() @@ -60,7 +74,7 @@ class TestExpCache: mock_exp_manager.extract_one_perfect_exp.return_value = None mock_func.return_value = "Computed result" - # Execute + # Exec await exp_cache_handler.execute_function() # Assert @@ -73,7 +87,7 @@ class TestExpCache: mock_scorer.evaluate.return_value = Score(value=100) exp_cache_handler._result = "Computed result" - # Execute + # Exec await exp_cache_handler.evaluate_experience() exp_cache_handler.save_experience() @@ -84,12 +98,12 @@ class TestExpCache: @pytest.mark.asyncio async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func): # Setup - exp_cache_handler.pass_exps = True + exp_cache_handler.pass_exps_to_func = True mock_func.return_value = "Async result with exps" mock_exp_manager.extract_one_perfect_exp.return_value = None exp_cache_handler._exps = [Experience(req="req", resp="resp")] - # Execute + # Exec await exp_cache_handler.execute_function() # Assert @@ -99,11 +113,11 @@ class TestExpCache: def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func): # Setup exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps") - exp_cache_handler.pass_exps = True + exp_cache_handler.pass_exps_to_func = True mock_exp_manager.extract_one_perfect_exp.return_value = None exp_cache_handler._exps = [Experience(req="req", resp="resp")] - # Execute + # Exec asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function()) # Assert @@ -114,7 +128,7 @@ class TestExpCache: # Setup mock_func = mocker.AsyncMock() - # Execute + # Exec wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function) # Assert @@ -124,22 +138,80 @@ class TestExpCache: # Setup sync_func = mocker.Mock() - # Execute + # Exec wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function) # Assert assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous" - @pytest.mark.asyncio - async def test_generate_req_identifier(self, exp_cache_handler): - # Setup - exp_cache_handler.func = lambda x: x - exp_cache_handler.args = (42,) - exp_cache_handler.kwargs = {"y": 3.14} + @pytest.mark.parametrize( + "func, args, kwargs, expected", + [ + (for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'), + (ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"), + (ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"), + (for_test_function, (), {}, "for_test_function@[]@{}"), + ( + for_test_function, + ("hello", [1, 2]), + {"key": "value"}, + 'for_test_function@["hello"~[1~2]]@{"key"!"value"}', + ), + ], + ) + def test_generate_req_identifier(self, func, args, kwargs, expected): + req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs) + assert req_identifier == expected - # Execute - req_id = exp_cache_handler.generate_req_identifier() + @pytest.mark.asyncio + async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager): + # Mock perfect experience + perfect_exp = Experience(req="test_req", resp="perfect_response") + mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp]) + mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp) + async_mock_func = mocker.AsyncMock() + + # Setup + decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager) + + # Exec + result: Experience = await decorated_func() # Assert - expected_id = "_(42,)_{'y': 3.14}" - assert req_id == expected_id, "Request identifier should match the expected format" + assert result.resp == "perfect_response", "Should return the perfect experience response" + async_mock_func.assert_not_called() + + @pytest.mark.asyncio + async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager): + # Mock + mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[]) + mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None) + async_mock_func = mocker.AsyncMock(return_value="computed_response") + async_mock_func.__signature__ = inspect.signature(for_test_function) + + # Setup + decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager) + + # Exec + result = await decorated_func() + + # Assert + assert result == "computed_response", "Should execute and return the function's response" + async_mock_func.assert_called_once() + + @pytest.mark.asyncio + async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer): + # Mock + mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[]) + mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None) + async_mock_func = mocker.AsyncMock(return_value="computed_response") + mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100)) + + # Setup + decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer) + + # Exec + await decorated_func() + + # Assert + mock_exp_manager.create_exp.assert_called_once() diff --git a/tests/metagpt/utils/test_reflection.py b/tests/metagpt/utils/test_reflection.py index e78e1b400..58fd81619 100644 --- a/tests/metagpt/utils/test_reflection.py +++ b/tests/metagpt/utils/test_reflection.py @@ -1,29 +1,35 @@ -from metagpt.utils.reflection import get_func_or_method_name +from metagpt.utils.reflection import get_class_name -def simple_function(): - pass - - -class SampleClass: - def method(self): +class SimpleFunction: + def function(self): pass -class TestFunctionOrMethodName: - def test_simple_function(self): - assert get_func_or_method_name(simple_function) == "simple_function" +class SampleClass: + @classmethod + def class_method(cls): + pass - def test_class_method_without_args(self): - sample_instance = SampleClass() - assert get_func_or_method_name(sample_instance.method) == "SampleClass.method" + def instance_method(self): + pass - def test_class_method_with_args(self): - sample_instance = SampleClass() - assert get_func_or_method_name(SampleClass.method, sample_instance) == "SampleClass.method" - def test_function_with_no_args(self): - assert get_func_or_method_name(simple_function) == "simple_function" +def standalone_function(): + pass - def test_method_without_instance(self): - assert get_func_or_method_name(SampleClass.method) == "method" + +class TestGetClassName: + def test_instance_method(self): + instance = SampleClass() + assert get_class_name(instance.instance_method) == "SampleClass" + + def test_class_method(self): + assert get_class_name(SampleClass.class_method) == "SampleClass" + + def test_standalone_function(self): + assert get_class_name(standalone_function) == "" + + def test_function_within_simple_class(self): + instance = SimpleFunction() + assert get_class_name(instance.function) == "SimpleFunction"