make similarity_top_k configurable

This commit is contained in:
seehi 2024-10-11 15:37:33 +08:00
parent 2965a22e1d
commit dd955c848d
4 changed files with 9 additions and 1 deletions

View file

@ -87,6 +87,7 @@ role_zero:
enable_longterm_memory: false # Whether to use long-term memory. Default is `false`.
longterm_memory_persist_path: .role_memory_data # The directory to save data.
memory_k: 200 # The capacity of short-term memory.
similarity_top_k: 5 # The number of long-term memories to retrieve.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -7,3 +7,4 @@ class RoleZeroConfig(YamlModel):
enable_longterm_memory: bool = Field(default=False, description="Whether to use long-term memory.")
longterm_memory_persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
similarity_top_k: int = Field(default=5, description="The number of long-term memories to retrieve.")

View file

@ -31,6 +31,7 @@ class RoleZeroLongTermMemory(Memory):
persist_path: str = Field(default=".role_memory_data", description="The directory to save data.")
collection_name: str = Field(default="role_zero", description="The name of the collection, such as the role name.")
memory_k: int = Field(default=200, description="The capacity of short-term memory.")
similarity_top_k: int = Field(default=5, description="The number of long-term memories to retrieve.")
_rag_engine: Any = None
@ -54,7 +55,11 @@ class RoleZeroLongTermMemory(Memory):
raise ImportError("To use the RoleZeroMemory, you need to install the rag module.")
retriever_configs = [
ChromaRetrieverConfig(persist_path=self.persist_path, collection_name=self.collection_name)
ChromaRetrieverConfig(
persist_path=self.persist_path,
collection_name=self.collection_name,
similarity_top_k=self.similarity_top_k,
)
]
ranker_configs = []

View file

@ -185,6 +185,7 @@ class RoleZero(Role):
persist_path=self.config.role_zero.longterm_memory_persist_path,
collection_name=self.name.replace(" ", ""),
memory_k=self.config.role_zero.memory_k,
similarity_top_k=self.config.role_zero.similarity_top_k,
)
logger.info(f"Long-term memory set for role '{self.name}'")