add exp_pool tests

This commit is contained in:
seehi 2024-06-05 22:15:09 +08:00
parent 1d8d85e9a5
commit c78cddd102
9 changed files with 391 additions and 43 deletions

View file

@ -7,8 +7,9 @@ from metagpt.exp_pool import exp_cache, exp_manager
from metagpt.logs import logger
@exp_cache
async def produce(req):
@exp_cache(pass_exps_to_func=True)
async def produce(req, exps=None):
logger.info(f"Previous experiences: {exps}")
return f"{req} {uuid.uuid4().hex}"

View file

@ -4,56 +4,134 @@ import asyncio
import functools
from typing import Any, Callable, Optional, TypeVar
from metagpt.exp_pool.manager import exp_manager
from metagpt.exp_pool.schema import Experience
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
from metagpt.utils.async_helper import NestAsyncio
ReturnType = TypeVar("ReturnType")
def exp_cache(_func: Optional[Callable[..., ReturnType]] = None):
"""Decorator to check for a perfect experience and returns it if exists.
Otherwise, it executes the function, save the result as a new experience, and returns the result.
def exp_cache(
_func: Optional[Callable[..., ReturnType]] = None,
query_type: QueryType = QueryType.SEMANTIC,
scorer: Optional[ExperienceScorer] = None,
manager: Optional[ExperienceManager] = None,
pass_exps_to_func: bool = False,
):
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
This can be applied to both synchronous and asynchronous functions.
Args:
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
query_type: The type of query to be used when fetching experiences.
scorer: Evaluate experience. Default SimpleScorer.
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType:
"""Attempts to retrieve a perfect experience or creates an experience if not found."""
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
handler = ExpCacheHandler(
func=func,
args=args,
kwargs=kwargs,
exp_manager=manager or exp_manager,
exp_scorer=scorer or SimpleScorer(),
pass_exps=pass_exps_to_func,
)
# 1. Get exps.
req = f"{func.__name__}_{args}_{kwargs}"
exps = await exp_manager.query_exps(req)
if perfect_exp := exp_manager.extract_one_perfect_exp(exps):
return perfect_exp
await handler.fetch_experiences(query_type)
if exp := handler.get_one_perfect_experience():
return exp
# 2. Exec func. TODO: pass exps to func
if is_async:
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
await handler.execute_function()
await handler.evaluate_experience()
handler.save_experience()
# 3. Create an exp.
exp_manager.create_exp(Experience(req=req, resp=result))
return handler._result
return result
return ExpCacheHandler.choose_wrapper(func, get_or_create)
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
return decorator(_func) if _func else decorator
class ExpCacheHandler(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
func: Callable
args: Any
kwargs: Any
exp_manager: ExperienceManager
exp_scorer: ExperienceScorer
pass_exps: bool
_exps: list[Experience] = None
_result: Any = None
_score: Score = None
async def fetch_experiences(self, query_type: QueryType):
"""Fetch a potentially perfect existing experience."""
req = self.generate_req_identifier()
self._exps = await self.exp_manager.query_exps(req, query_type=query_type)
def get_one_perfect_experience(self) -> Optional[Experience]:
return self.exp_manager.extract_one_perfect_exp(self._exps)
async def execute_function(self):
"""Execute the function, and save the result."""
self._result = await self._execute_function()
async def evaluate_experience(self):
"""Evaluate the experience, and save the score."""
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
def save_experience(self):
"""Save the new experience."""
req = self.generate_req_identifier()
exp = Experience(req=req, resp=self._result, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
def generate_req_identifier(self):
"""Generate a unique request identifier based on the function and its arguments."""
return f"{self.func.__name__}_{self.args}_{self.kwargs}"
@staticmethod
def choose_wrapper(func, wrapped_func):
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
async def async_wrapper(*args, **kwargs):
return await wrapped_func(args, kwargs)
def sync_wrapper(*args, **kwargs):
NestAsyncio.apply_once()
return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False))
return asyncio.get_event_loop().run_until_complete(wrapped_func(args, kwargs))
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
return await get_or_create(args, kwargs, is_async=True)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
async def _execute_function(self):
if self.pass_exps:
return await self._execute_function_with_exps()
if _func is None:
return decorator
else:
return decorator(_func)
return await self._execute_function_without_exps()
async def _execute_function_without_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
async def _execute_function_with_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs, exps=self._exps)
return self.func(*self.args, **self.kwargs, exps=self._exps)

View file

@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
@ -45,12 +45,13 @@ class ExperienceManager(BaseModel):
self.storage.add_objs([exp])
async def query_exps(self, req: str, tag: str = "") -> list[Experience]:
async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]:
"""Retrieves and filters experiences.
Args:
req (str): The query string to retrieve experiences.
tag (str): Optional tag to filter the experiences by.
query_type (QueryType): Default semantic to vector matching. exact to same matching.
Returns:
list[Experience]: A list of experiences that match the args.
@ -65,6 +66,9 @@ class ExperienceManager(BaseModel):
if tag:
exps = [exp for exp in exps if exp.tag == tag]
if query_type == QueryType.EXACT:
exps = [exp for exp in exps if exp.req == req]
return exps
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
@ -96,7 +100,7 @@ class ExperienceManager(BaseModel):
return False
# TODO: need more metrics
if exp.metric and exp.metric.score == MAX_SCORE:
if exp.metric and exp.metric.score.val == MAX_SCORE:
return True
return False

