Merge branch 'feat-exp-pool-opt' into 'mgx_ops'

Feat exp pool opt

See merge request pub/MetaGPT!354
This commit is contained in:
张雷 2024-08-28 05:42:24 +00:00
commit 7f1fea10b8
13 changed files with 226 additions and 30 deletions

View file

@ -81,6 +81,8 @@ exp_pool:
persist_path: .chroma_exp_data # The directory.
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result.
collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
metric=metric or Metric(score=Score(val=10, reason="Manual")),
)
exp_manager = get_exp_manager()
exp_manager.config.exp_pool.enabled = True
exp_manager.config.exp_pool.enable_write = True
exp_manager.is_writable = True
exp_manager.create_exp(exp)
logger.info(f"New experience created for the request `{req[:10]}`.")

View file

@ -0,0 +1,85 @@
"""Load and save experiences from the log file."""
import json
from pathlib import Path
from metagpt.exp_pool import get_exp_manager
from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience
from metagpt.logs import logger
def load_exps(log_file_path: str) -> list[Experience]:
"""Loads experiences from a log file.
Args:
log_file_path (str): The path to the log file.
Returns:
list[Experience]: A list of Experience objects loaded from the log file.
"""
if not Path(log_file_path).exists():
logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}")
return
exps = []
with open(log_file_path, "r") as log_file:
for line in log_file:
if LOG_NEW_EXPERIENCE_PREFIX in line:
json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip()
exp_data = json.loads(json_str)
exp = Experience(**exp_data)
exps.append(exp)
logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}")
return exps
def save_exps(exps: list[Experience]):
"""Saves a list of experiences to the experience pool.
Args:
exps (list[Experience]): The list of experiences to save.
"""
if not exps:
logger.warning("`save_exps` called with an empty list of experiences.")
return
manager = get_exp_manager()
manager.is_writable = True
manager.create_exps(exps)
logger.info(f"Saved {len(exps)} experiences.")
def get_log_file_path() -> str:
"""Retrieves the path to the log file.
Returns:
str: The path to the log file.
Raises:
ValueError: If the log file path cannot be found.
"""
handlers = logger._core.handlers
for handler in handlers.values():
if "log" in handler._name:
return handler._name[1:-1]
raise ValueError("Log file not found")
def main():
log_file_path = get_log_file_path()
exps = load_exps(log_file_path)
save_exps(exps)
if __name__ == "__main__":
main()

View file

@ -22,3 +22,4 @@ class ExperiencePoolConfig(YamlModel):
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
)
use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.")
collection_name: str = Field(default="experience_pool", description="The collection name in chromadb")

View file

@ -10,7 +10,13 @@ from metagpt.config2 import Config
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.schema import (
LOG_NEW_EXPERIENCE_PREFIX,
Experience,
Metric,
QueryType,
Score,
)
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
from metagpt.logs import logger
@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel):
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
self._log_exp(exp)
@staticmethod
def choose_wrapper(func, wrapped_func):
@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
def _log_exp(self, exp: Experience):
log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"})
logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}")

View file

@ -7,12 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
Experience,
QueryType,
)
from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
@ -36,7 +31,7 @@ class ExperienceManager(BaseModel):
_storage: Any = None
@property
def storage(self):
def storage(self) -> "SimpleEngine":
if self._storage is None:
logger.info(f"exp_pool config: {self.config.exp_pool}")
@ -44,6 +39,34 @@ class ExperienceManager(BaseModel):
return self._storage
@storage.setter
def storage(self, value):
self._storage = value
@property
def is_readable(self) -> bool:
return self.config.exp_pool.enabled and self.config.exp_pool.enable_read
@is_readable.setter
def is_readable(self, value: bool):
self.config.exp_pool.enable_read = value
# If set to True, ensure that enabled is also True.
if value:
self.config.exp_pool.enabled = True
@property
def is_writable(self) -> bool:
return self.config.exp_pool.enabled and self.config.exp_pool.enable_write
@is_writable.setter
def is_writable(self, value: bool):
self.config.exp_pool.enable_write = value
# If set to True, ensure that enabled is also True.
if value:
self.config.exp_pool.enabled = True
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -52,10 +75,19 @@ class ExperienceManager(BaseModel):
exp (Experience): The experience to add.
"""
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write:
self.create_exps([exp])
@handle_exception
def create_exps(self, exps: list[Experience]):
"""Adds multiple experiences to the storage if writing is enabled.
Args:
exps (list[Experience]): A list of experiences to add.
"""
if not self.is_writable:
return
self.storage.add_objs([exp])
self.storage.add_objs(exps)
self.storage.persist(self.config.exp_pool.persist_path)
@handle_exception(default_return=[])
@ -71,7 +103,7 @@ class ExperienceManager(BaseModel):
list[Experience]: A list of experiences that match the args.
"""
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read:
if not self.is_readable:
return []
nodes = await self.storage.aretrieve(req)
@ -86,6 +118,15 @@ class ExperienceManager(BaseModel):
return exps
@handle_exception
def delete_all_exps(self):
"""Delete the all experiences."""
if not self.is_writable:
return
self.storage.clear(persist_dir=self.config.exp_pool.persist_path)
def get_exps_count(self) -> int:
"""Get the total number of experiences."""
@ -166,7 +207,7 @@ class ExperienceManager(BaseModel):
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
collection_name=self.config.exp_pool.collection_name,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
@ -194,7 +235,7 @@ class ExperienceManager(BaseModel):
_exp_manager = None
def get_exp_manager():
def get_exp_manager() -> ExperienceManager:
global _exp_manager
if _exp_manager is None:
_exp_manager = ExperienceManager()

View file

@ -2,14 +2,16 @@
import time
from enum import Enum
from typing import Optional
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
MAX_SCORE = 10
DEFAULT_COLLECTION_NAME = "experience_pool"
DEFAULT_SIMILARITY_TOP_K = 2
LOG_NEW_EXPERIENCE_PREFIX = "New experience: "
class QueryType(str, Enum):
"""Type of query experiences."""
@ -68,6 +70,7 @@ class Experience(BaseModel):
tag: str = Field(default="", description="Tagging experience.")
traj: Optional[Trajectory] = Field(default=None, description="Trajectory.")
timestamp: Optional[float] = Field(default_factory=time.time)
uuid: Optional[UUID] = Field(default_factory=uuid4)
def rag_key(self):
return self.req

View file

@ -38,6 +38,7 @@ from metagpt.rag.factories import (
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import (
DeletableRAGRetriever,
ModifiableRAGRetriever,
PersistableRAGRetriever,
QueryableRAGRetriever,
@ -218,7 +219,13 @@ class SimpleEngine(RetrieverQueryEngine):
"""Count."""
self._ensure_retriever_queryable()
return self._retriever.query_total_count()
return self.retriever.query_total_count()
def clear(self, **kwargs):
"""Clear."""
self._ensure_retriever_deletable()
return self.retriever.clear(**kwargs)
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
@ -277,6 +284,9 @@ class SimpleEngine(RetrieverQueryEngine):
def _ensure_retriever_queryable(self):
self._ensure_retriever_of_type(QueryableRAGRetriever)
def _ensure_retriever_deletable(self):
self._ensure_retriever_of_type(DeletableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.

View file

@ -58,4 +58,18 @@ class QueryableRAGRetriever(RAGRetriever):
@abstractmethod
def query_total_count(self) -> int:
"""To support querying total count, must implement this func"""
"""To support querying total count, must implement this func."""
class DeletableRAGRetriever(RAGRetriever):
"""Support deleting all nodes."""
@classmethod
def __subclasshook__(cls, C):
if cls is DeletableRAGRetriever:
return check_methods(C, "clear")
return NotImplemented
@abstractmethod
def clear(self, **kwargs) -> int:
"""To support deleting all nodes, must implement this func."""

View file

@ -1,4 +1,5 @@
"""BM25 retriever."""
from pathlib import Path
from typing import Callable, Optional
from llama_index.core import VectorStoreIndex
@ -36,6 +37,7 @@ class DynamicBM25Retriever(BM25Retriever):
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
@ -45,6 +47,7 @@ class DynamicBM25Retriever(BM25Retriever):
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
if self._index:
self._index.storage_context.persist(persist_dir)
@ -52,3 +55,19 @@ class DynamicBM25Retriever(BM25Retriever):
"""Support query total count."""
return len(self._nodes)
def clear(self, **kwargs) -> None:
"""Support deleting all nodes."""
self._delete_json_files(kwargs.get("persist_dir"))
self._nodes = []
@staticmethod
def _delete_json_files(directory: str):
"""Delete all JSON files in the specified directory."""
if not directory:
return
for file in Path(directory).glob("*.json"):
file.unlink()

View file

@ -8,6 +8,10 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
class ChromaRetriever(VectorIndexRetriever):
"""Chroma retriever."""
@property
def vector_store(self) -> ChromaVectorStore:
return self._vector_store
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
@ -20,6 +24,11 @@ class ChromaRetriever(VectorIndexRetriever):
def query_total_count(self) -> int:
"""Support query total count."""
vector_store: ChromaVectorStore = self._vector_store
return self.vector_store._collection.count()
return vector_store._collection.count()
def clear(self, **kwargs) -> None:
"""Support deleting all nodes."""
ids = self.vector_store._collection.get()["ids"]
if ids:
self.vector_store._collection.delete(ids=ids)

View file

@ -147,8 +147,8 @@ class SerializationMixin(BaseModel, extra="forbid"):
serialized_data = self.model_dump()
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
write_json_file(file_path, serialized_data, use_fallback=True)
logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
return file_path
@ -171,7 +171,7 @@ class SerializationMixin(BaseModel, extra="forbid"):
data: dict = read_json_file(file_path)
model = cls(**data)
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
logger.debug(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
return model

View file

@ -581,20 +581,20 @@ def read_json_file(json_file: str, encoding: str = "utf-8") -> list[Any]:
def handle_unknown_serialization(x: Any) -> str:
"""For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception."""
"""For `to_jsonable_python` debug, get more detail about the x."""
if inspect.ismethod(x):
logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}")
tip = f"Cannot serialize method '{x.__func__.__name__}' of class '{x.__self__.__class__.__name__}'"
elif inspect.isfunction(x):
logger.error(f"Function: {x.__name__}")
tip = f"Cannot serialize function '{x.__name__}'"
elif hasattr(x, "__class__"):
logger.error(f"Instance of: {x.__class__.__name__}")
tip = f"Cannot serialize instance of '{x.__class__.__name__}'"
elif hasattr(x, "__name__"):
logger.error(f"Class or module: {x.__name__}")
tip = f"Cannot serialize class or module '{x.__name__}'"
else:
logger.error(f"Unknown type: {type(x)}")
tip = f"Cannot serialize object of type '{type(x).__name__}'"
return f"<Unserializable {type(x).__name__} object>"
raise TypeError(tip)
def write_json_file(json_file: str, data: Any, encoding: str = "utf-8", indent: int = 4, use_fallback: bool = False):