From 4523615dd92e2f889b8350ec6df8e82b209396b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 5 Sep 2024 17:21:27 +0800 Subject: [PATCH] feat: +Editor.search_index_repo --- metagpt/tools/libs/editor.py | 51 +++++++++++++++++++ metagpt/tools/libs/index_repo.py | 52 ++++++++++++++++++-- tests/metagpt/tools/libs/test_editor.py | 54 +++++++++++++++++++++ tests/metagpt/tools/libs/test_index_repo.py | 25 +++++++++- 4 files changed, 177 insertions(+), 5 deletions(-) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 8013b99c9..ab7b2efd6 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -3,6 +3,7 @@ This file is borrowed from OpenDevin You can find the original repository here: https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py """ +import asyncio import base64 import os import re @@ -16,6 +17,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger +from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils import read_docx @@ -951,3 +953,52 @@ class Editor(BaseModel): if not path.is_absolute(): path = self.working_dir / path return path + + @staticmethod + async def search_index_repo( + query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = 0 + ) -> List[str]: + """Searches the index repository for a given query across specified files or paths. + + This method classifies the provided files or paths, performing a search on each cluster + of files while handling other types of files separately. It merges results from structured + indices with any results from non-indexed files. + + Args: + query (str): The search query string to look for in the indexed files. + files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. + min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0. + + Returns: + List[str]: A list of search results as strings, containing the text from the merged results + and any direct results from other files. + """ + clusters, roots = IndexRepo.classify_path(files_or_paths) + futures = [] + others = set() + for persist_path, filenames in clusters.items(): + if persist_path == OTHER_TYPE: + others.update(filenames) + continue + root = roots[persist_path] + repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count) + futures.append(repo.search(query=query, filenames=list(filenames))) + + for i in others: + futures.append(aread(filename=i)) + + futures_results = [] + if futures: + futures_results = await asyncio.gather(*futures) + + result = [] + v_result = [] + for i in futures_results: + if isinstance(i, str): + result.append(i) + else: + v_result.append(i) + + repo = IndexRepo(min_token_count=min_token_count) + merged = await repo.merge(query=query, indices_list=v_result) + return [i.text for i in merged] + result diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index fbe3c633b..211225507 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -2,8 +2,9 @@ # -*- coding: utf-8 -*- import json +import re from pathlib import Path -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import tiktoken from llama_index.core.base.embeddings.base import BaseEmbedding @@ -18,6 +19,14 @@ from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRanker from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files from metagpt.utils.repo_to_markdown import is_text_file +UPLOADS_INDEX_ROOT = "/data/.index/uploads" +DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT +UPLOAD_ROOT = "/data/uploads" +DEFAULT_ROOT = UPLOAD_ROOT +CHATS_INDEX_ROOT = "/data/.index/chats" +CHATS_ROOT = "/data/chats/" +OTHER_TYPE = "other" + class TextScore(BaseModel): filename: str @@ -26,8 +35,10 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): - persist_path: str # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` - root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` + root_path: str = ( + DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + ) fingerprint_filename: str = "fingerprint.json" model: Optional[str] = None min_token_count: int = 10000 @@ -93,6 +104,10 @@ class IndexRepo(BaseModel): Returns: List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. """ + flat_nodes = [node for indices in indices_list for node in indices] + if len(flat_nodes) <= self.recall_count: + return flat_nodes + if not self.embedding: config = Config.default() if self.model: @@ -102,7 +117,6 @@ class IndexRepo(BaseModel): scores = [] query_embedding = await self.embedding.aget_text_embedding(query) - flat_nodes = [node for indices in indices_list for node in indices] for i in flat_nodes: text_embedding = await self.embedding.aget_text_embedding(i.text) similarity = self.embedding.similarity(query_embedding, text_embedding) @@ -262,3 +276,33 @@ class IndexRepo(BaseModel): return True fp = generate_fingerprint(content) return old_fp != fp + + @staticmethod + def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]: + mappings = { + UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), + CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"), + } + + clusters = {} + roots = {} + for i in files_or_paths: + path = Path(i).absolute() + path_type = OTHER_TYPE + for type_, pattern in mappings.items(): + if re.match(pattern, str(i)): + path_type = type_ + break + if path_type == CHATS_INDEX_ROOT: + chat_id = path.parts[3] + path_type = str(Path(path_type) / chat_id) + roots[path_type] = str(Path(CHATS_ROOT) / chat_id) + elif path_type == UPLOADS_INDEX_ROOT: + roots[path_type] = UPLOAD_ROOT + + if path_type in clusters: + clusters[path_type].add(path) + else: + clusters[path_type] = {path} + + return clusters, roots diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index bcef2b74e..26d53a703 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -1,7 +1,19 @@ +import os +import shutil +from pathlib import Path + import pytest from metagpt.const import TEST_DATA_PATH from metagpt.tools.libs.editor import Editor +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + CHATS_ROOT, + UPLOAD_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +from metagpt.utils.common import list_files TEST_FILE_CONTENT = """ # this is line one @@ -645,5 +657,47 @@ def test_append_to_single_empty_line_file(): assert n_added_lines == 1 +async def mock_index_repo(): + chat_id = "1" + chat_path = Path(CHATS_ROOT) / chat_id + chat_path.mkdir(parents=True, exist_ok=True) + src_path = TEST_DATA_PATH / "requirements" + command = f"cp -rf {str(src_path)} {str(chat_path)}" + os.system(command) + filenames = list_files(chat_path) + chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + chat_repo = IndexRepo( + persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0 + ) + await chat_repo.add(chat_files) + + Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True) + command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}" + os.system(command) + filenames = list_files(UPLOAD_ROOT) + uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) + await uploads_repo.add(uploads_files) + + filenames = list_files(src_path) + other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}] + + return chat_files, uploads_files, other_files + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_index_repo(): + # mock data + chat_files, uploads_files, other_files = await mock_index_repo() + + editor = Editor() + rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0) + assert rsp + + shutil.rmtree(CHATS_ROOT) + shutil.rmtree(UPLOAD_ROOT) + + if __name__ == "__main__": pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py index 3cc8ad406..aec1e3f5e 100644 --- a/tests/metagpt/tools/libs/test_index_repo.py +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -1,11 +1,17 @@ import shutil +from pathlib import Path import pytest from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH -from metagpt.tools.libs.index_repo import IndexRepo +from metagpt.tools.libs.index_repo import ( + CHATS_INDEX_ROOT, + UPLOADS_INDEX_ROOT, + IndexRepo, +) +@pytest.mark.skip @pytest.mark.asyncio @pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) async def test_index_repo(path, query): @@ -28,5 +34,22 @@ async def test_index_repo(path, query): shutil.rmtree(index_path) +@pytest.mark.parametrize( + ("paths", "path_type", "root"), + [ + (["/data/uploads"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/uploads/"], UPLOADS_INDEX_ROOT, "/data/uploads"), + (["/data/chats/1/1.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/1/2.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"), + (["/data/chats/2/2.txt", "/data/chats/2/2.txt"], str(Path(CHATS_INDEX_ROOT) / "2"), "/data/chats/2"), + (["/data/chats.txt"], "other", ""), + ], +) +def test_classify_path(paths, path_type, root): + result, result_root = IndexRepo.classify_path(paths) + assert path_type in set(result.keys()) + assert root == result_root.get(path_type, "") + + if __name__ == "__main__": pytest.main([__file__, "-s"])