View file

@ -9,6 +9,13 @@ from pydantic import BaseModel, Field
MAX_SCORE = 10
class QueryType(str, Enum):
"""Type of query experiences."""
EXACT = "exact"
SEMANTIC = "semantic"
class ExperienceType(str, Enum):
"""Experience Type."""
@ -24,12 +31,19 @@ class EntryType(Enum):
MANUAL = "Manual"
class Score(BaseModel):
"""Score in Metric."""
val: int = Field(default=1, description="Value of the score, Between 1 and 10, higher is better.")
reason: str = Field(default="", description="Reason for the value.")
class Metric(BaseModel):
"""Experience Metric."""
time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.")
money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.")
score: int = Field(default=1, description="Score, a value between 1 and 10.")
score: Score = Field(default=None, description="Score, with value and reason.")
class Trajectory(BaseModel):

View file

@ -0,0 +1,6 @@
"""Experience scorers init."""
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.simple import SimpleScorer
__all__ = ["ExperienceScorer", "SimpleScorer"]

View file

@ -0,0 +1,27 @@
"""Experience Scorers."""
from abc import abstractmethod
from typing import Any, Callable
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Score
class ExperienceScorer(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
"""Evaluate the quality of the result produced by the function and parameters.
Args:
func (Callable): The function whose result is to be evaluated.
result (Any): The result produced by the function.
args (Tuple[Any, ...]): The tuple of arguments that were passed to the function.
kwargs (Dict[str, Any]): The dictionary of keyword arguments that were passed to the function.
Example:
result = await sample(5, name="foo")
score = await scorer.evaluate(sample, result, args=(5), kwargs={"name": "foo"})
"""

View file

@ -0,0 +1,73 @@
"""Evalate by llm."""
import inspect
import json
from typing import Any, Callable
from pydantic import Field
from metagpt.exp_pool.schema import Score
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import parse_json_code_block
SIMPLE_SCORER_TEMPLATE = """
Role: You're an expert score evaluator. You specialize in assessing the output of the given function, based on its intended requirement and produced result.
## Context
### Function Name
{func_name}
### Function Document
{func_doc}
### Function Signature
{func_signature}
### Function Parameters
args: {func_args}
kwargs: {func_kwargs}
### Produced Result By Function and Parameters
{func_result}
## Format Example
```json
{{
"val": "the value of the score, int from 1 to 10, higher is better.",
"reason": "an explanation supporting the score."
}}
```
## Instructions
- Understand the function and requirements given by the user.
- Analyze the results produced by the function.
- Grade the results based on level of alignment with the requirements.
- Provide a score on a scale defined by user or a default scale (1 to 10).
## Constraint
Format: Just print the result in json format like **Format Example**.
## Action
Follow instructions, generate output and make sure it follows the **Constraint**.
"""
class SimpleScorer(ExperienceScorer):
llm: BaseLLM = Field(default_factory=LLM)
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
"""Evaluate the quality of content."""
prompt = SIMPLE_SCORER_TEMPLATE.format(
func_name=func.__name__,
func_doc=func.__doc__,
func_signature=inspect.signature(func),
func_args=args,
func_kwargs=kwargs,
func_result=result,
)
resp = await self.llm.aask(prompt)
resp_json = json.loads(parse_json_code_block(resp)[0])
return Score(**resp_json)

View file

@ -0,0 +1,145 @@
import asyncio
import pytest
from metagpt.exp_pool.decorator import ExpCacheHandler
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import Experience, QueryType, Score
from metagpt.exp_pool.scorers import SimpleScorer
from metagpt.rag.engines import SimpleEngine
class TestExpCache:
@pytest.fixture
def mock_func(self, mocker):
return mocker.AsyncMock()
@pytest.fixture
def mock_exp_manager(self, mocker):
manager = mocker.MagicMock(spec=ExperienceManager)
manager.storage = mocker.MagicMock(spec=SimpleEngine)
manager.query_exps = mocker.AsyncMock()
manager.create_exp = mocker.MagicMock()
manager.extract_one_perfect_exp = mocker.MagicMock()
return manager
@pytest.fixture
def mock_scorer(self, mocker):
scorer = mocker.MagicMock(spec=SimpleScorer)
scorer.evaluate = mocker.AsyncMock()
return scorer
@pytest.fixture
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer):
return ExpCacheHandler(
func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False
)
@pytest.mark.asyncio
async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager):
await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC)
mock_exp_manager.query_exps.assert_called_once()
@pytest.mark.asyncio
async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup: Assume perfect experience is found
perfect_exp = Experience(req="req", resp="resp")
mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp
# Execute
exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences
result = exp_cache_handler.get_one_perfect_experience()
# Assert
assert result.resp == "resp"
mock_func.assert_not_called() # Function should not be called
@pytest.mark.asyncio
async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup: No perfect experience
mock_exp_manager.extract_one_perfect_exp.return_value = None
mock_func.return_value = "Computed result"
# Execute
await exp_cache_handler.execute_function()
# Assert
assert exp_cache_handler._result == "Computed result"
mock_func.assert_called_once()
@pytest.mark.asyncio
async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
# Setup
mock_scorer.evaluate.return_value = Score(value=100)
exp_cache_handler._result = "Computed result"
# Execute
await exp_cache_handler.evaluate_experience()
exp_cache_handler.save_experience()
# Assert
mock_scorer.evaluate.assert_called_once()
mock_exp_manager.create_exp.assert_called_once()
@pytest.mark.asyncio
async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.pass_exps = True
mock_func.return_value = "Async result with exps"
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
# Execute
await exp_cache_handler.execute_function()
# Assert
mock_func.assert_called_once_with(exps=exp_cache_handler._exps)
assert exp_cache_handler._result == "Async result with exps"
def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps")
exp_cache_handler.pass_exps = True
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
# Execute
asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function())
# Assert
exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps)
assert exp_cache_handler._result == "Sync result with exps"
def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func):
# Setup
mock_func = mocker.AsyncMock()
# Execute
wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function)
# Assert
assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous"
def test_wrapper_selection_sync(self, exp_cache_handler, mocker):
# Setup
sync_func = mocker.Mock()
# Execute
wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function)
# Assert
assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous"
@pytest.mark.asyncio
async def test_generate_req_identifier(self, exp_cache_handler):
# Setup
exp_cache_handler.func = lambda x: x
exp_cache_handler.args = (42,)
exp_cache_handler.kwargs = {"y": 3.14}
# Execute
req_id = exp_cache_handler.generate_req_identifier()
# Assert
expected_id = "<lambda>_(42,)_{'y': 3.14}"
assert req_id == expected_id, "Request identifier should match the expected format"

View file

@ -4,7 +4,7 @@ from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
from metagpt.rag.engines import SimpleEngine
@ -62,15 +62,15 @@ class TestExperienceManager:
def test_extract_one_perfect_exp(self, mock_experience_manager):
experiences = [
Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE)),
Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))),
Experience(req="req", resp="resp"),
]
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
assert perfect_exp is not None
assert perfect_exp.metric.score == MAX_SCORE
assert perfect_exp.metric.score.val == MAX_SCORE
def test_is_perfect_exp(self):
exp = Experience(req="req", resp="resp", metric=Metric(score=MAX_SCORE))
exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
assert ExperienceManager.is_perfect_exp(exp) == True
exp = Experience(req="req", resp="resp")