diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index e358c2288..4448c3d02 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -3,24 +3,21 @@ This file is borrowed from OpenDevin You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ -import base64 import os import re import shutil import tempfile from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from pydantic import BaseModel, ConfigDict -from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger +from metagpt.tools.libs.index_repo import IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool -from metagpt.utils import read_docx -from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.file import File from metagpt.utils.report import EditorReporter # This is also used in unit tests! @@ -107,23 +104,12 @@ class Editor(BaseModel): async def read(self, path: str) -> FileBlock: """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" - is_text, mime_type = await is_text_file(path) - if is_text: - lines = await self._read_text(path) - elif mime_type == "application/pdf": - lines = await self._read_pdf(path) - elif mime_type in { - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-word.document.macroEnabled.12", - "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - "application/vnd.ms-word.template.macroEnabled.12", - }: - lines = await self._read_docx(path) - else: + content = await File.read_text_file(path) + if not content: return FileBlock(file_path=str(path), block_content="") self.resource.report(str(path), "path") + lines = content.splitlines(keepends=True) lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] result = FileBlock( file_path=str(path), @@ -131,80 +117,6 @@ class Editor(BaseModel): ) return result - @staticmethod - async def _read_text(path: Union[str, Path]) -> List[str]: - content = await aread(path) - lines = content.split("\n") - return lines - - @staticmethod - async def _read_pdf(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - - from llama_index.readers.file import PDFReader - - reader = PDFReader() - lines = reader.load_data(file=Path(path)) - return [i.text for i in lines] - - @staticmethod - async def _read_docx(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - return read_docx(str(path)) - - @staticmethod - async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]: - from metagpt.tools.libs import get_env_default - from metagpt.utils.omniparse_client import OmniParseClient - - env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") - env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") - conf_base_url, conf_timeout = await Editor._read_omniparse_config() - - base_url = env_base_url or conf_base_url - if not base_url: - return None - api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") - timeout = env_timeout or conf_timeout or 600 - try: - timeout = int(timeout) - except ValueError: - timeout = 600 - - try: - if not await check_http_endpoint(url=base_url): - logger.warning(f"{base_url}: NOT AVAILABLE") - return None - client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) - file_data = await aread_bin(filename=path) - ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) - except (ValueError, Exception) as e: - logger.exception(f"{path}: {e}") - return None - if not ret.images: - return [ret.text] if ret.text else None - - result = [ret.text] - img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") - img_dir.mkdir(parents=True, exist_ok=True) - for i in ret.images: - byte_data = base64.b64decode(i.image) - filename = img_dir / i.image_name - await awrite_bin(filename=filename, data=byte_data) - result.append(f"![{i.image_name}]({str(filename)})") - return result - - @staticmethod - async def _read_omniparse_config() -> Tuple[str, int]: - config = Config.default() - if config.omniparse and config.omniparse.url: - return config.omniparse.url, config.omniparse.timeout - return "", 0 - @staticmethod def _is_valid_filename(file_name: str) -> bool: if not file_name or not file_name.strip(): @@ -1023,3 +935,21 @@ class Editor(BaseModel): if not path.is_absolute(): path = self.working_dir / path return path + + @staticmethod + async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Searches the index repository for a given query across specified files or paths. + + This method classifies the provided files or paths, performing a search on each cluster + of files while handling other types of files separately. It merges results from structured + indices with any results from non-indexed files. + + Args: + query (str): The search query string to look for in the indexed files. + file_or_path (Union[str, Path]): A path or a filename to search within. + + Returns: + List[str]: A list of search results as strings, containing the text from the merged results + and any direct results from other files. + """ + return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index fadc11522..24065c4be 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -1,9 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +import asyncio import json +import re from pathlib import Path -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import tiktoken from llama_index.core.base.embeddings.base import BaseEmbedding @@ -16,7 +17,23 @@ from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.file import File + +UPLOADS_INDEX_ROOT = "/data/.index/uploads" +DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT +UPLOAD_ROOT = "/data/uploads" +DEFAULT_ROOT = UPLOAD_ROOT +CHATS_INDEX_ROOT = "/data/.index/chats" +CHATS_ROOT = "/data/chats/" +OTHER_TYPE = "other" + +DEFAULT_MIN_TOKEN_COUNT = 10000 +DEFAULT_MAX_TOKEN_COUNT = 100000000 + + +class IndexRepoMeta(BaseModel): + min_token_count: int + max_token_count: int class TextScore(BaseModel): @@ -26,12 +43,15 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): - persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ - root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` + root_path: str = ( + DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + ) fingerprint_filename: str = "fingerprint.json" + meta_filename: str = "meta.json" model: Optional[str] = None - min_token_count: int = 10000 - max_token_count: int = 100000000 + min_token_count: int = DEFAULT_MIN_TOKEN_COUNT + max_token_count: int = DEFAULT_MAX_TOKEN_COUNT recall_count: int = 5 embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) fingerprints: Dict[str, str] = Field(default_factory=dict) @@ -65,16 +85,21 @@ class IndexRepo(BaseModel): """ encoding = tiktoken.get_encoding("cl100k_base") result: List[Union[NodeWithScore, TextScore]] = [] - filenames, _ = await self._filter(filenames) + filenames, excludes = await self._filter(filenames) + if not filenames: + raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}") filter_filenames = set() + meta = await self._read_meta() for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) token_count = len(encoding.encode(content)) - if not self._is_buildable(token_count): + if not self._is_buildable( + token_count, min_token_count=meta.min_token_count, max_token_count=meta.max_token_count + ): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) - if self.fingerprints.get(str(i)) != file_fingerprint: + if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}: logger.error(f'file: "{i}" changed but not indexed') continue filter_filenames.add(str(i)) @@ -93,6 +118,10 @@ class IndexRepo(BaseModel): Returns: List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. """ + flat_nodes = [node for indices in indices_list if indices for node in indices if node] + if len(flat_nodes) <= self.recall_count: + return flat_nodes + if not self.embedding: config = Config.default() if self.model: @@ -102,7 +131,6 @@ class IndexRepo(BaseModel): scores = [] query_embedding = await self.embedding.aget_text_embedding(query) - flat_nodes = [node for indices in indices_list for node in indices] for i in flat_nodes: text_embedding = await self.embedding.aget_text_embedding(i.text) similarity = self.embedding.similarity(query_embedding, text_embedding) @@ -121,7 +149,7 @@ class IndexRepo(BaseModel): filter_filenames = [] delete_filenames = [] for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) if not self._is_fingerprint_changed(filename=i, content=content): continue token_count = len(encoding.encode(content)) @@ -169,10 +197,11 @@ class IndexRepo(BaseModel): logger.debug(f"add docs {filenames}") engine.persist(persist_dir=self.persist_path) for i in filenames: - content = await aread(i) + content = await File.read_text_file(i) fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + await self._save_meta() def __str__(self): """Return a string representation of the IndexRepo. @@ -182,7 +211,7 @@ class IndexRepo(BaseModel): """ return f"{self.persist_path}" - def _is_buildable(self, token_count: int) -> bool: + def _is_buildable(self, token_count: int, min_token_count: int = -1, max_token_count=-1) -> bool: """Check if the token count is within the buildable range. Args: @@ -191,7 +220,9 @@ class IndexRepo(BaseModel): Returns: bool: True if buildable, False otherwise. """ - if token_count < self.min_token_count or token_count > self.max_token_count: + min_token_count = min_token_count if min_token_count >= 0 else self.min_token_count + max_token_count = max_token_count if max_token_count >= 0 else self.max_token_count + if token_count < min_token_count or token_count > max_token_count: return False return True @@ -216,13 +247,13 @@ class IndexRepo(BaseModel): logger.debug(f"{path} not is_relative_to {root_path})") continue if not path.is_dir(): - is_text, _ = await is_text_file(path) + is_text = await File.is_textual_file(path) if is_text: pathnames.append(path) continue subfiles = list_files(path) for j in subfiles: - is_text, _ = await is_text_file(j) + is_text = await File.is_textual_file(j) if is_text: pathnames.append(j) @@ -240,7 +271,7 @@ class IndexRepo(BaseModel): List[NodeWithScore]: A list of nodes with scores matching the query. """ if not Path(self.persist_path).exists(): - return [] + raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") engine = SimpleEngine.from_index( index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] ) @@ -262,3 +293,114 @@ class IndexRepo(BaseModel): return True fp = generate_fingerprint(content) return old_fp != fp + + @staticmethod + def find_index_repo_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + """Map the file path to the corresponding index repo. + + Args: + files (List[Union[str, Path]]): A list of file paths or Path objects to be classified. + + Returns: + Tuple[Dict[str, Set[Path]], Dict[str, str]]: + - A dictionary mapping the index repo path to the files. + - A dictionary mapping the index repo path to their corresponding root directories. + """ + mappings = { + UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), + CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"), + } + + clusters = {} + roots = {} + for i in files: + path = Path(i).absolute() + path_type = OTHER_TYPE + for type_, pattern in mappings.items(): + if re.match(pattern, str(i)): + path_type = type_ + break + if path_type == CHATS_INDEX_ROOT: + chat_id = path.parts[3] + path_type = str(Path(path_type) / chat_id) + roots[path_type] = str(Path(CHATS_ROOT) / chat_id) + elif path_type == UPLOADS_INDEX_ROOT: + roots[path_type] = UPLOAD_ROOT + + if path_type in clusters: + clusters[path_type].add(path) + else: + clusters[path_type] = {path} + + return clusters, roots + + async def _save_meta(self): + meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + await awrite(filename=Path(self.persist_path) / self.meta_filename, data=meta.model_dump_json()) + + async def _read_meta(self) -> IndexRepoMeta: + default_meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + + filename = Path(self.persist_path) / self.meta_filename + if not filename.exists(): + return default_meta + meta_data = await aread(filename=filename) + try: + meta = IndexRepoMeta.model_validate_json(meta_data) + return meta + except Exception as e: + logger.warning(f"Load meta error: {e}") + return default_meta + + @staticmethod + async def cross_repo_search(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Search for a query across multiple repositories. + + This asynchronous function searches for the specified query in files + located at the given path or file. + + Args: + query (str): The search term to look for in the files. + file_or_path (Union[str, Path]): The path to the file or directory + where the search should be conducted. This can be a string path + or a Path object. + + Returns: + List[str]: A list of strings containing the paths of files that + contain the query results. + + Raises: + ValueError: If the query string is empty. + """ + if not file_or_path or not Path(file_or_path).exists(): + raise ValueError(f'"{str(file_or_path)}" not exists') + files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) + clusters, roots = IndexRepo.find_index_repo_path(files) + futures = [] + others = set() + for persist_path, filenames in clusters.items(): + if persist_path == OTHER_TYPE: + others.update(filenames) + continue + root = roots[persist_path] + repo = IndexRepo(persist_path=persist_path, root_path=root) + futures.append(repo.search(query=query, filenames=list(filenames))) + + for i in others: + futures.append(File.read_text_file(i)) + + futures_results = [] + if futures: + futures_results = await asyncio.gather(*futures) + + result = [] + v_result = [] + for i in futures_results: + if isinstance(i, str): + result.append(i) + else: + v_result.append(i) + + repo = IndexRepo() + merged = await repo.merge(query=query, indices_list=v_result) + return [i.text for i in merged] + result diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 8861f65dc..a3f612bcc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -6,13 +6,19 @@ @File : file.py @Describe : General file operations. """ +import base64 from pathlib import Path +from typing import Optional, Tuple, Union import aiofiles from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem +from metagpt.config2 import Config from metagpt.logs import logger +from metagpt.utils import read_docx +from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint from metagpt.utils.exceptions import handle_exception +from metagpt.utils.repo_to_markdown import is_text_file class File: @@ -70,6 +76,125 @@ class File: logger.debug(f"Successfully read file, the path of file: {file_path}") return content + @staticmethod + async def is_textual_file(filename: Union[str, Path]) -> bool: + """Determines if a given file is a textual file. + + A file is considered a textual file if it is plain text or has a + specific set of MIME types associated with textual formats, + including PDF and Microsoft Word documents. + + Args: + filename (Union[str, Path]): The path to the file to be checked. + + Returns: + bool: True if the file is a textual file, False otherwise. + """ + is_text, mime_type = await is_text_file(filename) + if is_text: + return True + if mime_type == "application/pdf": + return True + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return True + return False + + @staticmethod + async def read_text_file(filename: Union[str, Path]) -> Optional[str]: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + is_text, mime_type = await is_text_file(filename) + if is_text: + return await File._read_text(filename) + if mime_type == "application/pdf": + return await File._read_pdf(filename) + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return await File._read_docx(filename) + return None + + @staticmethod + async def _read_text(path: Union[str, Path]) -> str: + return await aread(path) + + @staticmethod + async def _read_pdf(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + + from llama_index.readers.file import PDFReader + + reader = PDFReader() + lines = reader.load_data(file=Path(path)) + return "\n".join([i.text for i in lines]) + + @staticmethod + async def _read_docx(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + return "\n".join(read_docx(str(path))) + + @staticmethod + async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]: + from metagpt.tools.libs import get_env_default + from metagpt.utils.omniparse_client import OmniParseClient + + env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") + env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") + conf_base_url, conf_timeout = await File._read_omniparse_config() + + base_url = env_base_url or conf_base_url + if not base_url: + return None + api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") + timeout = env_timeout or conf_timeout or 600 + try: + timeout = int(timeout) + except ValueError: + timeout = 600 + + try: + if not await check_http_endpoint(url=base_url): + logger.warning(f"{base_url}: NOT AVAILABLE") + return None + client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) + file_data = await aread_bin(filename=path) + ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) + except (ValueError, Exception) as e: + logger.exception(f"{path}: {e}") + return None + if not ret.images or not auto_save_image: + return ret.text + + result = [ret.text] + img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") + img_dir.mkdir(parents=True, exist_ok=True) + for i in ret.images: + byte_data = base64.b64decode(i.image) + filename = img_dir / i.image_name + await awrite_bin(filename=filename, data=byte_data) + result.append(f"![{i.image_name}]({str(filename)})") + return "\n".join(result) + + @staticmethod + async def _read_omniparse_config() -> Tuple[str, int]: + config = Config.default() + if config.omniparse and config.omniparse.url: + return config.omniparse.url, config.omniparse.timeout + return "", 0 + class MemoryFileSystem(_MemoryFileSystem): @classmethod diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index bcef2b74e..c601ee5a4 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -1,7 +1,19 @@ +import os +import shutil +from pathlib import Path + import pytest from metagpt.const import TEST_DATA_PATH from metagpt.tools.libs.editor import Editor +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + CHATS_ROOT, + UPLOAD_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +from metagpt.utils.common import list_files TEST_FILE_CONTENT = """ # this is line one @@ -645,5 +657,54 @@ def test_append_to_single_empty_line_file(): assert n_added_lines == 1 +async def mock_index_repo(): + chat_id = "1" + chat_path = Path(CHATS_ROOT) / chat_id + chat_path.mkdir(parents=True, exist_ok=True) + src_path = TEST_DATA_PATH / "requirements" + command = f"cp -rf {str(src_path)} {str(chat_path)}" + os.system(command) + filenames = list_files(chat_path) + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + chat_repo = IndexRepo( + persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 + ) + await chat_repo.add(chat_files) + assert chat_files + + Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) + command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" + os.system(command) + filenames = list_files(UPLOAD_ROOT) + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) + await uploads_repo.add(uploads_files) + assert uploads_files + + filenames = list_files(src_path) + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + assert other_files + + return chat_path, UPLOAD_ROOT, src_path + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_index_repo(): + # mock data + chat_path, UPLOAD_ROOT, src_path = await mock_index_repo() + + editor = Editor() + rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path) + assert rsp + + shutil.rmtree(CHATS_ROOT) + shutil.rmtree(UPLOAD_ROOT) + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py index 3cc8ad406..aec1e3f5e 100644 --- a/tests/metagpt/tools/libs/test_index_repo.py +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -1,11 +1,17 @@ import shutil +from pathlib import Path import pytest from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH -from metagpt.tools.libs.index_repo import IndexRepo +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) async def test_index_repo(path, query): @@ -28,5 +34,22 @@ async def test_index_repo(path, query): shutil.rmtree(index_path) +@pytest.mark.parametrize( + ("paths", "path_type", "root"), + [ + (["/data/uploads"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/uploads/"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/chats/1/1.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/1/2.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/2/2.txt", "/data/chats/2/2.txt"], str(Path(CHATS_INDEX_ROOT) / "2"), "/data/chats/2"), + (["/data/chats.txt"], "other", ""), + ], +) +def test_classify_path(paths, path_type, root): + result, result_root = IndexRepo.classify_path(paths) + assert path_type in set(result.keys()) + assert root == result_root.get(path_type, "") + + if __name__ == "__main__": pytest.main([__file__, "-s"])