refactor: cross_repo_search

This commit is contained in:
莘权 马 2024-09-06 18:31:11 +08:00
parent 6a57cb5e0a
commit 85c1d07990
2 changed files with 60 additions and 40 deletions

View file

@ -3,7 +3,6 @@ This file is borrowed from OpenDevin
You can find the original repository here:
https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py
"""
import asyncio
import os
import re
import shutil
@ -15,10 +14,9 @@ from pydantic import BaseModel, ConfigDict
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo
from metagpt.tools.libs.index_repo import IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import list_files
from metagpt.utils.file import File
from metagpt.utils.report import EditorReporter
@ -882,35 +880,4 @@ class Editor(BaseModel):
List[str]: A list of search results as strings, containing the text from the merged results
and any direct results from other files.
"""
if not file_or_path or not Path(file_or_path).exists():
raise ValueError(f'"{str(file_or_path)}" not exists')
files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path)
clusters, roots = IndexRepo.classify_path(files)
futures = []
others = set()
for persist_path, filenames in clusters.items():
if persist_path == OTHER_TYPE:
others.update(filenames)
continue
root = roots[persist_path]
repo = IndexRepo(persist_path=persist_path, root_path=root)
futures.append(repo.search(query=query, filenames=list(filenames)))
for i in others:
futures.append(File.read_text_file(i))
futures_results = []
if futures:
futures_results = await asyncio.gather(*futures)
result = []
v_result = []
for i in futures_results:
if isinstance(i, str):
result.append(i)
else:
v_result.append(i)
repo = IndexRepo()
merged = await repo.merge(query=query, indices_list=v_result)
return [i.text for i in merged] + result
return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path)

View file

@ -1,6 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import json
import re
from pathlib import Path
@ -295,16 +295,16 @@ class IndexRepo(BaseModel):
return old_fp != fp
@staticmethod
def classify_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]:
"""Classify a list of file paths or Path objects into different categories.
def find_index_repo_path(files: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]:
"""Map the file path to the corresponding index repo.
Args:
files (List[Union[str, Path]]): A list of file paths or Path objects to be classified.
Returns:
Tuple[Dict[str, Set[Path]], Dict[str, str]]:
- A dictionary mapping the classified path types to sets of corresponding Path objects.
- A dictionary mapping the classified path types to their corresponding root directories.
- A dictionary mapping the index repo path to the files.
- A dictionary mapping the index repo path to their corresponding root directories.
"""
mappings = {
UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"),
@ -351,3 +351,56 @@ class IndexRepo(BaseModel):
except Exception as e:
logger.warning(f"Load meta error: {e}")
return default_meta
@staticmethod
async def cross_repo_search(query: str, file_or_path: Union[str, Path]) -> List[str]:
"""Search for a query across multiple repositories.
This asynchronous function searches for the specified query in files
located at the given path or file.
Args:
query (str): The search term to look for in the files.
file_or_path (Union[str, Path]): The path to the file or directory
where the search should be conducted. This can be a string path
or a Path object.
Returns:
List[str]: A list of strings containing the paths of files that
contain the query results.
Raises:
ValueError: If the query string is empty.
"""
if not file_or_path or not Path(file_or_path).exists():
raise ValueError(f'"{str(file_or_path)}" not exists')
files = [file_or_path] if not Path(file_or_path).is_dir() else list_files(file_or_path)
clusters, roots = IndexRepo.find_index_repo_path(files)
futures = []
others = set()
for persist_path, filenames in clusters.items():
if persist_path == OTHER_TYPE:
others.update(filenames)
continue
root = roots[persist_path]
repo = IndexRepo(persist_path=persist_path, root_path=root)
futures.append(repo.search(query=query, filenames=list(filenames)))
for i in others:
futures.append(File.read_text_file(i))
futures_results = []
if futures:
futures_results = await asyncio.gather(*futures)
result = []
v_result = []
for i in futures_results:
if isinstance(i, str):
result.append(i)
else:
v_result.append(i)
repo = IndexRepo()
merged = await repo.merge(query=query, indices_list=v_result)
return [i.text for i in merged] + result