experiment pool init

This commit is contained in:
seehi 2024-06-04 10:28:39 +08:00
parent 46aa6f1975
commit 471310f3b3
14 changed files with 258 additions and 55 deletions

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.embedding_config import EmbeddingConfig
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.mermaid_config import MermaidConfig
from metagpt.configs.redis_config import RedisConfig
@ -22,7 +23,6 @@ from metagpt.configs.search_config import SearchConfig
from metagpt.configs.workspace_config import WorkspaceConfig
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
from metagpt.utils.yaml_model import YamlModel
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
class CLIParams(BaseModel):
@ -73,7 +73,7 @@ class Config(CLIParams, YamlModel):
code_review_k_times: int = 2
# Experience Pool Parameters
experience_pool: Optional[ExperiencePoolConfig] = None
exp_pool: ExperiencePoolConfig = ExperiencePoolConfig()
# Will be removed in the future
metagpt_tti_url: str = ""

View file

@ -1,6 +1,8 @@
from pydantic import Field
from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolConfig(YamlModel):
enable_read: bool = False
enable_write: bool = False
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")

View file

@ -0,0 +1,6 @@
"""Experience pool init."""
from metagpt.exp_pool.manager import exp_manager
from metagpt.exp_pool.decorator import exp_cache
__all__ = ["exp_manager", "exp_cache"]

View file

@ -1,4 +1,56 @@
"""Experience Decorator."""
import asyncio
import functools
from typing import Any, Callable, Optional, TypeVar
from metagpt.exp_pool.manager import exp_manager
from metagpt.exp_pool.schema import Experience
from metagpt.utils.async_helper import NestAsyncio
ReturnType = TypeVar("ReturnType")
def exp_cache(func):
pass
def exp_cache(_func: Optional[Callable[..., ReturnType]] = None):
"""Decorator to check for a perfect experience and returns it if exists.
Otherwise, it executes the function, save the result as a new experience, and returns the result.
This can be applied to both synchronous and asynchronous functions.
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
@functools.wraps(func)
async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType:
"""Attempts to retrieve a cached experience or creates one if not found."""
req = f"{func.__name__}_{args}_{kwargs}"
exps = await exp_manager.query_exps(req)
if perfect_exp := exp_manager.extract_one_perfect_exp(exps):
return perfect_exp
if is_async:
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
exp_manager.create_exp(Experience(req=req, resp=result))
return result
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
NestAsyncio.apply_once()
return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False))
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
return await get_or_create(args, kwargs, is_async=True)
if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
if _func is None:
return decorator
else:
return decorator(_func)

View file

@ -1,32 +1,105 @@
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Experience
import uuid
import chromadb
from chromadb import Collection, QueryResult
"""Experience Manager."""
from typing import Optional
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
class ExperiencePoolManager(BaseModel):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._storage = None
class ExperienceManager(BaseModel):
"""ExperienceManager manages the lifecycle of experiences, including CRUD and optimization.
Attributes:
config (Config): Configuration for managing experiences.
storage (SimpleEngine): Engine to handle the storage and retrieval of experiences.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
config: Config = config
storage: SimpleEngine = None
@model_validator(mode="after")
def initialize(self):
if self.storage is None:
self.storage = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
],
ranker_configs=[LLMRankerConfig()],
)
return self
@property
def storage(self) -> SimpleEngine:
if self._storage is None:
self._storage = SimpleEngine.from_objs(retriever_configs=[ChromaRetrieverConfig(collection_name="experience_pool", persist_path="./chroma_data")])
return self._storage
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
Args:
exp (Experience): The experience to add.
"""
if not self.config.exp_pool.enable_write:
return
self.storage.add_objs([exp])
async def query_exp(self, req: str) -> list[Experience]:
async def query_exps(self, req: str, tag: str = "") -> list[Experience]:
"""Retrieves and filters experiences.
Args:
req (str): The query string to retrieve experiences.
tag (str): Optional tag to filter the experiences by.
Returns:
list[Experience]: A list of experiences that match the args.
"""
if not self.config.exp_pool.enable_read:
return []
nodes = await self.storage.aretrieve(req)
exps = [node.metadata["obj"] for node in nodes]
exps: list[Experience] = [node.metadata["obj"] for node in nodes]
# TODO: filter by metadata
if tag:
exps = [exp for exp in exps if exp.tag == tag]
return exps
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
"""Extracts the first 'perfect' experience from a list of experiences.
Args:
exps (list[Experience]): The experiences to evaluate.
Returns:
Optional[Experience]: The first perfect experience if found, otherwise None.
"""
for exp in exps:
if self.is_perfect_exp(exp):
return exp
return None
@staticmethod
def is_perfect_exp(exp: Experience) -> bool:
"""Determines if an experience is considered 'perfect'.
Args:
exp (Experience): The experience to evaluate.
Returns:
bool: True if the experience is manually entered, otherwise False.
"""
if not exp:
return False
# TODO: need more metrics
if exp.metric and exp.metric.score == MAX_SCORE:
return True
return False
exp_manager = ExperienceManager()

View file

@ -1,10 +1,46 @@
from pydantic import BaseModel, Field
"""Experience schema."""
from enum import Enum
from typing import Optional
from llama_index.core.schema import TextNode
from pydantic import BaseModel, Field
MAX_SCORE = 10
class ExperienceType(str, Enum):
"""Experience Type."""
SUCCESS = "success"
FAILURE = "failure"
INSIGHT = "insight"
class EntryType(Enum):
"""Experience Entry Type."""
AUTOMATIC = "Automatic"
MANUAL = "Manual"
class Metric(BaseModel):
"""Experience Metric."""
time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.")
money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.")
score: int = Field(default=1, description="Score, a value between 1 and 10.")
class Experience(BaseModel):
"""Experience."""
req: str = Field(..., description="")
resp: str = Field(..., description="")
resp: str = Field(..., description="The type is string/json/code.")
metric: Optional[Metric] = Field(default=None, description="Metric.")
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
tag: str = Field(default="", description="Tagging experience.")
def rag_key(self):
return self.req

View file

@ -46,4 +46,4 @@ class DynamicBM25Retriever(BM25Retriever):
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
if self._index:
self._index.storage_context.persist(persist_dir)
self._index.storage_context.persist(persist_dir)

View file

@ -72,7 +72,6 @@ class File:
class MemoryFileSystem(_MemoryFileSystem):
@classmethod
def _strip_protocol(cls, path):
return super()._strip_protocol(str(path))

View file

@ -23,5 +23,5 @@ def get_func_full_name(func, *args) -> str:
if inspect.ismethod(func) or (inspect.isfunction(func) and "self" in inspect.signature(func).parameters):
cls_name = args[0].__class__.__name__
return f"{func.__module__}.{cls_name}.{func.__name__}"
return f"{func.__module__}.{func.__name__}"