mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-09 07:42:38 +02:00
experiment pool init
This commit is contained in:
parent
46aa6f1975
commit
471310f3b3
14 changed files with 258 additions and 55 deletions
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator
|
|||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.embedding_config import EmbeddingConfig
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.mermaid_config import MermaidConfig
|
||||
from metagpt.configs.redis_config import RedisConfig
|
||||
|
|
@ -22,7 +23,6 @@ from metagpt.configs.search_config import SearchConfig
|
|||
from metagpt.configs.workspace_config import WorkspaceConfig
|
||||
from metagpt.const import CONFIG_ROOT, METAGPT_ROOT
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
|
||||
|
||||
|
||||
class CLIParams(BaseModel):
|
||||
|
|
@ -73,7 +73,7 @@ class Config(CLIParams, YamlModel):
|
|||
code_review_k_times: int = 2
|
||||
|
||||
# Experience Pool Parameters
|
||||
experience_pool: Optional[ExperiencePoolConfig] = None
|
||||
exp_pool: ExperiencePoolConfig = ExperiencePoolConfig()
|
||||
|
||||
# Will be removed in the future
|
||||
metagpt_tti_url: str = ""
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
||||
class ExperiencePoolConfig(YamlModel):
|
||||
enable_read: bool = False
|
||||
enable_write: bool = False
|
||||
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,6 @@
|
|||
"""Experience pool init."""
|
||||
|
||||
from metagpt.exp_pool.manager import exp_manager
|
||||
from metagpt.exp_pool.decorator import exp_cache
|
||||
|
||||
__all__ = ["exp_manager", "exp_cache"]
|
||||
|
|
@ -1,4 +1,56 @@
|
|||
"""Experience Decorator."""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from metagpt.exp_pool.manager import exp_manager
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
|
||||
|
||||
def exp_cache(func):
|
||||
pass
|
||||
def exp_cache(_func: Optional[Callable[..., ReturnType]] = None):
|
||||
"""Decorator to check for a perfect experience and returns it if exists.
|
||||
|
||||
Otherwise, it executes the function, save the result as a new experience, and returns the result.
|
||||
|
||||
This can be applied to both synchronous and asynchronous functions.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
@functools.wraps(func)
|
||||
async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType:
|
||||
"""Attempts to retrieve a cached experience or creates one if not found."""
|
||||
|
||||
req = f"{func.__name__}_{args}_{kwargs}"
|
||||
exps = await exp_manager.query_exps(req)
|
||||
if perfect_exp := exp_manager.extract_one_perfect_exp(exps):
|
||||
return perfect_exp
|
||||
|
||||
if is_async:
|
||||
result = await func(*args, **kwargs)
|
||||
else:
|
||||
result = func(*args, **kwargs)
|
||||
|
||||
exp_manager.create_exp(Experience(req=req, resp=result))
|
||||
|
||||
return result
|
||||
|
||||
def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
NestAsyncio.apply_once()
|
||||
return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False))
|
||||
|
||||
async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType:
|
||||
return await get_or_create(args, kwargs, is_async=True)
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
return async_wrapper
|
||||
else:
|
||||
return sync_wrapper
|
||||
|
||||
if _func is None:
|
||||
return decorator
|
||||
else:
|
||||
return decorator(_func)
|
||||
|
|
|
|||
|
|
@ -1,32 +1,105 @@
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
import uuid
|
||||
import chromadb
|
||||
from chromadb import Collection, QueryResult
|
||||
"""Experience Manager."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import Config, config
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
|
||||
|
||||
class ExperiencePoolManager(BaseModel):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._storage = None
|
||||
class ExperienceManager(BaseModel):
|
||||
"""ExperienceManager manages the lifecycle of experiences, including CRUD and optimization.
|
||||
|
||||
Attributes:
|
||||
config (Config): Configuration for managing experiences.
|
||||
storage (SimpleEngine): Engine to handle the storage and retrieval of experiences.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
config: Config = config
|
||||
storage: SimpleEngine = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.storage is None:
|
||||
self.storage = SimpleEngine.from_objs(
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
|
||||
],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def storage(self) -> SimpleEngine:
|
||||
if self._storage is None:
|
||||
self._storage = SimpleEngine.from_objs(retriever_configs=[ChromaRetrieverConfig(collection_name="experience_pool", persist_path="./chroma_data")])
|
||||
return self._storage
|
||||
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to add.
|
||||
"""
|
||||
if not self.config.exp_pool.enable_write:
|
||||
return
|
||||
|
||||
self.storage.add_objs([exp])
|
||||
|
||||
async def query_exp(self, req: str) -> list[Experience]:
|
||||
|
||||
async def query_exps(self, req: str, tag: str = "") -> list[Experience]:
|
||||
"""Retrieves and filters experiences.
|
||||
|
||||
Args:
|
||||
req (str): The query string to retrieve experiences.
|
||||
tag (str): Optional tag to filter the experiences by.
|
||||
|
||||
Returns:
|
||||
list[Experience]: A list of experiences that match the args.
|
||||
"""
|
||||
if not self.config.exp_pool.enable_read:
|
||||
return []
|
||||
|
||||
nodes = await self.storage.aretrieve(req)
|
||||
exps = [node.metadata["obj"] for node in nodes]
|
||||
exps: list[Experience] = [node.metadata["obj"] for node in nodes]
|
||||
|
||||
# TODO: filter by metadata
|
||||
if tag:
|
||||
exps = [exp for exp in exps if exp.tag == tag]
|
||||
|
||||
return exps
|
||||
|
||||
|
||||
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
|
||||
"""Extracts the first 'perfect' experience from a list of experiences.
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): The experiences to evaluate.
|
||||
|
||||
Returns:
|
||||
Optional[Experience]: The first perfect experience if found, otherwise None.
|
||||
"""
|
||||
for exp in exps:
|
||||
if self.is_perfect_exp(exp):
|
||||
return exp
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_perfect_exp(exp: Experience) -> bool:
|
||||
"""Determines if an experience is considered 'perfect'.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the experience is manually entered, otherwise False.
|
||||
"""
|
||||
if not exp:
|
||||
return False
|
||||
|
||||
# TODO: need more metrics
|
||||
if exp.metric and exp.metric.score == MAX_SCORE:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
exp_manager = ExperienceManager()
|
||||
|
|
|
|||
|
|
@ -1,10 +1,46 @@
|
|||
from pydantic import BaseModel, Field
|
||||
"""Experience schema."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAX_SCORE = 10
|
||||
|
||||
|
||||
class ExperienceType(str, Enum):
|
||||
"""Experience Type."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
INSIGHT = "insight"
|
||||
|
||||
|
||||
class EntryType(Enum):
|
||||
"""Experience Entry Type."""
|
||||
|
||||
AUTOMATIC = "Automatic"
|
||||
MANUAL = "Manual"
|
||||
|
||||
|
||||
class Metric(BaseModel):
|
||||
"""Experience Metric."""
|
||||
|
||||
time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.")
|
||||
money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.")
|
||||
score: int = Field(default=1, description="Score, a value between 1 and 10.")
|
||||
|
||||
|
||||
class Experience(BaseModel):
|
||||
"""Experience."""
|
||||
|
||||
req: str = Field(..., description="")
|
||||
resp: str = Field(..., description="")
|
||||
resp: str = Field(..., description="The type is string/json/code.")
|
||||
metric: Optional[Metric] = Field(default=None, description="Metric.")
|
||||
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
|
||||
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
|
||||
tag: str = Field(default="", description="Tagging experience.")
|
||||
|
||||
def rag_key(self):
|
||||
return self.req
|
||||
|
|
|
|||
|
|
@ -46,4 +46,4 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
|
|
|||
|
|
@ -72,7 +72,6 @@ class File:
|
|||
|
||||
|
||||
class MemoryFileSystem(_MemoryFileSystem):
|
||||
|
||||
@classmethod
|
||||
def _strip_protocol(cls, path):
|
||||
return super()._strip_protocol(str(path))
|
||||
|
|
|
|||
|
|
@ -23,5 +23,5 @@ def get_func_full_name(func, *args) -> str:
|
|||
if inspect.ismethod(func) or (inspect.isfunction(func) and "self" in inspect.signature(func).parameters):
|
||||
cls_name = args[0].__class__.__name__
|
||||
return f"{func.__module__}.{cls_name}.{func.__name__}"
|
||||
|
||||
|
||||
return f"{func.__module__}.{func.__name__}"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue