From aa710129f7057e93c72eb67ef8070cb07f6a2df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 4 Sep 2024 19:11:45 +0800 Subject: [PATCH] feat: persist_path --- metagpt/tools/libs/index_repo.py | 21 +++++++++++---------- tests/metagpt/tools/libs/test_index_repo.py | 4 ++-- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 5944b5cfc..fadc11522 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -26,7 +26,7 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): - filename: str # The filename of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ + persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. fingerprint_filename: str = "fingerprint.json" model: Optional[str] = None @@ -44,7 +44,7 @@ class IndexRepo(BaseModel): IndexRepo: The updated IndexRepo instance. """ if not self.fingerprints: - filename = Path(self.filename) / self.fingerprint_filename + filename = Path(self.persist_path) / self.fingerprint_filename if not filename.exists(): return self with open(str(filename), "r") as reader: @@ -144,10 +144,11 @@ class IndexRepo(BaseModel): return logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") engine = None - if Path(self.filename).exists(): - logger.debug(f"load index from {self.filename}") + if Path(self.persist_path).exists(): + logger.debug(f"load index from {self.persist_path}") engine = SimpleEngine.from_index( - index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()] + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], ) try: engine.delete_docs(filenames + delete_filenames) @@ -166,12 +167,12 @@ class IndexRepo(BaseModel): ranker_configs=[LLMRankerConfig()], ) logger.debug(f"add docs {filenames}") - engine.persist(persist_dir=self.filename) + engine.persist(persist_dir=self.persist_path) for i in filenames: content = await aread(i) fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp - await awrite(filename=Path(self.filename) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) def __str__(self): """Return a string representation of the IndexRepo. @@ -179,7 +180,7 @@ class IndexRepo(BaseModel): Returns: str: The filename of the index repository. """ - return f"{self.filename}" + return f"{self.persist_path}" def _is_buildable(self, token_count: int) -> bool: """Check if the token count is within the buildable range. @@ -238,10 +239,10 @@ class IndexRepo(BaseModel): Returns: List[NodeWithScore]: A list of nodes with scores matching the query. """ - if not Path(self.filename).exists(): + if not Path(self.persist_path).exists(): return [] engine = SimpleEngine.from_index( - index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()] + index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] ) rsp = await engine.aretrieve(query) return [i for i in rsp if i.metadata.get("file_path") in filters] diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py index 65c5f1af9..3cc8ad406 100644 --- a/tests/metagpt/tools/libs/test_index_repo.py +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -10,7 +10,7 @@ from metagpt.tools.libs.index_repo import IndexRepo @pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) async def test_index_repo(path, query): index_path = DEFAULT_WORKSPACE_ROOT / ".index" - repo = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0) + repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) await repo.add([path]) await repo.add([path]) assert index_path.exists() @@ -18,7 +18,7 @@ async def test_index_repo(path, query): rsp = await repo.search(query) assert rsp - repo2 = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0) + repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0) rsp2 = await repo2.search(query) assert rsp2