This commit is contained in:
luxiangtao 2024-04-16 20:09:11 +08:00
parent 258a0894b8
commit 200d47a5c0

View file

@ -3,7 +3,6 @@ import json
import chromadb
from pydantic import BaseModel
from examples.rag_pipeline import TRAVEL_DOC_PATH
from metagpt.actions import Action
from metagpt.const import SERDESER_PATH
from metagpt.logs import logger
@ -47,15 +46,9 @@ class AddNewTrajectories(Action):
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences."""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
engine = SimpleEngine.from_objs(
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name) # get chromadb collection
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"}) # delete the irrelevant record
return engine
async def run(self, planner: Planner, trajectory_collection_name: str = TRAJECTORY_COLLECTION_NAME):
@ -86,15 +79,9 @@ class AddNewExperiences(Action):
def _init_engine(self, collection_name: str):
"""Initialize a collection for storing code experiences."""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
engine = SimpleEngine.from_objs(
retriever_configs=[ChromaRetrieverConfig(persist_path=PERSIST_PATH, collection_name=collection_name)],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def _single_task_summary(self, trajectory_collection_name: str, experience_collection_name: str):
@ -183,19 +170,13 @@ class RetrieveExperiences(Action):
top_k (int): The number of eperiences to be retrieved.
"""
engine = SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
engine = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(
persist_path=PERSIST_PATH, collection_name=collection_name, similarity_top_k=top_k
)
],
)
# due to an irrelevant record being added to the vector database when loading from SimpleEngine.from_docs(), it is necessary to remove it.
db = chromadb.PersistentClient(path=str(PERSIST_PATH))
chroma_collection = db.get_or_create_collection(collection_name)
chroma_collection.delete(where_document={"$contains": "Bob likes traveling"})
return engine
async def run(