feat: persist_path

This commit is contained in:
莘权 马 2024-09-04 19:11:45 +08:00
parent ed0b9e33bc
commit aa710129f7
2 changed files with 13 additions and 12 deletions

View file

@ -26,7 +26,7 @@ class TextScore(BaseModel):
class IndexRepo(BaseModel):
filename: str # The filename of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/
persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/
root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
fingerprint_filename: str = "fingerprint.json"
model: Optional[str] = None
@ -44,7 +44,7 @@ class IndexRepo(BaseModel):
IndexRepo: The updated IndexRepo instance.
"""
if not self.fingerprints:
filename = Path(self.filename) / self.fingerprint_filename
filename = Path(self.persist_path) / self.fingerprint_filename
if not filename.exists():
return self
with open(str(filename), "r") as reader:
@ -144,10 +144,11 @@ class IndexRepo(BaseModel):
return
logger.info(f"update index repo, add {filenames}, remove {delete_filenames}")
engine = None
if Path(self.filename).exists():
logger.debug(f"load index from {self.filename}")
if Path(self.persist_path).exists():
logger.debug(f"load index from {self.persist_path}")
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()]
index_config=FAISSIndexConfig(persist_path=self.persist_path),
retriever_configs=[FAISSRetrieverConfig()],
)
try:
engine.delete_docs(filenames + delete_filenames)
@ -166,12 +167,12 @@ class IndexRepo(BaseModel):
ranker_configs=[LLMRankerConfig()],
)
logger.debug(f"add docs {filenames}")
engine.persist(persist_dir=self.filename)
engine.persist(persist_dir=self.persist_path)
for i in filenames:
content = await aread(i)
fp = generate_fingerprint(content)
self.fingerprints[str(i)] = fp
await awrite(filename=Path(self.filename) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
def __str__(self):
"""Return a string representation of the IndexRepo.
@ -179,7 +180,7 @@ class IndexRepo(BaseModel):
Returns:
str: The filename of the index repository.
"""
return f"{self.filename}"
return f"{self.persist_path}"
def _is_buildable(self, token_count: int) -> bool:
"""Check if the token count is within the buildable range.
@ -238,10 +239,10 @@ class IndexRepo(BaseModel):
Returns:
List[NodeWithScore]: A list of nodes with scores matching the query.
"""
if not Path(self.filename).exists():
if not Path(self.persist_path).exists():
return []
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()]
index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()]
)
rsp = await engine.aretrieve(query)
return [i for i in rsp if i.metadata.get("file_path") in filters]

View file

@ -10,7 +10,7 @@ from metagpt.tools.libs.index_repo import IndexRepo
@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")])
async def test_index_repo(path, query):
index_path = DEFAULT_WORKSPACE_ROOT / ".index"
repo = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0)
repo = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0)
await repo.add([path])
await repo.add([path])
assert index_path.exists()
@ -18,7 +18,7 @@ async def test_index_repo(path, query):
rsp = await repo.search(query)
assert rsp
repo2 = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0)
repo2 = IndexRepo(persist_path=str(index_path), root_path=str(path), min_token_count=0)
rsp2 = await repo2.search(query)
assert rsp2