add exp_pool tests

This commit is contained in:
seehi 2024-06-05 22:15:09 +08:00
parent 1d8d85e9a5
commit c78cddd102
9 changed files with 391 additions and 43 deletions

View file

@ -4,56 +4,134 @@ import asyncio
import functools
from typing import Any, Callable, Optional, TypeVar
from metagpt.exp_pool.manager import exp_manager
from metagpt.exp_pool.schema import Experience
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
from metagpt.utils.async_helper import NestAsyncio
ReturnType = TypeVar("ReturnType")
def exp_cache(_func: Optional[Callable[..., ReturnType]] = None):
"""Decorator to check for a perfect experience and returns it if exists.
Otherwise, it executes the function, save the result as a new experience, and returns the result.
def exp_cache(
_func: Optional[Callable[..., ReturnType]] = None,
query_type: QueryType = QueryType.SEMANTIC,
scorer: Optional[ExperienceScorer] = None,
manager: Optional[ExperienceManager] = None,
pass_exps_to_func: bool = False,
):
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
This can be applied to both synchronous and asynchronous functions.
Args:
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
query_type: The type of query to be used when fetching experiences.
scorer: Evaluate experience. Default SimpleScorer.
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType:
"""Attempts to retrieve a perfect experience or creates an experience if not found."""
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
handler = ExpCacheHandler(
func=func,
args=args,
kwargs=kwargs,
exp_manager=manager or exp_manager,
exp_scorer=scorer or SimpleScorer(),
pass_exps=pass_exps_to_func,
)
# 1. Get exps.
req = f"{func.__name__}_{args}_{kwargs}"
exps = await exp_manager.query_exps(req)
if perfect_exp := exp_manager.extract_one_perfect_exp(exps):
return perfect_exp
await handler.fetch_experiences(query_type)
if exp := handler.get_one_perfect_experience():
return exp
# 2. Exec func. TODO: pass exps to func
if is_async:
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
await handler.execute_function()
await handler.evaluate_experience()
handler.save_experience()
# 3. Create an exp.
exp_manager.create_exp(Experience(req=req, resp=result))
return handler._result
return result
return ExpCacheHandler.choose_wrapper(func, get_or_create)
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
return decorator(_func) if _func else decorator
class ExpCacheHandler(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
func: Callable
args: Any
kwargs: Any
exp_manager: ExperienceManager
exp_scorer: ExperienceScorer
pass_exps: bool
_exps: list[Experience] = None
_result: Any = None
_score: Score = None
async def fetch_experiences(self, query_type: QueryType):
"""Fetch a potentially perfect existing experience."""
req = self.generate_req_identifier()
self._exps = await self.exp_manager.query_exps(req, query_type=query_type)
def get_one_perfect_experience(self) -> Optional[Experience]:
return self.exp_manager.extract_one_perfect_exp(self._exps)
async def execute_function(self):
"""Execute the function, and save the result."""
self._result = await self._execute_function()
async def evaluate_experience(self):
"""Evaluate the experience, and save the score."""
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
def save_experience(self):
"""Save the new experience."""
req = self.generate_req_identifier()
exp = Experience(req=req, resp=self._result, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
def generate_req_identifier(self):
"""Generate a unique request identifier based on the function and its arguments."""
return f"{self.func.__name__}_{self.args}_{self.kwargs}"
@staticmethod
def choose_wrapper(func, wrapped_func):
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
async def async_wrapper(*args, **kwargs):
return await wrapped_func(args, kwargs)
def sync_wrapper(*args, **kwargs):
NestAsyncio.apply_once()
return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False))
return asyncio.get_event_loop().run_until_complete(wrapped_func(args, kwargs))
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
return await get_or_create(args, kwargs, is_async=True)
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
async def _execute_function(self):
if self.pass_exps:
return await self._execute_function_with_exps()
if _func is None:
return decorator
else:
return decorator(_func)
return await self._execute_function_without_exps()
async def _execute_function_without_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
async def _execute_function_with_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs, exps=self._exps)
return self.func(*self.args, **self.kwargs, exps=self._exps)

View file

@ -5,7 +5,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
@ -45,12 +45,13 @@ class ExperienceManager(BaseModel):
self.storage.add_objs([exp])
async def query_exps(self, req: str, tag: str = "") -> list[Experience]:
async def query_exps(self, req: str, tag: str = "", query_type: QueryType = QueryType.SEMANTIC) -> list[Experience]:
"""Retrieves and filters experiences.
Args:
req (str): The query string to retrieve experiences.
tag (str): Optional tag to filter the experiences by.
query_type (QueryType): Default semantic to vector matching. exact to same matching.
Returns:
list[Experience]: A list of experiences that match the args.
@ -65,6 +66,9 @@ class ExperienceManager(BaseModel):
if tag:
exps = [exp for exp in exps if exp.tag == tag]
if query_type == QueryType.EXACT:
exps = [exp for exp in exps if exp.req == req]
return exps
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
@ -96,7 +100,7 @@ class ExperienceManager(BaseModel):
return False
# TODO: need more metrics
if exp.metric and exp.metric.score == MAX_SCORE:
if exp.metric and exp.metric.score.val == MAX_SCORE:
return True
return False

