diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 6bb43458c..89a3f70cc 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -18,6 +18,7 @@ from metagpt.logs import logger from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import list_files from metagpt.utils.file import File from metagpt.utils.report import EditorReporter @@ -866,7 +867,7 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]: + async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]: """Searches the index repository for a given query across specified files or paths. This method classifies the provided files or paths, performing a search on each cluster @@ -875,13 +876,16 @@ class Editor(BaseModel): Args: query (str): The search query string to look for in the indexed files. - files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. + file_or_path (Union[str, Path]): A path or a filename to search within. Returns: List[str]: A list of search results as strings, containing the text from the merged results and any direct results from other files. """ - clusters, roots = IndexRepo.classify_path(files_or_paths) + if not file_or_path or not Path(file_or_path).exists(): + raise ValueError(f'"{str(file_or_path)}" not exists') + files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) + clusters, roots = IndexRepo.classify_path(files) futures = [] others = set() for persist_path, filenames in clusters.items(): diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 216304e93..7de6bce5e 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -85,7 +85,9 @@ class IndexRepo(BaseModel): """ encoding = tiktoken.get_encoding("cl100k_base") result: List[Union[NodeWithScore, TextScore]] = [] - filenames, _ = await self._filter(filenames) + filenames, excludes = await self._filter(filenames) + if not filenames: + raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}") filter_filenames = set() meta = await self._read_meta() for i in filenames: @@ -269,7 +271,7 @@ class IndexRepo(BaseModel): List[NodeWithScore]: A list of nodes with scores matching the query. """ if not Path(self.persist_path).exists(): - return [] + raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") engine = SimpleEngine.from_index( index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] ) @@ -293,11 +295,11 @@ class IndexRepo(BaseModel): return old_fp != fp @staticmethod - def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + def classify_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: """Classify a list of file paths or Path objects into different categories. Args: - files_or_paths (List[Union[str, Path]]): A list of file paths or Path objects to be classified. + files (List[Union[str, Path]]): A list of file paths or Path objects to be classified. Returns: Tuple[Dict[str, Set[Path]], Dict[str, str]]: @@ -311,7 +313,7 @@ class IndexRepo(BaseModel): clusters = {} roots = {} - for i in files_or_paths: + for i in files: path = Path(i).absolute() path_type = OTHER_TYPE for type_, pattern in mappings.items(): diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 20857b97f..c601ee5a4 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -670,6 +670,7 @@ async def mock_index_repo(): persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 ) await chat_repo.add(chat_files) + assert chat_files Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" @@ -678,21 +679,27 @@ async def mock_index_repo(): uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) await uploads_repo.add(uploads_files) + assert uploads_files filenames = list_files(src_path) other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + assert other_files - return chat_files, uploads_files, other_files + return chat_path, UPLOAD_ROOT, src_path @pytest.mark.skip @pytest.mark.asyncio async def test_index_repo(): # mock data - chat_files, uploads_files, other_files = await mock_index_repo() + chat_path, UPLOAD_ROOT, src_path = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo(query="业务线", files_or_paths=chat_files + uploads_files + other_files) + rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path) assert rsp shutil.rmtree(CHATS_ROOT)