add collection_name in exp_pool config

This commit is contained in:
seehi 2024-08-23 10:14:12 +08:00
parent cb3541f204
commit 9d327081ba
5 changed files with 7 additions and 10 deletions

View file

@ -81,6 +81,8 @@ exp_pool:
persist_path: .chroma_exp_data # The directory.
retrieval_type: bm25 # Default is `bm25`, can be set to `chroma` for vector storage, which requires setting up embedding.
use_llm_ranker: true # Default is `true`, it will use LLM Reranker to get better result.
collection_name: experience_pool # When `retrieval_type` is `chroma`, `collection_name` is the collection name in chromadb.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -22,3 +22,4 @@ class ExperiencePoolConfig(YamlModel):
default=ExperiencePoolRetrievalType.BM25, description="The retrieval type for experience pool."
)
use_llm_ranker: bool = Field(default=True, description="Use LLM Reranker to get better result.")
collection_name: str = Field(default="experience_pool", description="The collection name in chromadb")

View file

@ -7,12 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolRetrievalType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
Experience,
QueryType,
)
from metagpt.exp_pool.schema import DEFAULT_SIMILARITY_TOP_K, Experience, QueryType
from metagpt.logs import logger
from metagpt.utils.exceptions import handle_exception
@ -166,7 +161,7 @@ class ExperienceManager(BaseModel):
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
collection_name=self.config.exp_pool.collection_name,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]

View file

@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
MAX_SCORE = 10
DEFAULT_COLLECTION_NAME = "experience_pool"
DEFAULT_SIMILARITY_TOP_K = 2

View file

@ -148,7 +148,7 @@ class SerializationMixin(BaseModel, extra="forbid"):
serialized_data = self.model_dump()
write_json_file(file_path, serialized_data)
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
logger.debug(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
return file_path
@ -171,7 +171,7 @@ class SerializationMixin(BaseModel, extra="forbid"):
data: dict = read_json_file(file_path)
model = cls(**data)
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
logger.debug(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
return model