update exp_pool decorator

This commit is contained in:
seehi 2024-06-11 21:40:51 +08:00
parent 4650b7bdf1
commit 6052d8b9ac
5 changed files with 173 additions and 74 deletions

View file

@ -2,9 +2,11 @@
import asyncio
import functools
import inspect
import json
from typing import Any, Callable, Optional, TypeVar
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
@ -42,8 +44,8 @@ def exp_cache(
func=func,
args=args,
kwargs=kwargs,
exp_manager=manager or exp_manager,
exp_scorer=scorer or SimpleScorer(),
exp_manager=manager,
exp_scorer=scorer,
pass_exps_to_func=pass_exps_to_func,
)
@ -67,8 +69,8 @@ class ExpCacheHandler(BaseModel):
func: Callable
args: Any
kwargs: Any
exp_manager: ExperienceManager
exp_scorer: ExperienceScorer
exp_manager: Optional[ExperienceManager] = None
exp_scorer: Optional[ExperienceScorer] = None
pass_exps_to_func: bool = False
_exps: list[Experience] = None
@ -76,11 +78,22 @@ class ExpCacheHandler(BaseModel):
_score: Score = None
_req: str = None
@model_validator(mode="after")
def initialize(self):
if self.exp_manager is None:
self.exp_manager = exp_manager
if self.exp_scorer is None:
self.exp_scorer = SimpleScorer()
self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs)
return self
async def fetch_experiences(self, query_type: QueryType):
"""Fetch a potentially perfect existing experience."""
req = self._get_req_identifier()
self._exps = await self.exp_manager.query_exps(req, query_type=query_type)
self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type)
def get_one_perfect_experience(self) -> Optional[Experience]:
return self.exp_manager.extract_one_perfect_exp(self._exps)
@ -107,26 +120,29 @@ class ExpCacheHandler(BaseModel):
def save_experience(self):
"""Save the new experience."""
req = self._get_req_identifier()
exp = Experience(req=req, resp=self._result, metric=Metric(score=self._score))
exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
def _get_req_identifier(self):
"""Generate a unique request identifier based on the function and its arguments.
@classmethod
def generate_req_identifier(cls, func, *args, **kwargs) -> str:
"""Generate a unique request identifier for any given function and its arguments.
Result Example:
- "write_prd-('2048',)-{}"
- "WritePRD.run-('2048',)-{}"
Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'.
Return Example:
SimpleClass.test_method@[1~2]@{"c"!3}
"""
if not self._req:
cls_name = get_class_name(self.func, *self.args)
func_name = f"{cls_name}.{self.func.__name__}" if cls_name else self.func.__name__
args = self.args[1:] if cls_name and len(self.args) >= 1 else self.args
cls_name = get_class_name(func)
func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__
self._req = f"{func_name}-{args}-{self.kwargs}"
if cls_name and args and inspect.isfunction(func):
args = args[1:]
return self._req
args = cls._serialize_and_replace(args)
kwargs = cls._serialize_and_replace(kwargs)
return f"{func_name}@{args}@{kwargs}"
@staticmethod
def choose_wrapper(func, wrapped_func):
@ -141,6 +157,11 @@ class ExpCacheHandler(BaseModel):
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
@classmethod
def _serialize_and_replace(cls, data):
json_str = json.dumps(data)
return json_str.replace(", ", "~").replace(": ", "!")
async def _execute_function(self):
if self.pass_exps_to_func:
return await self._execute_function_with_exps()

View file

@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
from metagpt.utils.exceptions import handle_exception
@ -31,6 +31,7 @@ class ExperienceManager(BaseModel):
retriever_configs=[
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
],
ranker_configs=[LLMRankerConfig()],
)
return self

View file

@ -19,24 +19,23 @@ def check_methods(C, *methods):
return True
def get_class_name(func, *args) -> str:
def get_class_name(func) -> str:
"""Returns the class name of the object that a method belongs to.
- If `func` is a bound method, extracts the class name directly from the method.
- If `func` is an unbound method and `args` are provided, assumes the first argument is `self` and extracts the class name.
- Returns an empty string if neither condition is met.
- If `func` is a bound method or a class method, extracts the class name directly from the method.
- Returns an empty string if it's a regular function or cannot determine the class.
"""
if inspect.ismethod(func):
if inspect.isclass(func.__self__):
return func.__self__.__name__
return func.__self__.__class__.__name__
if inspect.isfunction(func) and "self" in inspect.signature(func).parameters and args:
return args[0].__class__.__name__
if inspect.isfunction(func):
qualname_parts = func.__qualname__.split(".")
if len(qualname_parts) > 1:
class_name = qualname_parts[-2]
if class_name.isidentifier():
return class_name
return ""
def get_func_or_method_name(func, *args) -> str:
"""Function name, or method name with class name."""
cls_name = get_class_name(func, *args)
return f"{cls_name}.{func.__name__}" if cls_name else f"{func.__name__}"

View file

@ -1,14 +1,28 @@
import asyncio
import inspect
import pytest
from metagpt.exp_pool.decorator import ExpCacheHandler
from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import Experience, QueryType, Score
from metagpt.exp_pool.scorers import SimpleScorer
from metagpt.rag.engines import SimpleEngine
def for_test_function(a, b, c=None):
return a + b if c is None else a + b + c
class ForTestClass:
def for_test_method(self, x, y):
return x * y
@classmethod
def for_test_class_method(cls, x, y):
return x**y
class TestExpCache:
@pytest.fixture
def mock_func(self, mocker):
@ -46,7 +60,7 @@ class TestExpCache:
perfect_exp = Experience(req="req", resp="resp")
mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp
# Execute
# Exec
exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences
result = exp_cache_handler.get_one_perfect_experience()
@ -60,7 +74,7 @@ class TestExpCache:
mock_exp_manager.extract_one_perfect_exp.return_value = None
mock_func.return_value = "Computed result"
# Execute
# Exec
await exp_cache_handler.execute_function()
# Assert
@ -73,7 +87,7 @@ class TestExpCache:
mock_scorer.evaluate.return_value = Score(value=100)
exp_cache_handler._result = "Computed result"
# Execute
# Exec
await exp_cache_handler.evaluate_experience()
exp_cache_handler.save_experience()
@ -84,12 +98,12 @@ class TestExpCache:
@pytest.mark.asyncio
async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.pass_exps = True
exp_cache_handler.pass_exps_to_func = True
mock_func.return_value = "Async result with exps"
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
# Execute
# Exec
await exp_cache_handler.execute_function()
# Assert
@ -99,11 +113,11 @@ class TestExpCache:
def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps")
exp_cache_handler.pass_exps = True
exp_cache_handler.pass_exps_to_func = True
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
# Execute
# Exec
asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function())
# Assert
@ -114,7 +128,7 @@ class TestExpCache:
# Setup
mock_func = mocker.AsyncMock()
# Execute
# Exec
wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function)
# Assert
@ -124,22 +138,80 @@ class TestExpCache:
# Setup
sync_func = mocker.Mock()
# Execute
# Exec
wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function)
# Assert
assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous"
@pytest.mark.asyncio
async def test_generate_req_identifier(self, exp_cache_handler):
# Setup
exp_cache_handler.func = lambda x: x
exp_cache_handler.args = (42,)
exp_cache_handler.kwargs = {"y": 3.14}
@pytest.mark.parametrize(
"func, args, kwargs, expected",
[
(for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'),
(ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"),
(ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"),
(for_test_function, (), {}, "for_test_function@[]@{}"),
(
for_test_function,
("hello", [1, 2]),
{"key": "value"},
'for_test_function@["hello"~[1~2]]@{"key"!"value"}',
),
],
)
def test_generate_req_identifier(self, func, args, kwargs, expected):
req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs)
assert req_identifier == expected
# Execute
req_id = exp_cache_handler.generate_req_identifier()
@pytest.mark.asyncio
async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager):
# Mock perfect experience
perfect_exp = Experience(req="test_req", resp="perfect_response")
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp)
async_mock_func = mocker.AsyncMock()
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
# Exec
result: Experience = await decorated_func()
# Assert
expected_id = "<lambda>_(42,)_{'y': 3.14}"
assert req_id == expected_id, "Request identifier should match the expected format"
assert result.resp == "perfect_response", "Should return the perfect experience response"
async_mock_func.assert_not_called()
@pytest.mark.asyncio
async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager):
# Mock
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
async_mock_func = mocker.AsyncMock(return_value="computed_response")
async_mock_func.__signature__ = inspect.signature(for_test_function)
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
# Exec
result = await decorated_func()
# Assert
assert result == "computed_response", "Should execute and return the function's response"
async_mock_func.assert_called_once()
@pytest.mark.asyncio
async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer):
# Mock
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
async_mock_func = mocker.AsyncMock(return_value="computed_response")
mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100))
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer)
# Exec
await decorated_func()
# Assert
mock_exp_manager.create_exp.assert_called_once()

View file

@ -1,29 +1,35 @@
from metagpt.utils.reflection import get_func_or_method_name
from metagpt.utils.reflection import get_class_name
def simple_function():
pass
class SampleClass:
def method(self):
class SimpleFunction:
def function(self):
pass
class TestFunctionOrMethodName:
def test_simple_function(self):
assert get_func_or_method_name(simple_function) == "simple_function"
class SampleClass:
@classmethod
def class_method(cls):
pass
def test_class_method_without_args(self):
sample_instance = SampleClass()
assert get_func_or_method_name(sample_instance.method) == "SampleClass.method"
def instance_method(self):
pass
def test_class_method_with_args(self):
sample_instance = SampleClass()
assert get_func_or_method_name(SampleClass.method, sample_instance) == "SampleClass.method"
def test_function_with_no_args(self):
assert get_func_or_method_name(simple_function) == "simple_function"
def standalone_function():
pass
def test_method_without_instance(self):
assert get_func_or_method_name(SampleClass.method) == "method"
class TestGetClassName:
def test_instance_method(self):
instance = SampleClass()
assert get_class_name(instance.instance_method) == "SampleClass"
def test_class_method(self):
assert get_class_name(SampleClass.class_method) == "SampleClass"
def test_standalone_function(self):
assert get_class_name(standalone_function) == ""
def test_function_within_simple_class(self):
instance = SimpleFunction()
assert get_class_name(instance.function) == "SimpleFunction"