mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-29 19:06:23 +02:00
update werewolf experience and add rag retrieve
This commit is contained in:
parent
5caaea2aeb
commit
a466bc9243
3 changed files with 44 additions and 29 deletions
|
|
@ -1,5 +1,4 @@
|
|||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -69,22 +68,24 @@ class TestExperiencesOperation:
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add(self):
|
||||
saved_file = f"{DEFAULT_WORKSPACE_ROOT}/werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
if os.path.exists(saved_file):
|
||||
os.remove(saved_file)
|
||||
saved_file = DEFAULT_WORKSPACE_ROOT.joinpath(
|
||||
f"werewolf_game/experiences/{self.version}/{self.test_round_id}.json"
|
||||
)
|
||||
if saved_file.exists():
|
||||
saved_file.unlink()
|
||||
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.run(self.samples_to_add)
|
||||
|
||||
# test insertion
|
||||
inserted = action.collection.get()
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
# test if we record the samples correctly to local file
|
||||
# & test if we could recover a embedding db from the file
|
||||
action = AddNewExperiences(collection_name=self.collection_name, delete_existing=True)
|
||||
action.add_from_file(saved_file)
|
||||
inserted = action.collection.get()
|
||||
inserted = action.engine.retriever._index._vector_store._collection.get()
|
||||
assert len(inserted["documents"]) == len(self.samples_to_add)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -123,8 +124,8 @@ class TestActualRetrieve:
|
|||
async def test_check_experience_pool(self):
|
||||
logger.info("check experience pool")
|
||||
action = RetrieveExperiences(collection_name=self.collection_name)
|
||||
if action.collection:
|
||||
all_experiences = action.collection.get()
|
||||
if action.engine:
|
||||
all_experiences = action.engine.retriever._index._vector_store._collection.get()
|
||||
logger.info(f"{len(all_experiences['metadatas'])=}")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue