diff --git a/examples/exp_pool/decorator.py b/examples/exp_pool/decorator.py index 2f6397f80..3f6093e01 100644 --- a/examples/exp_pool/decorator.py +++ b/examples/exp_pool/decorator.py @@ -7,8 +7,9 @@ from metagpt.exp_pool import exp_cache, exp_manager from metagpt.logs import logger -@exp_cache -async def produce(req): +@exp_cache(pass_exps_to_func=True) +async def produce(req, exps=None): + logger.info(f"Previous experiences: {exps}") return f"{req} {uuid.uuid4().hex}" diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index e073ee494..9eb4d9e61 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -4,56 +4,134 @@ import asyncio import functools from typing import Any, Callable, Optional, TypeVar -from metagpt.exp_pool.manager import exp_manager -from metagpt.exp_pool.schema import Experience +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.manager import ExperienceManager, exp_manager +from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score +from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer from metagpt.utils.async_helper import NestAsyncio ReturnType = TypeVar("ReturnType") -def exp_cache(_func: Optional[Callable[..., ReturnType]] = None): - """Decorator to check for a perfect experience and returns it if exists. - - Otherwise, it executes the function, save the result as a new experience, and returns the result. +def exp_cache( + _func: Optional[Callable[..., ReturnType]] = None, + query_type: QueryType = QueryType.SEMANTIC, + scorer: Optional[ExperienceScorer] = None, + manager: Optional[ExperienceManager] = None, + pass_exps_to_func: bool = False, +): + """Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience. This can be applied to both synchronous and asynchronous functions. + + Args: + _func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache(). + query_type: The type of query to be used when fetching experiences. + scorer: Evaluate experience. Default SimpleScorer. + manager: How to fetch, evaluate and save experience, etc. Default exp_manager. + pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'. """ def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: @functools.wraps(func) - async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType: - """Attempts to retrieve a perfect experience or creates an experience if not found.""" + async def get_or_create(args: Any, kwargs: Any) -> ReturnType: + handler = ExpCacheHandler( + func=func, + args=args, + kwargs=kwargs, + exp_manager=manager or exp_manager, + exp_scorer=scorer or SimpleScorer(), + pass_exps=pass_exps_to_func, + ) - # 1. Get exps. - req = f"{func.__name__}_{args}_{kwargs}" - exps = await exp_manager.query_exps(req) - if perfect_exp := exp_manager.extract_one_perfect_exp(exps): - return perfect_exp + await handler.fetch_experiences(query_type) + if exp := handler.get_one_perfect_experience(): + return exp - # 2. Exec func. TODO: pass exps to func - if is_async: - result = await func(*args, **kwargs) - else: - result = func(*args, **kwargs) + await handler.execute_function() + await handler.evaluate_experience() + handler.save_experience() - # 3. Create an exp. - exp_manager.create_exp(Experience(req=req, resp=result)) + return handler._result - return result + return ExpCacheHandler.choose_wrapper(func, get_or_create) - def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + return decorator(_func) if _func else decorator + + +class ExpCacheHandler(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + func: Callable + args: Any + kwargs: Any + exp_manager: ExperienceManager + exp_scorer: ExperienceScorer + pass_exps: bool + + _exps: list[Experience] = None + _result: Any = None + _score: Score = None + + async def fetch_experiences(self, query_type: QueryType): + """Fetch a potentially perfect existing experience.""" + + req = self.generate_req_identifier() + self._exps = await self.exp_manager.query_exps(req, query_type=query_type) + + def get_one_perfect_experience(self) -> Optional[Experience]: + return self.exp_manager.extract_one_perfect_exp(self._exps) + + async def execute_function(self): + """Execute the function, and save the result.""" + self._result = await self._execute_function() + + async def evaluate_experience(self): + """Evaluate the experience, and save the score.""" + + self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs) + + def save_experience(self): + """Save the new experience.""" + + req = self.generate_req_identifier() + exp = Experience(req=req, resp=self._result, metric=Metric(score=self._score)) + + self.exp_manager.create_exp(exp) + + def generate_req_identifier(self): + """Generate a unique request identifier based on the function and its arguments.""" + + return f"{self.func.__name__}_{self.args}_{self.kwargs}" + + @staticmethod + def choose_wrapper(func, wrapped_func): + """Choose how to run wrapped_func based on whether the function is asynchronous.""" + + async def async_wrapper(*args, **kwargs): + return await wrapped_func(args, kwargs) + + def sync_wrapper(*args, **kwargs): NestAsyncio.apply_once() - return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False)) + return asyncio.get_event_loop().run_until_complete(wrapped_func(args, kwargs)) - async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType: - return await get_or_create(args, kwargs, is_async=True) + return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper - if asyncio.iscoroutinefunction(func): - return async_wrapper - else: - return sync_wrapper + async def _execute_function(self): + if self.pass_exps: + return await self._execute_function_with_exps() - if _func is None: - return decorator - else: - return decorator(_func) + return await self._execute_function_without_exps() + + async def _execute_function_without_exps(self): + if asyncio.iscoroutinefunction(self.func): + return await self.func(*self.args, **self.kwargs) + + return self.func(*self.args, **self.kwargs) + + async def _execute_function_with_exps(self): + if asyncio.iscoroutinefunction(self.func): + return await self.func(*self.args, **self.kwargs, exps=self._exps) + + return self.func(*self.args, **self.kwargs, exps=self._exps) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 4bc566104..58499104d 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -5,7 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict, model_validator from metagpt.config2 import Config, config -from metagpt.exp_pool.schema import MAX_SCORE, Experience +from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig @@ -45,12 +45,13 @@ class ExperienceManager(BaseModel): self.storage.add_objs([exp]) - async def query_exps(self, req: str, tag: str = "") -> list[Experience]: + async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]: """Retrieves and filters experiences. Args: req (str): The query string to retrieve experiences. tag (str): Optional tag to filter the experiences by. + query_type (QueryType): Default semantic to vector matching. exact to same matching. Returns: list[Experience]: A list of experiences that match the args. @@ -65,6 +66,9 @@ class ExperienceManager(BaseModel): if tag: exps = [exp for exp in exps if exp.tag == tag] + if query_type == QueryType.EXACT: + exps = [exp for exp in exps if exp.req == req] + return exps def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]: @@ -96,7 +100,7 @@ class ExperienceManager(BaseModel): return False # TODO: need more metrics - if exp.metric and exp.metric.score == MAX_SCORE: + if exp.metric and exp.metric.score.val == MAX_SCORE: return True return False diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index 1afcc1508..9fc665cca 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -9,6 +9,13 @@ from pydantic import BaseModel, Field MAX_SCORE = 10 +class QueryType(str, Enum): + """Type of query experiences.""" + + EXACT = "exact" + SEMANTIC = "semantic" + + class ExperienceType(str, Enum): """Experience Type.""" @@ -24,12 +31,19 @@ class EntryType(Enum): MANUAL = "Manual" +class Score(BaseModel): + """Score in Metric.""" + + val: int = Field(default=1, description="Value of the score, Between 1 and 10, higher is better.") + reason: str = Field(default="", description="Reason for the value.") + + class Metric(BaseModel): """Experience Metric.""" time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.") money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.") - score: int = Field(default=1, description="Score, a value between 1 and 10.") + score: Score = Field(default=None, description="Score, with value and reason.") class Trajectory(BaseModel): diff --git a/metagpt/exp_pool/scorers/__init__.py b/metagpt/exp_pool/scorers/__init__.py new file mode 100644 index 000000000..85bea88ff --- /dev/null +++ b/metagpt/exp_pool/scorers/__init__.py @@ -0,0 +1,6 @@ +"""Experience scorers init.""" + +from metagpt.exp_pool.scorers.base import ExperienceScorer +from metagpt.exp_pool.scorers.simple import SimpleScorer + +__all__ = ["ExperienceScorer", "SimpleScorer"] diff --git a/metagpt/exp_pool/scorers/base.py b/metagpt/exp_pool/scorers/base.py new file mode 100644 index 000000000..a9d30cffe --- /dev/null +++ b/metagpt/exp_pool/scorers/base.py @@ -0,0 +1,27 @@ +"""Experience Scorers.""" + +from abc import abstractmethod +from typing import Any, Callable + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Score + + +class ExperienceScorer(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score: + """Evaluate the quality of the result produced by the function and parameters. + + Args: + func (Callable): The function whose result is to be evaluated. + result (Any): The result produced by the function. + args (Tuple[Any, ...]): The tuple of arguments that were passed to the function. + kwargs (Dict[str, Any]): The dictionary of keyword arguments that were passed to the function. + + Example: + result = await sample(5, name="foo") + score = await scorer.evaluate(sample, result, args=(5), kwargs={"name": "foo"}) + """ diff --git a/metagpt/exp_pool/scorers/simple.py b/metagpt/exp_pool/scorers/simple.py new file mode 100644 index 000000000..d0301cbc2 --- /dev/null +++ b/metagpt/exp_pool/scorers/simple.py @@ -0,0 +1,73 @@ +"""Evalate by llm.""" +import inspect +import json +from typing import Any, Callable + +from pydantic import Field + +from metagpt.exp_pool.schema import Score +from metagpt.exp_pool.scorers.base import ExperienceScorer +from metagpt.llm import LLM +from metagpt.provider.base_llm import BaseLLM +from metagpt.utils.common import parse_json_code_block + +SIMPLE_SCORER_TEMPLATE = """ +Role: You're an expert score evaluator. You specialize in assessing the output of the given function, based on its intended requirement and produced result. + +## Context +### Function Name +{func_name} + +### Function Document +{func_doc} + +### Function Signature +{func_signature} + +### Function Parameters +args: {func_args} +kwargs: {func_kwargs} + +### Produced Result By Function and Parameters +{func_result} + +## Format Example +```json +{{ + "val": "the value of the score, int from 1 to 10, higher is better.", + "reason": "an explanation supporting the score." +}} +``` + +## Instructions +- Understand the function and requirements given by the user. +- Analyze the results produced by the function. +- Grade the results based on level of alignment with the requirements. +- Provide a score on a scale defined by user or a default scale (1 to 10). + +## Constraint +Format: Just print the result in json format like **Format Example**. + +## Action +Follow instructions, generate output and make sure it follows the **Constraint**. +""" + + +class SimpleScorer(ExperienceScorer): + llm: BaseLLM = Field(default_factory=LLM) + + async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score: + """Evaluate the quality of content.""" + + prompt = SIMPLE_SCORER_TEMPLATE.format( + func_name=func.__name__, + func_doc=func.__doc__, + func_signature=inspect.signature(func), + func_args=args, + func_kwargs=kwargs, + func_result=result, + ) + resp = await self.llm.aask(prompt) + resp_json = json.loads(parse_json_code_block(resp)[0]) + + return Score(**resp_json) diff --git a/tests/metagpt/exp_pool/test_decorator.py b/tests/metagpt/exp_pool/test_decorator.py new file mode 100644 index 000000000..508229d18 --- /dev/null +++ b/tests/metagpt/exp_pool/test_decorator.py @@ -0,0 +1,145 @@ +import asyncio + +import pytest + +from metagpt.exp_pool.decorator import ExpCacheHandler +from metagpt.exp_pool.manager import ExperienceManager +from metagpt.exp_pool.schema import Experience, QueryType, Score +from metagpt.exp_pool.scorers import SimpleScorer +from metagpt.rag.engines import SimpleEngine + + +class TestExpCache: + @pytest.fixture + def mock_func(self, mocker): + return mocker.AsyncMock() + + @pytest.fixture + def mock_exp_manager(self, mocker): + manager = mocker.MagicMock(spec=ExperienceManager) + manager.storage = mocker.MagicMock(spec=SimpleEngine) + manager.query_exps = mocker.AsyncMock() + manager.create_exp = mocker.MagicMock() + manager.extract_one_perfect_exp = mocker.MagicMock() + return manager + + @pytest.fixture + def mock_scorer(self, mocker): + scorer = mocker.MagicMock(spec=SimpleScorer) + scorer.evaluate = mocker.AsyncMock() + return scorer + + @pytest.fixture + def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer): + return ExpCacheHandler( + func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False + ) + + @pytest.mark.asyncio + async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager): + await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC) + mock_exp_manager.query_exps.assert_called_once() + + @pytest.mark.asyncio + async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func): + # Setup: Assume perfect experience is found + perfect_exp = Experience(req="req", resp="resp") + mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp + + # Execute + exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences + result = exp_cache_handler.get_one_perfect_experience() + + # Assert + assert result.resp == "resp" + mock_func.assert_not_called() # Function should not be called + + @pytest.mark.asyncio + async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func): + # Setup: No perfect experience + mock_exp_manager.extract_one_perfect_exp.return_value = None + mock_func.return_value = "Computed result" + + # Execute + await exp_cache_handler.execute_function() + + # Assert + assert exp_cache_handler._result == "Computed result" + mock_func.assert_called_once() + + @pytest.mark.asyncio + async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager): + # Setup + mock_scorer.evaluate.return_value = Score(value=100) + exp_cache_handler._result = "Computed result" + + # Execute + await exp_cache_handler.evaluate_experience() + exp_cache_handler.save_experience() + + # Assert + mock_scorer.evaluate.assert_called_once() + mock_exp_manager.create_exp.assert_called_once() + + @pytest.mark.asyncio + async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func): + # Setup + exp_cache_handler.pass_exps = True + mock_func.return_value = "Async result with exps" + mock_exp_manager.extract_one_perfect_exp.return_value = None + exp_cache_handler._exps = [Experience(req="req", resp="resp")] + + # Execute + await exp_cache_handler.execute_function() + + # Assert + mock_func.assert_called_once_with(exps=exp_cache_handler._exps) + assert exp_cache_handler._result == "Async result with exps" + + def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func): + # Setup + exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps") + exp_cache_handler.pass_exps = True + mock_exp_manager.extract_one_perfect_exp.return_value = None + exp_cache_handler._exps = [Experience(req="req", resp="resp")] + + # Execute + asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function()) + + # Assert + exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps) + assert exp_cache_handler._result == "Sync result with exps" + + def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func): + # Setup + mock_func = mocker.AsyncMock() + + # Execute + wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function) + + # Assert + assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous" + + def test_wrapper_selection_sync(self, exp_cache_handler, mocker): + # Setup + sync_func = mocker.Mock() + + # Execute + wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function) + + # Assert + assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous" + + @pytest.mark.asyncio + async def test_generate_req_identifier(self, exp_cache_handler): + # Setup + exp_cache_handler.func = lambda x: x + exp_cache_handler.args = (42,) + exp_cache_handler.kwargs = {"y": 3.14} + + # Execute + req_id = exp_cache_handler.generate_req_identifier() + + # Assert + expected_id = "_(42,)_{'y': 3.14}" + assert req_id == expected_id, "Request identifier should match the expected format" diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index a0d7005f5..3e8f47417 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -4,7 +4,7 @@ from metagpt.config2 import Config from metagpt.configs.exp_pool_config import ExperiencePoolConfig from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import ExperienceManager -from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric +from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score from metagpt.rag.engines import SimpleEngine @@ -62,15 +62,15 @@ class TestExperienceManager: def test_extract_one_perfect_exp(self, mock_experience_manager): experiences = [ - Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)), + Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))), Experience(req="req", resp="resp"), ] perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences) assert perfect_exp is not None - assert perfect_exp.metric.score == MAX_SCORE + assert perfect_exp.metric.score.val == MAX_SCORE def test_is_perfect_exp(self): - exp = Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)) + exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) assert ExperienceManager.is_perfect_exp(exp) == True exp = Experience(req="req", resp="resp")