View file

@ -9,6 +9,13 @@ from pydantic import BaseModel, Field
MAX_SCORE = 10
class QueryType(str, Enum):
"""Type of query experiences."""
EXACT = "exact"
SEMANTIC = "semantic"
class ExperienceType(str, Enum):
"""Experience Type."""
@ -24,12 +31,19 @@ class EntryType(Enum):
MANUAL = "Manual"
class Score(BaseModel):
"""Score in Metric."""
val: int = Field(default=1, description="Value of the score, Between 1 and 10, higher is better.")
reason: str = Field(default="", description="Reason for the value.")
class Metric(BaseModel):
"""Experience Metric."""
time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.")
money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.")
score: int = Field(default=1, description="Score, a value between 1 and 10.")
score: Score = Field(default=None, description="Score, with value and reason.")
class Trajectory(BaseModel):

View file

@ -0,0 +1,6 @@
"""Experience scorers init."""
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.simple import SimpleScorer
__all__ = ["ExperienceScorer", "SimpleScorer"]

View file

@ -0,0 +1,27 @@
"""Experience Scorers."""
from abc import abstractmethod
from typing import Any, Callable
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Score
class ExperienceScorer(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
"""Evaluate the quality of the result produced by the function and parameters.
Args:
func (Callable): The function whose result is to be evaluated.
result (Any): The result produced by the function.
args (Tuple[Any, ...]): The tuple of arguments that were passed to the function.
kwargs (Dict[str, Any]): The dictionary of keyword arguments that were passed to the function.
Example:
result = await sample(5, name="foo")
score = await scorer.evaluate(sample, result, args=(5), kwargs={"name": "foo"})
"""

View file

@ -0,0 +1,73 @@
"""Evalate by llm."""
import inspect
import json
from typing import Any, Callable
from pydantic import Field
from metagpt.exp_pool.schema import Score
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import parse_json_code_block
SIMPLE_SCORER_TEMPLATE = """
Role: You're an expert score evaluator. You specialize in assessing the output of the given function, based on its intended requirement and produced result.
## Context
### Function Name
{func_name}
### Function Document
{func_doc}
### Function Signature
{func_signature}
### Function Parameters
args: {func_args}
kwargs: {func_kwargs}
### Produced Result By Function and Parameters
{func_result}
## Format Example
```json
{{
"val": "the value of the score, int from 1 to 10, higher is better.",
"reason": "an explanation supporting the score."
}}
```
## Instructions
- Understand the function and requirements given by the user.
- Analyze the results produced by the function.
- Grade the results based on level of alignment with the requirements.
- Provide a score on a scale defined by user or a default scale (1 to 10).
## Constraint
Format: Just print the result in json format like **Format Example**.
## Action
Follow instructions, generate output and make sure it follows the **Constraint**.
"""
class SimpleScorer(ExperienceScorer):
llm: BaseLLM = Field(default_factory=LLM)
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
"""Evaluate the quality of content."""
prompt = SIMPLE_SCORER_TEMPLATE.format(
func_name=func.__name__,
func_doc=func.__doc__,
func_signature=inspect.signature(func),
func_args=args,
func_kwargs=kwargs,
func_result=result,
)
resp = await self.llm.aask(prompt)
resp_json = json.loads(parse_json_code_block(resp)[0])
return Score(**resp_json)