From 1eb3b8fb8c7ec51506237a90bc5dd3f2ea7b0af9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 4 Sep 2024 20:16:08 +0800 Subject: [PATCH 1/8] feat: /data/.index/chats/chat_id --- metagpt/tools/libs/index_repo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index fadc11522..fbe3c633b 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -26,7 +26,7 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): - persist_path: str # The persist path of the index repo, {DEFAULT_WORKSPACE_ROOT}/.index/{chat_id or 'uploads'}/ + persist_path: str # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. fingerprint_filename: str = "fingerprint.json" model: Optional[str] = None From 4523615dd92e2f889b8350ec6df8e82b209396b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 17:21:27 +0800 Subject: [PATCH 2/8] feat: +Editor.search_index_repo --- metagpt/tools/libs/editor.py | 51 +++++++++++++++++++ metagpt/tools/libs/index_repo.py | 52 ++++++++++++++++++-- tests/metagpt/tools/libs/test_editor.py | 54 +++++++++++++++++++++ tests/metagpt/tools/libs/test_index_repo.py | 25 +++++++++- 4 files changed, 177 insertions(+), 5 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 8013b99c9..ab7b2efd6 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -3,6 +3,7 @@ This file is borrowed from OpenDevin You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ +import asyncio import base64 import os import re @@ -16,6 +17,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger +from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils import read_docx @@ -951,3 +953,52 @@ class Editor(BaseModel): if not path.is_absolute(): path = self.working_dir / path return path + + @staticmethod + async def search_index_repo( + query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = 0 + ) -> List[str]: + """Searches the index repository for a given query across specified files or paths. + + This method classifies the provided files or paths, performing a search on each cluster + of files while handling other types of files separately. It merges results from structured + indices with any results from non-indexed files. + + Args: + query (str): The search query string to look for in the indexed files. + files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. + min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0. + + Returns: + List[str]: A list of search results as strings, containing the text from the merged results + and any direct results from other files. + """ + clusters, roots = IndexRepo.classify_path(files_or_paths) + futures = [] + others = set() + for persist_path, filenames in clusters.items(): + if persist_path == OTHER_TYPE: + others.update(filenames) + continue + root = roots[persist_path] + repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count) + futures.append(repo.search(query=query, filenames=list(filenames))) + + for i in others: + futures.append(aread(filename=i)) + + futures_results = [] + if futures: + futures_results = await asyncio.gather(*futures) + + result = [] + v_result = [] + for i in futures_results: + if isinstance(i, str): + result.append(i) + else: + v_result.append(i) + + repo = IndexRepo(min_token_count=min_token_count) + merged = await repo.merge(query=query, indices_list=v_result) + return [i.text for i in merged] + result diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index fbe3c633b..211225507 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -2,8 +2,9 @@ # -*- coding: utf-8 -*- import json +import re from pathlib import Path -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import tiktoken from llama_index.core.base.embeddings.base import BaseEmbedding @@ -18,6 +19,14 @@ from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRanker from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files from metagpt.utils.repo_to_markdown import is_text_file +UPLOADS_INDEX_ROOT = "/data/.index/uploads" +DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT +UPLOAD_ROOT = "/data/uploads" +DEFAULT_ROOT = UPLOAD_ROOT +CHATS_INDEX_ROOT = "/data/.index/chats" +CHATS_ROOT = "/data/chats/" +OTHER_TYPE = "other" + class TextScore(BaseModel): filename: str @@ -26,8 +35,10 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): - persist_path: str # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` - root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` + root_path: str = ( + DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + ) fingerprint_filename: str = "fingerprint.json" model: Optional[str] = None min_token_count: int = 10000 @@ -93,6 +104,10 @@ class IndexRepo(BaseModel): Returns: List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. """ + flat_nodes = [node for indices in indices_list for node in indices] + if len(flat_nodes) <= self.recall_count: + return flat_nodes + if not self.embedding: config = Config.default() if self.model: @@ -102,7 +117,6 @@ class IndexRepo(BaseModel): scores = [] query_embedding = await self.embedding.aget_text_embedding(query) - flat_nodes = [node for indices in indices_list for node in indices] for i in flat_nodes: text_embedding = await self.embedding.aget_text_embedding(i.text) similarity = self.embedding.similarity(query_embedding, text_embedding) @@ -262,3 +276,33 @@ class IndexRepo(BaseModel): return True fp = generate_fingerprint(content) return old_fp != fp + + @staticmethod + def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + mappings = { + UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), + CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"), + } + + clusters = {} + roots = {} + for i in files_or_paths: + path = Path(i).absolute() + path_type = OTHER_TYPE + for type_, pattern in mappings.items(): + if re.match(pattern, str(i)): + path_type = type_ + break + if path_type == CHATS_INDEX_ROOT: + chat_id = path.parts[3] + path_type = str(Path(path_type) / chat_id) + roots[path_type] = str(Path(CHATS_ROOT) / chat_id) + elif path_type == UPLOADS_INDEX_ROOT: + roots[path_type] = UPLOAD_ROOT + + if path_type in clusters: + clusters[path_type].add(path) + else: + clusters[path_type] = {path} + + return clusters, roots diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index bcef2b74e..26d53a703 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -1,7 +1,19 @@ +import os +import shutil +from pathlib import Path + import pytest from metagpt.const import TEST_DATA_PATH from metagpt.tools.libs.editor import Editor +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + CHATS_ROOT, + UPLOAD_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +from metagpt.utils.common import list_files TEST_FILE_CONTENT = """ # this is line one @@ -645,5 +657,47 @@ def test_append_to_single_empty_line_file(): assert n_added_lines == 1 +async def mock_index_repo(): + chat_id = "1" + chat_path = Path(CHATS_ROOT) / chat_id + chat_path.mkdir(parents=True, exist_ok=True) + src_path = TEST_DATA_PATH / "requirements" + command = f"cp -rf {str(src_path)} {str(chat_path)}" + os.system(command) + filenames = list_files(chat_path) + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + chat_repo = IndexRepo( + persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 + ) + await chat_repo.add(chat_files) + + Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) + command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" + os.system(command) + filenames = list_files(UPLOAD_ROOT) + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) + await uploads_repo.add(uploads_files) + + filenames = list_files(src_path) + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + + return chat_files, uploads_files, other_files + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_index_repo(): + # mock data + chat_files, uploads_files, other_files = await mock_index_repo() + + editor = Editor() + rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0) + assert rsp + + shutil.rmtree(CHATS_ROOT) + shutil.rmtree(UPLOAD_ROOT) + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py index 3cc8ad406..aec1e3f5e 100644 --- a/tests/metagpt/tools/libs/test_index_repo.py +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -1,11 +1,17 @@ import shutil +from pathlib import Path import pytest from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH -from metagpt.tools.libs.index_repo import IndexRepo +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) async def test_index_repo(path, query): @@ -28,5 +34,22 @@ async def test_index_repo(path, query): shutil.rmtree(index_path) +@pytest.mark.parametrize( + ("paths", "path_type", "root"), + [ + (["/data/uploads"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/uploads/"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/chats/1/1.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/1/2.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/2/2.txt", "/data/chats/2/2.txt"], str(Path(CHATS_INDEX_ROOT) / "2"), "/data/chats/2"), + (["/data/chats.txt"], "other", ""), + ], +) +def test_classify_path(paths, path_type, root): + result, result_root = IndexRepo.classify_path(paths) + assert path_type in set(result.keys()) + assert root == result_root.get(path_type, "") + + if __name__ == "__main__": pytest.main([__file__, "-s"]) From f0c980be241aa77bb58d7bf543a63193276cb250 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 17:25:33 +0800 Subject: [PATCH 3/8] feat: +comments --- metagpt/tools/libs/index_repo.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 211225507..9c1b0886a 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -279,6 +279,16 @@ class IndexRepo(BaseModel): @staticmethod def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + """Classify a list of file paths or Path objects into different categories. + + Args: + files_or_paths (List[Union[str, Path]]): A list of file paths or Path objects to be classified. + + Returns: + Tuple[Dict[str, Set[Path]], Dict[str, str]]: + - A dictionary mapping the classified path types to sets of corresponding Path objects. + - A dictionary mapping the classified path types to their corresponding root directories. + """ mappings = { UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"), From 285a6bf164f24db1472efa6cd0ab3f61fb2aaaff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 17:32:18 +0800 Subject: [PATCH 4/8] feat: +DEFAULT_MIN_TOKEN_COUNT --- metagpt/tools/libs/editor.py | 4 ++-- metagpt/tools/libs/index_repo.py | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index ab7b2efd6..9a1cf63c3 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo +from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils import read_docx @@ -956,7 +956,7 @@ class Editor(BaseModel): @staticmethod async def search_index_repo( - query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = 0 + query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = DEFAULT_MIN_TOKEN_COUNT ) -> List[str]: """Searches the index repository for a given query across specified files or paths. diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 9c1b0886a..b4907d74e 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -27,6 +27,9 @@ CHATS_INDEX_ROOT = "/data/.index/chats" CHATS_ROOT = "/data/chats/" OTHER_TYPE = "other" +DEFAULT_MIN_TOKEN_COUNT = 10000 +DEFAULT_MAX_TOKEN_COUNT = 100000000 + class TextScore(BaseModel): filename: str @@ -41,8 +44,8 @@ class IndexRepo(BaseModel): ) fingerprint_filename: str = "fingerprint.json" model: Optional[str] = None - min_token_count: int = 10000 - max_token_count: int = 100000000 + min_token_count: int = DEFAULT_MIN_TOKEN_COUNT + max_token_count: int = DEFAULT_MAX_TOKEN_COUNT recall_count: int = 5 embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) fingerprints: Dict[str, str] = Field(default_factory=dict) From ad80dab67859385e911816390034979e8a2f4ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 19:24:12 +0800 Subject: [PATCH 5/8] feat: +pdf --- metagpt/tools/libs/editor.py | 101 ++----------------- metagpt/tools/libs/index_repo.py | 18 ++-- metagpt/utils/file.py | 125 ++++++++++++++++++++++++ tests/metagpt/tools/libs/test_editor.py | 10 +- 4 files changed, 146 insertions(+), 108 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 9a1cf63c3..cf56eb292 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -4,25 +4,21 @@ You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ import asyncio -import base64 import os import re import shutil import tempfile from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import List, Optional, Union from pydantic import BaseModel, ConfigDict -from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool -from metagpt.utils import read_docx -from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.file import File from metagpt.utils.report import EditorReporter # This is also used in unit tests! @@ -72,23 +68,12 @@ class Editor(BaseModel): async def read(self, path: str) -> FileBlock: """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" - is_text, mime_type = await is_text_file(path) - if is_text: - lines = await self._read_text(path) - elif mime_type == "application/pdf": - lines = await self._read_pdf(path) - elif mime_type in { - "application/msword", - "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - "application/vnd.ms-word.document.macroEnabled.12", - "application/vnd.openxmlformats-officedocument.wordprocessingml.template", - "application/vnd.ms-word.template.macroEnabled.12", - }: - lines = await self._read_docx(path) - else: + content = await File.read_text_file(path) + if not content: return FileBlock(file_path=str(path), block_content="") self.resource.report(str(path), "path") + lines = content.splitlines(keepends=True) lines_with_num = [f"{i + 1:03}|{line}" for i, line in enumerate(lines)] result = FileBlock( file_path=str(path), @@ -96,80 +81,6 @@ class Editor(BaseModel): ) return result - @staticmethod - async def _read_text(path: Union[str, Path]) -> List[str]: - content = await aread(path) - lines = content.split("\n") - return lines - - @staticmethod - async def _read_pdf(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - - from llama_index.readers.file import PDFReader - - reader = PDFReader() - lines = reader.load_data(file=Path(path)) - return [i.text for i in lines] - - @staticmethod - async def _read_docx(path: Union[str, Path]) -> List[str]: - result = await Editor._omniparse_read_file(path) - if result: - return result - return read_docx(str(path)) - - @staticmethod - async def _omniparse_read_file(path: Union[str, Path]) -> Optional[List[str]]: - from metagpt.tools.libs import get_env_default - from metagpt.utils.omniparse_client import OmniParseClient - - env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") - env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") - conf_base_url, conf_timeout = await Editor._read_omniparse_config() - - base_url = env_base_url or conf_base_url - if not base_url: - return None - api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") - timeout = env_timeout or conf_timeout or 600 - try: - timeout = int(timeout) - except ValueError: - timeout = 600 - - try: - if not await check_http_endpoint(url=base_url): - logger.warning(f"{base_url}: NOT AVAILABLE") - return None - client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) - file_data = await aread_bin(filename=path) - ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) - except (ValueError, Exception) as e: - logger.exception(f"{path}: {e}") - return None - if not ret.images: - return [ret.text] if ret.text else None - - result = [ret.text] - img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") - img_dir.mkdir(parents=True, exist_ok=True) - for i in ret.images: - byte_data = base64.b64decode(i.image) - filename = img_dir / i.image_name - await awrite_bin(filename=filename, data=byte_data) - result.append(f"![{i.image_name}]({str(filename)})") - return result - - @staticmethod - async def _read_omniparse_config() -> Tuple[str, int]: - config = Config.default() - if config.omniparse and config.omniparse.url: - return config.omniparse.url, config.omniparse.timeout - return "", 0 - @staticmethod def _is_valid_filename(file_name: str) -> bool: if not file_name or not file_name.strip(): @@ -985,7 +896,7 @@ class Editor(BaseModel): futures.append(repo.search(query=query, filenames=list(filenames))) for i in others: - futures.append(aread(filename=i)) + futures.append(File.read_text_file(i)) futures_results = [] if futures: diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index b4907d74e..d29254318 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -16,8 +16,8 @@ from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig -from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files -from metagpt.utils.repo_to_markdown import is_text_file +from metagpt.utils.common import awrite, generate_fingerprint, list_files +from metagpt.utils.file import File UPLOADS_INDEX_ROOT = "/data/.index/uploads" DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT @@ -82,13 +82,13 @@ class IndexRepo(BaseModel): filenames, _ = await self._filter(filenames) filter_filenames = set() for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) token_count = len(encoding.encode(content)) if not self._is_buildable(token_count): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) - if self.fingerprints.get(str(i)) != file_fingerprint: + if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}: logger.error(f'file: "{i}" changed but not indexed') continue filter_filenames.add(str(i)) @@ -107,7 +107,7 @@ class IndexRepo(BaseModel): Returns: List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. """ - flat_nodes = [node for indices in indices_list for node in indices] + flat_nodes = [node for indices in indices_list if indices for node in indices if node] if len(flat_nodes) <= self.recall_count: return flat_nodes @@ -138,7 +138,7 @@ class IndexRepo(BaseModel): filter_filenames = [] delete_filenames = [] for i in filenames: - content = await aread(filename=i) + content = await File.read_text_file(i) if not self._is_fingerprint_changed(filename=i, content=content): continue token_count = len(encoding.encode(content)) @@ -186,7 +186,7 @@ class IndexRepo(BaseModel): logger.debug(f"add docs {filenames}") engine.persist(persist_dir=self.persist_path) for i in filenames: - content = await aread(i) + content = await File.read_text_file(i) fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) @@ -233,13 +233,13 @@ class IndexRepo(BaseModel): logger.debug(f"{path} not is_relative_to {root_path})") continue if not path.is_dir(): - is_text, _ = await is_text_file(path) + is_text = await File.is_textual_file(path) if is_text: pathnames.append(path) continue subfiles = list_files(path) for j in subfiles: - is_text, _ = await is_text_file(j) + is_text = await File.is_textual_file(j) if is_text: pathnames.append(j) diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 8861f65dc..a3f612bcc 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -6,13 +6,19 @@ @File : file.py @Describe : General file operations. """ +import base64 from pathlib import Path +from typing import Optional, Tuple, Union import aiofiles from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem +from metagpt.config2 import Config from metagpt.logs import logger +from metagpt.utils import read_docx +from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint from metagpt.utils.exceptions import handle_exception +from metagpt.utils.repo_to_markdown import is_text_file class File: @@ -70,6 +76,125 @@ class File: logger.debug(f"Successfully read file, the path of file: {file_path}") return content + @staticmethod + async def is_textual_file(filename: Union[str, Path]) -> bool: + """Determines if a given file is a textual file. + + A file is considered a textual file if it is plain text or has a + specific set of MIME types associated with textual formats, + including PDF and Microsoft Word documents. + + Args: + filename (Union[str, Path]): The path to the file to be checked. + + Returns: + bool: True if the file is a textual file, False otherwise. + """ + is_text, mime_type = await is_text_file(filename) + if is_text: + return True + if mime_type == "application/pdf": + return True + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return True + return False + + @staticmethod + async def read_text_file(filename: Union[str, Path]) -> Optional[str]: + """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + is_text, mime_type = await is_text_file(filename) + if is_text: + return await File._read_text(filename) + if mime_type == "application/pdf": + return await File._read_pdf(filename) + if mime_type in { + "application/msword", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "application/vnd.ms-word.document.macroEnabled.12", + "application/vnd.openxmlformats-officedocument.wordprocessingml.template", + "application/vnd.ms-word.template.macroEnabled.12", + }: + return await File._read_docx(filename) + return None + + @staticmethod + async def _read_text(path: Union[str, Path]) -> str: + return await aread(path) + + @staticmethod + async def _read_pdf(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + + from llama_index.readers.file import PDFReader + + reader = PDFReader() + lines = reader.load_data(file=Path(path)) + return "\n".join([i.text for i in lines]) + + @staticmethod + async def _read_docx(path: Union[str, Path]) -> str: + result = await File._omniparse_read_file(path) + if result: + return result + return "\n".join(read_docx(str(path))) + + @staticmethod + async def _omniparse_read_file(path: Union[str, Path], auto_save_image: bool = False) -> Optional[str]: + from metagpt.tools.libs import get_env_default + from metagpt.utils.omniparse_client import OmniParseClient + + env_base_url = await get_env_default(key="base_url", app_name="OmniParse", default_value="") + env_timeout = await get_env_default(key="timeout", app_name="OmniParse", default_value="") + conf_base_url, conf_timeout = await File._read_omniparse_config() + + base_url = env_base_url or conf_base_url + if not base_url: + return None + api_key = await get_env_default(key="api_key", app_name="OmniParse", default_value="") + timeout = env_timeout or conf_timeout or 600 + try: + timeout = int(timeout) + except ValueError: + timeout = 600 + + try: + if not await check_http_endpoint(url=base_url): + logger.warning(f"{base_url}: NOT AVAILABLE") + return None + client = OmniParseClient(api_key=api_key, base_url=base_url, max_timeout=timeout) + file_data = await aread_bin(filename=path) + ret = await client.parse_document(file_input=file_data, bytes_filename=str(path)) + except (ValueError, Exception) as e: + logger.exception(f"{path}: {e}") + return None + if not ret.images or not auto_save_image: + return ret.text + + result = [ret.text] + img_dir = Path(path).parent / (Path(path).name.replace(".", "_") + "_images") + img_dir.mkdir(parents=True, exist_ok=True) + for i in ret.images: + byte_data = base64.b64decode(i.image) + filename = img_dir / i.image_name + await awrite_bin(filename=filename, data=byte_data) + result.append(f"![{i.image_name}]({str(filename)})") + return "\n".join(result) + + @staticmethod + async def _read_omniparse_config() -> Tuple[str, int]: + config = Config.default() + if config.omniparse and config.omniparse.url: + return config.omniparse.url, config.omniparse.timeout + return "", 0 + class MemoryFileSystem(_MemoryFileSystem): @classmethod diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 26d53a703..3a4cf65fe 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -665,7 +665,7 @@ async def mock_index_repo(): command = f"cp -rf {str(src_path)} {str(chat_path)}" os.system(command) filenames = list_files(chat_path) - chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] chat_repo = IndexRepo( persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 ) @@ -675,12 +675,12 @@ async def mock_index_repo(): command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" os.system(command) filenames = list_files(UPLOAD_ROOT) - uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) await uploads_repo.add(uploads_files) filenames = list_files(src_path) - other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] return chat_files, uploads_files, other_files @@ -692,7 +692,9 @@ async def test_index_repo(): chat_files, uploads_files, other_files = await mock_index_repo() editor = Editor() - rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0) + rsp = await editor.search_index_repo( + query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0 + ) assert rsp shutil.rmtree(CHATS_ROOT) From 95bf4c3e226f2b6e4788231ba3e505a6598484d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 6 Sep 2024 11:07:31 +0800 Subject: [PATCH 6/8] =?UTF-8?q?feat:=20search=E6=97=B6=E4=B8=8D=E5=BF=85?= =?UTF-8?q?=E8=AE=BE=E7=BD=AEIndexRepo=E7=9A=84min/max=5Ftoken=5Fcount?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- metagpt/tools/libs/editor.py | 11 +++---- metagpt/tools/libs/index_repo.py | 38 ++++++++++++++++++++++--- tests/metagpt/tools/libs/test_editor.py | 4 +-- 3 files changed, 39 insertions(+), 14 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index cf56eb292..6bb43458c 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo +from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils.file import File @@ -866,9 +866,7 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo( - query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = DEFAULT_MIN_TOKEN_COUNT - ) -> List[str]: + async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]: """Searches the index repository for a given query across specified files or paths. This method classifies the provided files or paths, performing a search on each cluster @@ -878,7 +876,6 @@ class Editor(BaseModel): Args: query (str): The search query string to look for in the indexed files. files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. - min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0. Returns: List[str]: A list of search results as strings, containing the text from the merged results @@ -892,7 +889,7 @@ class Editor(BaseModel): others.update(filenames) continue root = roots[persist_path] - repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count) + repo = IndexRepo(persist_path=persist_path, root_path=root) futures.append(repo.search(query=query, filenames=list(filenames))) for i in others: @@ -910,6 +907,6 @@ class Editor(BaseModel): else: v_result.append(i) - repo = IndexRepo(min_token_count=min_token_count) + repo = IndexRepo() merged = await repo.merge(query=query, indices_list=v_result) return [i.text for i in merged] + result diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index d29254318..216304e93 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -16,7 +16,7 @@ from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig -from metagpt.utils.common import awrite, generate_fingerprint, list_files +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files from metagpt.utils.file import File UPLOADS_INDEX_ROOT = "/data/.index/uploads" @@ -31,6 +31,11 @@ DEFAULT_MIN_TOKEN_COUNT = 10000 DEFAULT_MAX_TOKEN_COUNT = 100000000 +class IndexRepoMeta(BaseModel): + min_token_count: int + max_token_count: int + + class TextScore(BaseModel): filename: str text: str @@ -43,6 +48,7 @@ class IndexRepo(BaseModel): DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. ) fingerprint_filename: str = "fingerprint.json" + meta_filename: str = "meta.json" model: Optional[str] = None min_token_count: int = DEFAULT_MIN_TOKEN_COUNT max_token_count: int = DEFAULT_MAX_TOKEN_COUNT @@ -81,10 +87,13 @@ class IndexRepo(BaseModel): result: List[Union[NodeWithScore, TextScore]] = [] filenames, _ = await self._filter(filenames) filter_filenames = set() + meta = await self._read_meta() for i in filenames: content = await File.read_text_file(i) token_count = len(encoding.encode(content)) - if not self._is_buildable(token_count): + if not self._is_buildable( + token_count, min_token_count=meta.min_token_count, max_token_count=meta.max_token_count + ): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) @@ -190,6 +199,7 @@ class IndexRepo(BaseModel): fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + await self._save_meta() def __str__(self): """Return a string representation of the IndexRepo. @@ -199,7 +209,7 @@ class IndexRepo(BaseModel): """ return f"{self.persist_path}" - def _is_buildable(self, token_count: int) -> bool: + def _is_buildable(self, token_count: int, min_token_count: int = -1, max_token_count=-1) -> bool: """Check if the token count is within the buildable range. Args: @@ -208,7 +218,9 @@ class IndexRepo(BaseModel): Returns: bool: True if buildable, False otherwise. """ - if token_count < self.min_token_count or token_count > self.max_token_count: + min_token_count = min_token_count if min_token_count >= 0 else self.min_token_count + max_token_count = max_token_count if max_token_count >= 0 else self.max_token_count + if token_count < min_token_count or token_count > max_token_count: return False return True @@ -319,3 +331,21 @@ class IndexRepo(BaseModel): clusters[path_type] = {path} return clusters, roots + + async def _save_meta(self): + meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + await awrite(filename=Path(self.persist_path) / self.meta_filename, data=meta.model_dump_json()) + + async def _read_meta(self) -> IndexRepoMeta: + default_meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + + filename = Path(self.persist_path) / self.meta_filename + if not filename.exists(): + return default_meta + meta_data = await aread(filename=filename) + try: + meta = IndexRepoMeta.model_validate_json(meta_data) + return meta + except Exception as e: + logger.warning(f"Load meta error: {e}") + return default_meta diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 3a4cf65fe..20857b97f 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -692,9 +692,7 @@ async def test_index_repo(): chat_files, uploads_files, other_files = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo( - query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0 - ) + rsp = await editor.search_index_repo(query="业务线", files_or_paths=chat_files + uploads_files + other_files) assert rsp shutil.rmtree(CHATS_ROOT) From 6a57cb5e0af335d4001a613c9f640e07ca8346cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 6 Sep 2024 15:35:47 +0800 Subject: [PATCH 7/8] feat: search_index_repo path or filename --- metagpt/tools/libs/editor.py | 10 +++++++--- metagpt/tools/libs/index_repo.py | 12 +++++++----- tests/metagpt/tools/libs/test_editor.py | 13 ++++++++++--- 3 files changed, 24 insertions(+), 11 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 6bb43458c..89a3f70cc 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -18,6 +18,7 @@ from metagpt.logs import logger from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool +from metagpt.utils.common import list_files from metagpt.utils.file import File from metagpt.utils.report import EditorReporter @@ -866,7 +867,7 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]: + async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]: """Searches the index repository for a given query across specified files or paths. This method classifies the provided files or paths, performing a search on each cluster @@ -875,13 +876,16 @@ class Editor(BaseModel): Args: query (str): The search query string to look for in the indexed files. - files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. + file_or_path (Union[str, Path]): A path or a filename to search within. Returns: List[str]: A list of search results as strings, containing the text from the merged results and any direct results from other files. """ - clusters, roots = IndexRepo.classify_path(files_or_paths) + if not file_or_path or not Path(file_or_path).exists(): + raise ValueError(f'"{str(file_or_path)}" not exists') + files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) + clusters, roots = IndexRepo.classify_path(files) futures = [] others = set() for persist_path, filenames in clusters.items(): diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 216304e93..7de6bce5e 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -85,7 +85,9 @@ class IndexRepo(BaseModel): """ encoding = tiktoken.get_encoding("cl100k_base") result: List[Union[NodeWithScore, TextScore]] = [] - filenames, _ = await self._filter(filenames) + filenames, excludes = await self._filter(filenames) + if not filenames: + raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}") filter_filenames = set() meta = await self._read_meta() for i in filenames: @@ -269,7 +271,7 @@ class IndexRepo(BaseModel): List[NodeWithScore]: A list of nodes with scores matching the query. """ if not Path(self.persist_path).exists(): - return [] + raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") engine = SimpleEngine.from_index( index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] ) @@ -293,11 +295,11 @@ class IndexRepo(BaseModel): return old_fp != fp @staticmethod - def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + def classify_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: """Classify a list of file paths or Path objects into different categories. Args: - files_or_paths (List[Union[str, Path]]): A list of file paths or Path objects to be classified. + files (List[Union[str, Path]]): A list of file paths or Path objects to be classified. Returns: Tuple[Dict[str, Set[Path]], Dict[str, str]]: @@ -311,7 +313,7 @@ class IndexRepo(BaseModel): clusters = {} roots = {} - for i in files_or_paths: + for i in files: path = Path(i).absolute() path_type = OTHER_TYPE for type_, pattern in mappings.items(): diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 20857b97f..c601ee5a4 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -670,6 +670,7 @@ async def mock_index_repo(): persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 ) await chat_repo.add(chat_files) + assert chat_files Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" @@ -678,21 +679,27 @@ async def mock_index_repo(): uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) await uploads_repo.add(uploads_files) + assert uploads_files filenames = list_files(src_path) other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] + assert other_files - return chat_files, uploads_files, other_files + return chat_path, UPLOAD_ROOT, src_path @pytest.mark.skip @pytest.mark.asyncio async def test_index_repo(): # mock data - chat_files, uploads_files, other_files = await mock_index_repo() + chat_path, UPLOAD_ROOT, src_path = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo(query="业务线", files_or_paths=chat_files + uploads_files + other_files) + rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT) + assert rsp + rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path) assert rsp shutil.rmtree(CHATS_ROOT) From 85c1d0799025c22a74a19668ef4b7d4baafbec70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Fri, 6 Sep 2024 18:31:11 +0800 Subject: [PATCH 8/8] refactor: cross_repo_search --- metagpt/tools/libs/editor.py | 37 +------------------ metagpt/tools/libs/index_repo.py | 63 +++++++++++++++++++++++++++++--- 2 files changed, 60 insertions(+), 40 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 89a3f70cc..29e04b56f 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -3,7 +3,6 @@ This file is borrowed from OpenDevin You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ -import asyncio import os import re import shutil @@ -15,10 +14,9 @@ from pydantic import BaseModel, ConfigDict from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo +from metagpt.tools.libs.index_repo import IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool -from metagpt.utils.common import list_files from metagpt.utils.file import File from metagpt.utils.report import EditorReporter @@ -882,35 +880,4 @@ class Editor(BaseModel): List[str]: A list of search results as strings, containing the text from the merged results and any direct results from other files. """ - if not file_or_path or not Path(file_or_path).exists(): - raise ValueError(f'"{str(file_or_path)}" not exists') - files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) - clusters, roots = IndexRepo.classify_path(files) - futures = [] - others = set() - for persist_path, filenames in clusters.items(): - if persist_path == OTHER_TYPE: - others.update(filenames) - continue - root = roots[persist_path] - repo = IndexRepo(persist_path=persist_path, root_path=root) - futures.append(repo.search(query=query, filenames=list(filenames))) - - for i in others: - futures.append(File.read_text_file(i)) - - futures_results = [] - if futures: - futures_results = await asyncio.gather(*futures) - - result = [] - v_result = [] - for i in futures_results: - if isinstance(i, str): - result.append(i) - else: - v_result.append(i) - - repo = IndexRepo() - merged = await repo.merge(query=query, indices_list=v_result) - return [i.text for i in merged] + result + return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 7de6bce5e..24065c4be 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- - +import asyncio import json import re from pathlib import Path @@ -295,16 +295,16 @@ class IndexRepo(BaseModel): return old_fp != fp @staticmethod - def classify_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: - """Classify a list of file paths or Path objects into different categories. + def find_index_repo_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + """Map the file path to the corresponding index repo. Args: files (List[Union[str, Path]]): A list of file paths or Path objects to be classified. Returns: Tuple[Dict[str, Set[Path]], Dict[str, str]]: - - A dictionary mapping the classified path types to sets of corresponding Path objects. - - A dictionary mapping the classified path types to their corresponding root directories. + - A dictionary mapping the index repo path to the files. + - A dictionary mapping the index repo path to their corresponding root directories. """ mappings = { UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), @@ -351,3 +351,56 @@ class IndexRepo(BaseModel): except Exception as e: logger.warning(f"Load meta error: {e}") return default_meta + + @staticmethod + async def cross_repo_search(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Search for a query across multiple repositories. + + This asynchronous function searches for the specified query in files + located at the given path or file. + + Args: + query (str): The search term to look for in the files. + file_or_path (Union[str, Path]): The path to the file or directory + where the search should be conducted. This can be a string path + or a Path object. + + Returns: + List[str]: A list of strings containing the paths of files that + contain the query results. + + Raises: + ValueError: If the query string is empty. + """ + if not file_or_path or not Path(file_or_path).exists(): + raise ValueError(f'"{str(file_or_path)}" not exists') + files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path) + clusters, roots = IndexRepo.find_index_repo_path(files) + futures = [] + others = set() + for persist_path, filenames in clusters.items(): + if persist_path == OTHER_TYPE: + others.update(filenames) + continue + root = roots[persist_path] + repo = IndexRepo(persist_path=persist_path, root_path=root) + futures.append(repo.search(query=query, filenames=list(filenames))) + + for i in others: + futures.append(File.read_text_file(i)) + + futures_results = [] + if futures: + futures_results = await asyncio.gather(*futures) + + result = [] + v_result = [] + for i in futures_results: + if isinstance(i, str): + result.append(i) + else: + v_result.append(i) + + repo = IndexRepo() + merged = await repo.merge(query=query, indices_list=v_result) + return [i.text for i in merged] + result