diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index f0c7cba9d..a5d9fc6aa 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -109,8 +109,10 @@ class IndexRepo(BaseModel): continue filter_filenames.add(str(i)) if new_files: - await self.add(paths=list(new_files.keys()), file_datas=new_files) - filter_filenames.update([str(i) for i in new_files.keys()]) + added, others = await self.add(paths=list(new_files.keys()), file_datas=new_files) + filter_filenames.update([str(i) for i in added]) + for i in others: + result.append(TextScore(filename=str(i), text=new_files.get(i))) nodes = await self._search(query=query, filters=filter_filenames) return result + nodes @@ -146,12 +148,19 @@ class IndexRepo(BaseModel): scores.sort(key=lambda x: x[0], reverse=True) return [i[1] for i in scores][: self.recall_count] - async def add(self, paths: List[Path], file_datas: Dict[Union[str, Path], str] = None): + async def add( + self, paths: List[Path], file_datas: Dict[Union[str, Path], str] = None + ) -> Tuple[List[str], List[str]]: """Add new documents to the index. Args: paths (List[Path]): A list of paths to the documents to be added. file_datas (Dict[Union[str, Path], str]): A list of file content. + + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists: + 1. The list of filenames that were successfully added to the index. + 2. The list of filenames that were not added to the index because they were not buildable. """ encoding = tiktoken.get_encoding("cl100k_base") filenames, _ = await self._filter(paths) @@ -170,6 +179,7 @@ class IndexRepo(BaseModel): delete_filenames.append(i) logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames, file_datas=file_datas) + return filter_filenames, delete_filenames async def _add_batch( self, diff --git a/tests/data/pdf/20210709逗你学云豆付费课程协议.pdf b/tests/data/pdf/20210709逗你学云豆付费课程协议.pdf new file mode 100644 index 000000000..278ab2160 Binary files /dev/null and b/tests/data/pdf/20210709逗你学云豆付费课程协议.pdf differ diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 6dbd6f274..24c368c7b 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -703,5 +703,28 @@ async def test_index_repo(): shutil.rmtree(UPLOAD_ROOT) +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "filename"), + [ + ( + "In this document, who are the legal representatives of both parties?", + TEST_DATA_PATH / "pdf/20210709逗你学云豆付费课程协议.pdf", + ) + ], +) +async def test_similarity_search(query, filename): + filename = Path(filename) + save_to = Path(UPLOAD_ROOT) / filename.name + save_to.parent.mkdir(parents=True, exist_ok=True) + os.system(f"cp {str(filename)} {str(save_to)}") + + editor = Editor() + rsp = await editor.similarity_search(query=query, file_or_path=save_to) + assert rsp + + save_to.unlink(missing_ok=True) + + if __name__ == "__main__": pytest.main([__file__, "-s"])