diff --git a/.gitignore b/.gitignore index aa5edd74a..7c64829ad 100644 --- a/.gitignore +++ b/.gitignore @@ -162,6 +162,7 @@ examples/graph_store.json examples/image__vector_store.json examples/index_store.json .chroma +.chroma_exp_data *~$* workspace/* tmp diff --git a/config/config2.example.yaml b/config/config2.example.yaml index c5ca6e767..c7b2cae2c 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -74,9 +74,9 @@ s3: secure: false bucket: "test" -experience_pool: - enable_read: false - enable_write: false +exp_pool: + enable_read: true + enable_write: true azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/examples/exp_pool/decorator.py b/examples/exp_pool/decorator.py new file mode 100644 index 000000000..2f6397f80 --- /dev/null +++ b/examples/exp_pool/decorator.py @@ -0,0 +1,26 @@ +"""Decorator example of experience pool.""" + +import asyncio +import uuid + +from metagpt.exp_pool import exp_cache, exp_manager +from metagpt.logs import logger + + +@exp_cache +async def produce(req): + return f"{req} {uuid.uuid4().hex}" + + +async def main(): + req = "Water" + + resp = await produce(req) + logger.info(f"The resp of `produce{req}` is: {resp}") + + exps = await exp_manager.query_exps(req) + logger.info(f"Find experiences: {exps}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/exp_pool/manager.py b/examples/exp_pool/manager.py deleted file mode 100644 index f5766f9a5..000000000 --- a/examples/exp_pool/manager.py +++ /dev/null @@ -1,21 +0,0 @@ -from metagpt.exp_pool.manager import ExperiencePoolManager -from metagpt.exp_pool.schema import Experience -from pprint import pprint -import asyncio -# import logging -# logging.basicConfig(level=logging.DEBUG) - -async def main(): - req = "2048 game" - exp = Experience(req=req, resp="python code") - - manager = ExperiencePoolManager() - - # pprint(manager.storage.get()) - # manager.create_exp(exp) - result = await manager.query_exp(req) - print(result) - - -if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file diff --git a/examples/exp_pool/simple.py b/examples/exp_pool/simple.py new file mode 100644 index 000000000..bc20fbcdd --- /dev/null +++ b/examples/exp_pool/simple.py @@ -0,0 +1,29 @@ +"""Simple example of experience pool.""" + +import asyncio + +from metagpt.exp_pool import exp_manager +from metagpt.exp_pool.schema import EntryType, Experience +from metagpt.logs import logger + + +async def main(): + req = "Simple task." + + # 1. Find experiences. + exps = await exp_manager.query_exps(req) + if exps: + logger.info(f"Experiences already exist for the request `{req}`: {exps}") + return + + # 2. Create a new experience if none exist + exp_manager.create_exp(Experience(req=req, resp="Simple echo.", entry_type=EntryType.MANUAL)) + logger.info(f"New experience created for the request `{req}`.") + + # 3. Find again + exps = await exp_manager.query_exps(req) + logger.info(f"Updated experiences: {exps}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/config2.py b/metagpt/config2.py index 6f5a1add6..6588a6036 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -13,6 +13,7 @@ from pydantic import BaseModel, model_validator from metagpt.configs.browser_config import BrowserConfig from metagpt.configs.embedding_config import EmbeddingConfig +from metagpt.configs.exp_pool_config import ExperiencePoolConfig from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.configs.mermaid_config import MermaidConfig from metagpt.configs.redis_config import RedisConfig @@ -22,7 +23,6 @@ from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig from metagpt.const import CONFIG_ROOT, METAGPT_ROOT from metagpt.utils.yaml_model import YamlModel -from metagpt.configs.exp_pool_config import ExperiencePoolConfig class CLIParams(BaseModel): @@ -73,7 +73,7 @@ class Config(CLIParams, YamlModel): code_review_k_times: int = 2 # Experience Pool Parameters - experience_pool: Optional[ExperiencePoolConfig] = None + exp_pool: ExperiencePoolConfig = ExperiencePoolConfig() # Will be removed in the future metagpt_tti_url: str = "" diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index f7312d2de..3f86173c1 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -1,6 +1,8 @@ +from pydantic import Field + from metagpt.utils.yaml_model import YamlModel class ExperiencePoolConfig(YamlModel): - enable_read: bool = False - enable_write: bool = False + enable_read: bool = Field(default=True, description="Enable to read from experience pool.") + enable_write: bool = Field(default=True, description="Enable to write to experience pool.") diff --git a/metagpt/exp_pool/__init__.py b/metagpt/exp_pool/__init__.py index e69de29bb..aeeb94b38 100644 --- a/metagpt/exp_pool/__init__.py +++ b/metagpt/exp_pool/__init__.py @@ -0,0 +1,6 @@ +"""Experience pool init.""" + +from metagpt.exp_pool.manager import exp_manager +from metagpt.exp_pool.decorator import exp_cache + +__all__ = ["exp_manager", "exp_cache"] diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 6629e8377..1d691b8f3 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -1,4 +1,56 @@ +"""Experience Decorator.""" + +import asyncio +import functools +from typing import Any, Callable, Optional, TypeVar + +from metagpt.exp_pool.manager import exp_manager +from metagpt.exp_pool.schema import Experience +from metagpt.utils.async_helper import NestAsyncio + +ReturnType = TypeVar("ReturnType") -def exp_cache(func): - pass +def exp_cache(_func: Optional[Callable[..., ReturnType]] = None): + """Decorator to check for a perfect experience and returns it if exists. + + Otherwise, it executes the function, save the result as a new experience, and returns the result. + + This can be applied to both synchronous and asynchronous functions. + """ + + def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + @functools.wraps(func) + async def get_or_create(args: Any, kwargs: Any, is_async: bool) -> ReturnType: + """Attempts to retrieve a cached experience or creates one if not found.""" + + req = f"{func.__name__}_{args}_{kwargs}" + exps = await exp_manager.query_exps(req) + if perfect_exp := exp_manager.extract_one_perfect_exp(exps): + return perfect_exp + + if is_async: + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + + exp_manager.create_exp(Experience(req=req, resp=result)) + + return result + + def sync_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + NestAsyncio.apply_once() + return asyncio.get_event_loop().run_until_complete(get_or_create(args, kwargs, is_async=False)) + + async def async_wrapper(*args: Any, **kwargs: Any) -> ReturnType: + return await get_or_create(args, kwargs, is_async=True) + + if asyncio.iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + if _func is None: + return decorator + else: + return decorator(_func) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index c32073a9f..4bc566104 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -1,32 +1,105 @@ -from pydantic import BaseModel, ConfigDict -from metagpt.exp_pool.schema import Experience -import uuid -import chromadb -from chromadb import Collection, QueryResult +"""Experience Manager.""" + from typing import Optional + +from pydantic import BaseModel, ConfigDict, model_validator + +from metagpt.config2 import Config, config +from metagpt.exp_pool.schema import MAX_SCORE, Experience from metagpt.rag.engines import SimpleEngine -from metagpt.rag.schema import ChromaRetrieverConfig +from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig -class ExperiencePoolManager(BaseModel): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._storage = None +class ExperienceManager(BaseModel): + """ExperienceManager manages the lifecycle of experiences, including CRUD and optimization. + + Attributes: + config (Config): Configuration for managing experiences. + storage (SimpleEngine): Engine to handle the storage and retrieval of experiences. + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + config: Config = config + storage: SimpleEngine = None + + @model_validator(mode="after") + def initialize(self): + if self.storage is None: + self.storage = SimpleEngine.from_objs( + retriever_configs=[ + ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data") + ], + ranker_configs=[LLMRankerConfig()], + ) + return self - @property - def storage(self) -> SimpleEngine: - if self._storage is None: - self._storage = SimpleEngine.from_objs(retriever_configs=[ChromaRetrieverConfig(collection_name="experience_pool", persist_path="./chroma_data")]) - return self._storage - def create_exp(self, exp: Experience): + """Adds an experience to the storage if writing is enabled. + + Args: + exp (Experience): The experience to add. + """ + if not self.config.exp_pool.enable_write: + return + self.storage.add_objs([exp]) - - async def query_exp(self, req: str) -> list[Experience]: + + async def query_exps(self, req: str, tag: str = "") -> list[Experience]: + """Retrieves and filters experiences. + + Args: + req (str): The query string to retrieve experiences. + tag (str): Optional tag to filter the experiences by. + + Returns: + list[Experience]: A list of experiences that match the args. + """ + if not self.config.exp_pool.enable_read: + return [] + nodes = await self.storage.aretrieve(req) - exps = [node.metadata["obj"] for node in nodes] + exps: list[Experience] = [node.metadata["obj"] for node in nodes] + + # TODO: filter by metadata + if tag: + exps = [exp for exp in exps if exp.tag == tag] return exps - + def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]: + """Extracts the first 'perfect' experience from a list of experiences. + Args: + exps (list[Experience]): The experiences to evaluate. + + Returns: + Optional[Experience]: The first perfect experience if found, otherwise None. + """ + for exp in exps: + if self.is_perfect_exp(exp): + return exp + + return None + + @staticmethod + def is_perfect_exp(exp: Experience) -> bool: + """Determines if an experience is considered 'perfect'. + + Args: + exp (Experience): The experience to evaluate. + + Returns: + bool: True if the experience is manually entered, otherwise False. + """ + if not exp: + return False + + # TODO: need more metrics + if exp.metric and exp.metric.score == MAX_SCORE: + return True + + return False + + +exp_manager = ExperienceManager() diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index 359268612..b51bc3c17 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -1,10 +1,46 @@ -from pydantic import BaseModel, Field +"""Experience schema.""" + +from enum import Enum +from typing import Optional + from llama_index.core.schema import TextNode +from pydantic import BaseModel, Field + +MAX_SCORE = 10 + + +class ExperienceType(str, Enum): + """Experience Type.""" + + SUCCESS = "success" + FAILURE = "failure" + INSIGHT = "insight" + + +class EntryType(Enum): + """Experience Entry Type.""" + + AUTOMATIC = "Automatic" + MANUAL = "Manual" + + +class Metric(BaseModel): + """Experience Metric.""" + + time_cost: float = Field(default=0.000, description="Time cost, the unit is milliseconds.") + money_cost: float = Field(default=0.000, description="Money cost, the unit is US dollars.") + score: int = Field(default=1, description="Score, a value between 1 and 10.") class Experience(BaseModel): + """Experience.""" + req: str = Field(..., description="") - resp: str = Field(..., description="") + resp: str = Field(..., description="The type is string/json/code.") + metric: Optional[Metric] = Field(default=None, description="Metric.") + exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.") + entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.") + tag: str = Field(default="", description="Tagging experience.") def rag_key(self): return self.req diff --git a/metagpt/rag/retrievers/bm25_retriever.py b/metagpt/rag/retrievers/bm25_retriever.py index 3b085cb73..dc75d87b0 100644 --- a/metagpt/rag/retrievers/bm25_retriever.py +++ b/metagpt/rag/retrievers/bm25_retriever.py @@ -46,4 +46,4 @@ class DynamicBM25Retriever(BM25Retriever): def persist(self, persist_dir: str, **kwargs) -> None: """Support persist.""" if self._index: - self._index.storage_context.persist(persist_dir) \ No newline at end of file + self._index.storage_context.persist(persist_dir) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index a8ed482d9..8861f65dc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -72,7 +72,6 @@ class File: class MemoryFileSystem(_MemoryFileSystem): - @classmethod def _strip_protocol(cls, path): return super()._strip_protocol(str(path)) diff --git a/metagpt/utils/reflection.py b/metagpt/utils/reflection.py index 688831f06..2683e5657 100644 --- a/metagpt/utils/reflection.py +++ b/metagpt/utils/reflection.py @@ -23,5 +23,5 @@ def get_func_full_name(func, *args) -> str: if inspect.ismethod(func) or (inspect.isfunction(func) and "self" in inspect.signature(func).parameters): cls_name = args[0].__class__.__name__ return f"{func.__module__}.{cls_name}.{func.__name__}" - + return f"{func.__module__}.{func.__name__}"