faiss store: add tests

This commit is contained in:
geekan 2024-01-02 21:30:35 +08:00
parent 54201b1459
commit c3dd03671d
5 changed files with 14 additions and 10 deletions

1
.gitignore vendored
View file

@ -171,3 +171,4 @@ tests/metagpt/utils/file_repo_git
*.png
htmlcov
htmlcov.*
*.pkl

View file

@ -101,6 +101,7 @@ class Document(BaseModel):
raise ValueError("File path is not set.")
self.path.parent.mkdir(parents=True, exist_ok=True)
# TODO: excel, csv, json, etc.
self.path.write_text(self.content, encoding="utf-8")
def persist(self):
@ -126,10 +127,12 @@ class IndexableDocument(Document):
if not data_path.exists():
raise FileNotFoundError(f"File {data_path} not found.")
data = read_data(data_path)
content = data_path.read_text()
if isinstance(data, pd.DataFrame):
validate_cols(content_col, data)
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
return cls(data=data, content=str(data), content_col=content_col, meta_col=meta_col)
else:
content = data_path.read_text()
return cls(data=data, content=content, content_col=content_col, meta_col=meta_col)
def _get_docs_and_metadatas_by_df(self) -> (list, list):
df = self.data

View file

@ -14,7 +14,6 @@ from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
from metagpt.config import CONFIG
from metagpt.const import DATA_PATH
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
from metagpt.logs import logger
@ -76,10 +75,3 @@ class FaissStore(LocalStore):
def delete(self, *args, **kwargs):
"""Currently, langchain does not provide a delete interface."""
raise NotImplementedError
if __name__ == "__main__":
faiss_store = FaissStore(DATA_PATH / "qcs/qcs_4w.json")
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))
faiss_store.add([f"Oily Skin Facial Cleanser-{i}" for i in range(3)])
logger.info(faiss_store.search("Oily Skin Facial Cleanser"))

View file

@ -30,3 +30,11 @@ async def test_search_xlsx():
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
logger.info(result)
@pytest.mark.asyncio
async def test_write():
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore
assert _faiss_store.index