mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-21 14:05:17 +02:00
Merge branch 'feat-exp-pool-opt' into 'mgx_ops'
Feat exp pool opt See merge request pub/MetaGPT!354
This commit is contained in:
commit
7f1fea10b8
13 changed files with 226 additions and 30 deletions
|
|
@ -81,6 +81,8 @@ exp_pool:
|
|||
persist_path: .chroma_exp_data # The directory.
|
||||
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
|
||||
use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result.
|
||||
collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb.
|
||||
|
||||
|
||||
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
azure_tts_region: "eastus"
|
||||
|
|
|
|||
|
|
@ -46,8 +46,8 @@ async def add_exp(req: str, resp: str, tag: str, metric: Metric = None):
|
|||
metric=metric or Metric(score=Score(val=10, reason="Manual")),
|
||||
)
|
||||
exp_manager = get_exp_manager()
|
||||
exp_manager.config.exp_pool.enabled = True
|
||||
exp_manager.config.exp_pool.enable_write = True
|
||||
exp_manager.is_writable = True
|
||||
|
||||
exp_manager.create_exp(exp)
|
||||
logger.info(f"New experience created for the request `{req[:10]}`.")
|
||||
|
||||
|
|
|
|||
85
examples/exp_pool/load_exps_from_log.py
Normal file
85
examples/exp_pool/load_exps_from_log.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
"""Load and save experiences from the log file."""
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.exp_pool import get_exp_manager
|
||||
from metagpt.exp_pool.schema import LOG_NEW_EXPERIENCE_PREFIX, Experience
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def load_exps(log_file_path: str) -> list[Experience]:
|
||||
"""Loads experiences from a log file.
|
||||
|
||||
Args:
|
||||
log_file_path (str): The path to the log file.
|
||||
|
||||
Returns:
|
||||
list[Experience]: A list of Experience objects loaded from the log file.
|
||||
"""
|
||||
|
||||
if not Path(log_file_path).exists():
|
||||
logger.warning(f"`load_exps` called with a non-existent log file path: {log_file_path}")
|
||||
return
|
||||
|
||||
exps = []
|
||||
with open(log_file_path, "r") as log_file:
|
||||
for line in log_file:
|
||||
if LOG_NEW_EXPERIENCE_PREFIX in line:
|
||||
json_str = line.split(LOG_NEW_EXPERIENCE_PREFIX, 1)[1].strip()
|
||||
exp_data = json.loads(json_str)
|
||||
|
||||
exp = Experience(**exp_data)
|
||||
exps.append(exp)
|
||||
|
||||
logger.info(f"Loaded {len(exps)} experiences from log file: {log_file_path}")
|
||||
|
||||
return exps
|
||||
|
||||
|
||||
def save_exps(exps: list[Experience]):
|
||||
"""Saves a list of experiences to the experience pool.
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): The list of experiences to save.
|
||||
"""
|
||||
|
||||
if not exps:
|
||||
logger.warning("`save_exps` called with an empty list of experiences.")
|
||||
return
|
||||
|
||||
manager = get_exp_manager()
|
||||
manager.is_writable = True
|
||||
|
||||
manager.create_exps(exps)
|
||||
logger.info(f"Saved {len(exps)} experiences.")
|
||||
|
||||
|
||||
def get_log_file_path() -> str:
|
||||
"""Retrieves the path to the log file.
|
||||
|
||||
Returns:
|
||||
str: The path to the log file.
|
||||
|
||||
Raises:
|
||||
ValueError: If the log file path cannot be found.
|
||||
"""
|
||||
|
||||
handlers = logger._core.handlers
|
||||
|
||||
for handler in handlers.values():
|
||||
if "log" in handler._name:
|
||||
return handler._name[1:-1]
|
||||
|
||||
raise ValueError("Log file not found")
|
||||
|
||||
|
||||
def main():
|
||||
log_file_path = get_log_file_path()
|
||||
|
||||
exps = load_exps(log_file_path)
|
||||
save_exps(exps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -22,3 +22,4 @@ class ExperiencePoolConfig(YamlModel):
|
|||
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
|
||||
)
|
||||
use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.")
|
||||
collection_name: str = Field(default="experience_pool", description="The collection name in chromadb")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,13 @@ from metagpt.config2 import Config
|
|||
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
|
||||
from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager
|
||||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
|
||||
from metagpt.exp_pool.schema import (
|
||||
LOG_NEW_EXPERIENCE_PREFIX,
|
||||
Experience,
|
||||
Metric,
|
||||
QueryType,
|
||||
Score,
|
||||
)
|
||||
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
|
||||
from metagpt.exp_pool.serializers import BaseSerializer, SimpleSerializer
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -173,6 +179,7 @@ class ExpCacheHandler(BaseModel):
|
|||
|
||||
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
|
||||
self.exp_manager.create_exp(exp)
|
||||
self._log_exp(exp)
|
||||
|
||||
@staticmethod
|
||||
def choose_wrapper(func, wrapped_func):
|
||||
|
|
@ -215,3 +222,8 @@ class ExpCacheHandler(BaseModel):
|
|||
return await self.func(*self.args, **self.kwargs)
|
||||
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
def _log_exp(self, exp: Experience):
|
||||
log_entry = exp.model_dump_json(include={"uuid", "req", "resp", "tag"})
|
||||
|
||||
logger.debug(f"{LOG_NEW_EXPERIENCE_PREFIX}{log_entry}")
|
||||
|
|
|
|||
|
|
@ -7,12 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
Experience,
|
||||
QueryType,
|
||||
)
|
||||
from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
|
@ -36,7 +31,7 @@ class ExperienceManager(BaseModel):
|
|||
_storage: Any = None
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
def storage(self) -> "SimpleEngine":
|
||||
if self._storage is None:
|
||||
logger.info(f"exp_pool config: {self.config.exp_pool}")
|
||||
|
||||
|
|
@ -44,6 +39,34 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
return self._storage
|
||||
|
||||
@storage.setter
|
||||
def storage(self, value):
|
||||
self._storage = value
|
||||
|
||||
@property
|
||||
def is_readable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_read
|
||||
|
||||
@is_readable.setter
|
||||
def is_readable(self, value: bool):
|
||||
self.config.exp_pool.enable_read = value
|
||||
|
||||
# If set to True, ensure that enabled is also True.
|
||||
if value:
|
||||
self.config.exp_pool.enabled = True
|
||||
|
||||
@property
|
||||
def is_writable(self) -> bool:
|
||||
return self.config.exp_pool.enabled and self.config.exp_pool.enable_write
|
||||
|
||||
@is_writable.setter
|
||||
def is_writable(self, value: bool):
|
||||
self.config.exp_pool.enable_write = value
|
||||
|
||||
# If set to True, ensure that enabled is also True.
|
||||
if value:
|
||||
self.config.exp_pool.enabled = True
|
||||
|
||||
@handle_exception
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
|
@ -52,10 +75,19 @@ class ExperienceManager(BaseModel):
|
|||
exp (Experience): The experience to add.
|
||||
"""
|
||||
|
||||
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_write:
|
||||
self.create_exps([exp])
|
||||
|
||||
@handle_exception
|
||||
def create_exps(self, exps: list[Experience]):
|
||||
"""Adds multiple experiences to the storage if writing is enabled.
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): A list of experiences to add.
|
||||
"""
|
||||
if not self.is_writable:
|
||||
return
|
||||
|
||||
self.storage.add_objs([exp])
|
||||
self.storage.add_objs(exps)
|
||||
self.storage.persist(self.config.exp_pool.persist_path)
|
||||
|
||||
@handle_exception(default_return=[])
|
||||
|
|
@ -71,7 +103,7 @@ class ExperienceManager(BaseModel):
|
|||
list[Experience]: A list of experiences that match the args.
|
||||
"""
|
||||
|
||||
if not self.config.exp_pool.enabled or not self.config.exp_pool.enable_read:
|
||||
if not self.is_readable:
|
||||
return []
|
||||
|
||||
nodes = await self.storage.aretrieve(req)
|
||||
|
|
@ -86,6 +118,15 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
return exps
|
||||
|
||||
@handle_exception
|
||||
def delete_all_exps(self):
|
||||
"""Delete the all experiences."""
|
||||
|
||||
if not self.is_writable:
|
||||
return
|
||||
|
||||
self.storage.clear(persist_dir=self.config.exp_pool.persist_path)
|
||||
|
||||
def get_exps_count(self) -> int:
|
||||
"""Get the total number of experiences."""
|
||||
|
||||
|
|
@ -166,7 +207,7 @@ class ExperienceManager(BaseModel):
|
|||
retriever_configs = [
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=self.config.exp_pool.persist_path,
|
||||
collection_name=DEFAULT_COLLECTION_NAME,
|
||||
collection_name=self.config.exp_pool.collection_name,
|
||||
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
|
||||
)
|
||||
]
|
||||
|
|
@ -194,7 +235,7 @@ class ExperienceManager(BaseModel):
|
|||
_exp_manager = None
|
||||
|
||||
|
||||
def get_exp_manager():
|
||||
def get_exp_manager() -> ExperienceManager:
|
||||
global _exp_manager
|
||||
if _exp_manager is None:
|
||||
_exp_manager = ExperienceManager()
|
||||
|
|
|
|||
|
|
@ -2,14 +2,16 @@
|
|||
import time
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAX_SCORE = 10
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "experience_pool"
|
||||
DEFAULT_SIMILARITY_TOP_K = 2
|
||||
|
||||
LOG_NEW_EXPERIENCE_PREFIX = "New experience: "
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""Type of query experiences."""
|
||||
|
|
@ -68,6 +70,7 @@ class Experience(BaseModel):
|
|||
tag: str = Field(default="", description="Tagging experience.")
|
||||
traj: Optional[Trajectory] = Field(default=None, description="Trajectory.")
|
||||
timestamp: Optional[float] = Field(default_factory=time.time)
|
||||
uuid: Optional[UUID] = Field(default_factory=uuid4)
|
||||
|
||||
def rag_key(self):
|
||||
return self.req
|
||||
|
|
|
|||
|
|
@ -38,6 +38,7 @@ from metagpt.rag.factories import (
|
|||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import (
|
||||
DeletableRAGRetriever,
|
||||
ModifiableRAGRetriever,
|
||||
PersistableRAGRetriever,
|
||||
QueryableRAGRetriever,
|
||||
|
|
@ -218,7 +219,13 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Count."""
|
||||
self._ensure_retriever_queryable()
|
||||
|
||||
return self._retriever.query_total_count()
|
||||
return self.retriever.query_total_count()
|
||||
|
||||
def clear(self, **kwargs):
|
||||
"""Clear."""
|
||||
self._ensure_retriever_deletable()
|
||||
|
||||
return self.retriever.clear(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
|
||||
|
|
@ -277,6 +284,9 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
def _ensure_retriever_queryable(self):
|
||||
self._ensure_retriever_of_type(QueryableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_deletable(self):
|
||||
self._ensure_retriever_of_type(DeletableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
|
|
|
|||
|
|
@ -58,4 +58,18 @@ class QueryableRAGRetriever(RAGRetriever):
|
|||
|
||||
@abstractmethod
|
||||
def query_total_count(self) -> int:
|
||||
"""To support querying total count, must implement this func"""
|
||||
"""To support querying total count, must implement this func."""
|
||||
|
||||
|
||||
class DeletableRAGRetriever(RAGRetriever):
|
||||
"""Support deleting all nodes."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is DeletableRAGRetriever:
|
||||
return check_methods(C, "clear")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def clear(self, **kwargs) -> int:
|
||||
"""To support deleting all nodes, must implement this func."""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
"""BM25 retriever."""
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
|
@ -36,6 +37,7 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
|
@ -45,6 +47,7 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
|
||||
if self._index:
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
|
||||
|
|
@ -52,3 +55,19 @@ class DynamicBM25Retriever(BM25Retriever):
|
|||
"""Support query total count."""
|
||||
|
||||
return len(self._nodes)
|
||||
|
||||
def clear(self, **kwargs) -> None:
|
||||
"""Support deleting all nodes."""
|
||||
|
||||
self._delete_json_files(kwargs.get("persist_dir"))
|
||||
self._nodes = []
|
||||
|
||||
@staticmethod
|
||||
def _delete_json_files(directory: str):
|
||||
"""Delete all JSON files in the specified directory."""
|
||||
|
||||
if not directory:
|
||||
return
|
||||
|
||||
for file in Path(directory).glob("*.json"):
|
||||
file.unlink()
|
||||
|
|
|
|||
|
|
@ -8,6 +8,10 @@ from llama_index.vector_stores.chroma import ChromaVectorStore
|
|||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""Chroma retriever."""
|
||||
|
||||
@property
|
||||
def vector_store(self) -> ChromaVectorStore:
|
||||
return self._vector_store
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
|
@ -20,6 +24,11 @@ class ChromaRetriever(VectorIndexRetriever):
|
|||
def query_total_count(self) -> int:
|
||||
"""Support query total count."""
|
||||
|
||||
vector_store: ChromaVectorStore = self._vector_store
|
||||
return self.vector_store._collection.count()
|
||||
|
||||
return vector_store._collection.count()
|
||||
def clear(self, **kwargs) -> None:
|
||||
"""Support deleting all nodes."""
|
||||
|
||||
ids = self.vector_store._collection.get()["ids"]
|
||||
if ids:
|
||||
self.vector_store._collection.delete(ids=ids)
|
||||
|
|
|
|||
|
|
@ -147,8 +147,8 @@ class SerializationMixin(BaseModel, extra="forbid"):
|
|||
|
||||
serialized_data = self.model_dump()
|
||||
|
||||
write_json_file(file_path, serialized_data)
|
||||
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
|
||||
write_json_file(file_path, serialized_data, use_fallback=True)
|
||||
logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
|
|
@ -171,7 +171,7 @@ class SerializationMixin(BaseModel, extra="forbid"):
|
|||
data: dict = read_json_file(file_path)
|
||||
|
||||
model = cls(**data)
|
||||
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
|
||||
logger.debug(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
|
||||
|
||||
return model
|
||||
|
||||
|
|
|
|||
|
|
@ -581,20 +581,20 @@ def read_json_file(json_file: str, encoding: str = "utf-8") -> list[Any]:
|
|||
|
||||
|
||||
def handle_unknown_serialization(x: Any) -> str:
|
||||
"""For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception."""
|
||||
"""For `to_jsonable_python` debug, get more detail about the x."""
|
||||
|
||||
if inspect.ismethod(x):
|
||||
logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}")
|
||||
tip = f"Cannot serialize method '{x.__func__.__name__}' of class '{x.__self__.__class__.__name__}'"
|
||||
elif inspect.isfunction(x):
|
||||
logger.error(f"Function: {x.__name__}")
|
||||
tip = f"Cannot serialize function '{x.__name__}'"
|
||||
elif hasattr(x, "__class__"):
|
||||
logger.error(f"Instance of: {x.__class__.__name__}")
|
||||
tip = f"Cannot serialize instance of '{x.__class__.__name__}'"
|
||||
elif hasattr(x, "__name__"):
|
||||
logger.error(f"Class or module: {x.__name__}")
|
||||
tip = f"Cannot serialize class or module '{x.__name__}'"
|
||||
else:
|
||||
logger.error(f"Unknown type: {type(x)}")
|
||||
tip = f"Cannot serialize object of type '{type(x).__name__}'"
|
||||
|
||||
return f"<Unserializable {type(x).__name__} object>"
|
||||
raise TypeError(tip)
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: Any, encoding: str = "utf-8", indent: int = 4, use_fallback: bool = False):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue