mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-02 12:22:39 +02:00
update werewolf experience and add rag retrieve
This commit is contained in:
parent
5caaea2aeb
commit
a466bc9243
3 changed files with 44 additions and 29 deletions
|
|
@ -1,11 +1,10 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from chromadb.utils import embedding_functions
|
||||
import chromadb
|
||||
from pydantic import model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.environment.werewolf.const import RoleType
|
||||
from metagpt.ext.werewolf.schema import RoleExperience
|
||||
|
|
@ -15,13 +14,7 @@ from metagpt.rag.schema import ChromaIndexConfig, ChromaRetrieverConfig
|
|||
from metagpt.utils.common import read_json_file, write_json_file
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "role_reflection" # FIXME: some hard code for now
|
||||
EMB_FN = embedding_functions.OpenAIEmbeddingFunction(
|
||||
api_key=config.llm.api_key,
|
||||
api_base=config.llm.base_url,
|
||||
api_type=config.llm.api_type,
|
||||
model_name="text-embedding-ada-002",
|
||||
)
|
||||
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("/werewolf_game/chroma")
|
||||
PERSIST_PATH = DEFAULT_WORKSPACE_ROOT.joinpath("werewolf_game/chroma")
|
||||
PERSIST_PATH.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
|
|
@ -35,6 +28,13 @@ class AddNewExperiences(Action):
|
|||
def validate_collection(self):
|
||||
if self.engine:
|
||||
return
|
||||
if self.delete_existing:
|
||||
try:
|
||||
# implement engine `DELETE` method later
|
||||
chromadb.PersistentClient(PERSIST_PATH.as_posix()).delete_collection(self.collection_name)
|
||||
except Exception as exp:
|
||||
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
|
||||
|
||||
self.engine = SimpleEngine.from_objs(
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(
|
||||
|
|
@ -42,12 +42,6 @@ class AddNewExperiences(Action):
|
|||
)
|
||||
]
|
||||
)
|
||||
if self.delete_existing:
|
||||
try:
|
||||
# implement engine `DELETE` method later
|
||||
self.engine.retriever._index._vector_store._collection.delete_collection(name=self.collection_name)
|
||||
except Exception as exp:
|
||||
logger.error(f"delete chroma collection: {self.collection_name} failed, exp: {exp}")
|
||||
|
||||
def run(self, experiences: list[RoleExperience]):
|
||||
if not experiences:
|
||||
|
|
@ -64,7 +58,7 @@ class AddNewExperiences(Action):
|
|||
experiences = [RoleExperience.model_validate(item) for item in experiences]
|
||||
experiences = [exp for exp in experiences if len(exp.reflection) > 2] # not "" or not '""'
|
||||
|
||||
self.engine.add(experiences)
|
||||
self.engine.add_objs(experiences)
|
||||
|
||||
@staticmethod
|
||||
def _record_experiences_local(experiences: list[RoleExperience]):
|
||||
|
|
@ -85,18 +79,25 @@ class RetrieveExperiences(Action):
|
|||
collection_name: str = DEFAULT_COLLECTION_NAME
|
||||
has_experiences: bool = True
|
||||
engine: Optional[SimpleEngine] = None
|
||||
topk: int = 5
|
||||
topk: int = 10
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_collection(self):
|
||||
if self.engine:
|
||||
return
|
||||
try:
|
||||
self.engine.from_index(
|
||||
self.engine = SimpleEngine.from_index(
|
||||
index_config=ChromaIndexConfig(
|
||||
persist_path=PERSIST_PATH, collection_name=self.collection_name, metadata={"hnsw:space": "cosine"}
|
||||
),
|
||||
retriever_configs=ChromaRetrieverConfig(similarity_top_k=self.topk),
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(
|
||||
similarity_top_k=self.topk,
|
||||
persist_path=PERSIST_PATH,
|
||||
collection_name=self.collection_name,
|
||||
metadata={"hnsw:space": "cosine"},
|
||||
)
|
||||
],
|
||||
)
|
||||
except Exception as exp:
|
||||
logger.warning(f"No experience pool: {self.collection_name}, exp: {exp}")
|
||||
|
|
@ -123,8 +124,14 @@ class RetrieveExperiences(Action):
|
|||
results = self.engine.retrieve(query)
|
||||
|
||||
logger.info(f"retrieve {profile}'s experiences")
|
||||
past_experiences = [res.metadata["obj"] for res in results]
|
||||
if verbose:
|
||||
experiences = [res.metadata["obj"] for res in results]
|
||||
|
||||
past_experiences = [] # currently use post-process to filter, and later add `filters` in rag
|
||||
for exp in experiences:
|
||||
if exp.profile == profile and exp.version != excluded_version:
|
||||
past_experiences.append(exp)
|
||||
|
||||
if verbose and results:
|
||||
logger.info("past_experiences: {}".format("\n\n".join(past_experiences)))
|
||||
distances = results[0].score
|
||||
logger.info(f"distances: {distances}")
|
||||
|
|
|
|||
|
|
@ -161,6 +161,13 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
def retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
||||
nodes = super().retrieve(query_bundle)
|
||||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -69,22 +68,24 @@ class TestExperiencesOperation:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add(self):
|
||||
saved_file = f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
if os.path.exists(saved_file):
|
||||
os.remove(saved_file)
|
||||
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
|
||||
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
)
|
||||
if saved_file.exists():
|
||||
saved_file.unlink()
|
||||
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.run(self.samples_to_add)
|
||||
|
||||
# test insertion
|
||||
inserted = action.collection.get()
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
# test if we record the samples correctly to local file
|
||||
# & test if we could recover a embedding db from the file
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.add_from_file(saved_file)
|
||||
inserted = action.collection.get()
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -123,8 +124,8 @@ class TestActualRetrieve:
|
|||
async def test_check_experience_pool(self):
|
||||
logger.info("check experience pool")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
if action.collection:
|
||||
all_experiences = action.collection.get()
|
||||
if action.engine:
|
||||
all_experiences = action.engine.retriever._index._vector_store._collection.get()
|
||||
logger.info(f"{len(all_experiences['metadatas'])=}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue