feat: search_index_repo path or filename

This commit is contained in:
莘权 马 2024-09-06 15:35:47 +08:00
parent 95bf4c3e22
commit 6a57cb5e0a
3 changed files with 24 additions and 11 deletions

View file

@ -18,6 +18,7 @@ from metagpt.logs import logger
from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import list_files
from metagpt.utils.file import File
from metagpt.utils.report import EditorReporter
@ -866,7 +867,7 @@ class Editor(BaseModel):
return path
@staticmethod
async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]:
async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]:
"""Searches the index repository for a given query across specified files or paths.
This method classifies the provided files or paths, performing a search on each cluster
@ -875,13 +876,16 @@ class Editor(BaseModel):
Args:
query (str): The search query string to look for in the indexed files.
files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within.
file_or_path (Union[str, Path]): A path or a filename to search within.
Returns:
List[str]: A list of search results as strings, containing the text from the merged results
and any direct results from other files.
"""
clusters, roots = IndexRepo.classify_path(files_or_paths)
if not file_or_path or not Path(file_or_path).exists():
raise ValueError(f'"{str(file_or_path)}" not exists')
files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path)
clusters, roots = IndexRepo.classify_path(files)
futures = []
others = set()
for persist_path, filenames in clusters.items():

View file

@ -85,7 +85,9 @@ class IndexRepo(BaseModel):
"""
encoding = tiktoken.get_encoding("cl100k_base")
result: List[Union[NodeWithScore, TextScore]] = []
filenames, _ = await self._filter(filenames)
filenames, excludes = await self._filter(filenames)
if not filenames:
raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}")
filter_filenames = set()
meta = await self._read_meta()
for i in filenames:
@ -269,7 +271,7 @@ class IndexRepo(BaseModel):
List[NodeWithScore]: A list of nodes with scores matching the query.
"""
if not Path(self.persist_path).exists():
return []
raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.")
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()]
)
@ -293,11 +295,11 @@ class IndexRepo(BaseModel):
return old_fp != fp
@staticmethod
def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]:
def classify_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]:
"""Classify a list of file paths or Path objects into different categories.
Args:
files_or_paths (List[Union[str, Path]]): A list of file paths or Path objects to be classified.
files (List[Union[str, Path]]): A list of file paths or Path objects to be classified.
Returns:
Tuple[Dict[str, Set[Path]], Dict[str, str]]:
@ -311,7 +313,7 @@ class IndexRepo(BaseModel):
clusters = {}
roots = {}
for i in files_or_paths:
for i in files:
path = Path(i).absolute()
path_type = OTHER_TYPE
for type_, pattern in mappings.items():

View file

@ -670,6 +670,7 @@ async def mock_index_repo():
persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0
)
await chat_repo.add(chat_files)
assert chat_files
Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True)
command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}"
@ -678,21 +679,27 @@ async def mock_index_repo():
uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0)
await uploads_repo.add(uploads_files)
assert uploads_files
filenames = list_files(src_path)
other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
assert other_files
return chat_files, uploads_files, other_files
return chat_path, UPLOAD_ROOT, src_path
@pytest.mark.skip
@pytest.mark.asyncio
async def test_index_repo():
# mock data
chat_files, uploads_files, other_files = await mock_index_repo()
chat_path, UPLOAD_ROOT, src_path = await mock_index_repo()
editor = Editor()
rsp = await editor.search_index_repo(query="业务线", files_or_paths=chat_files + uploads_files + other_files)
rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path)
assert rsp
rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT)
assert rsp
rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path)
assert rsp
shutil.rmtree(CHATS_ROOT)