feat: +Editor.search_index_repo

This commit is contained in:
莘权 马 2024-09-05 17:21:27 +08:00
parent 1eb3b8fb8c
commit 4523615dd9
4 changed files with 177 additions and 5 deletions

View file

@ -3,6 +3,7 @@ This file is borrowed from OpenDevin
You can find the original repository here:
https://github.com/All-Hands-AI/OpenHands/blob/main/openhands/runtime/plugins/agent_skills/file_ops/file_ops.py
"""
import asyncio
import base64
import os
import re
@ -16,6 +17,7 @@ from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils import read_docx
@ -951,3 +953,52 @@ class Editor(BaseModel):
if not path.is_absolute():
path = self.working_dir / path
return path
@staticmethod
async def search_index_repo(
query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = 0
) -> List[str]:
"""Searches the index repository for a given query across specified files or paths.
This method classifies the provided files or paths, performing a search on each cluster
of files while handling other types of files separately. It merges results from structured
indices with any results from non-indexed files.
Args:
query (str): The search query string to look for in the indexed files.
files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within.
min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0.
Returns:
List[str]: A list of search results as strings, containing the text from the merged results
and any direct results from other files.
"""
clusters, roots = IndexRepo.classify_path(files_or_paths)
futures = []
others = set()
for persist_path, filenames in clusters.items():
if persist_path == OTHER_TYPE:
others.update(filenames)
continue
root = roots[persist_path]
repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count)
futures.append(repo.search(query=query, filenames=list(filenames)))
for i in others:
futures.append(aread(filename=i))
futures_results = []
if futures:
futures_results = await asyncio.gather(*futures)
result = []
v_result = []
for i in futures_results:
if isinstance(i, str):
result.append(i)
else:
v_result.append(i)
repo = IndexRepo(min_token_count=min_token_count)
merged = await repo.merge(query=query, indices_list=v_result)
return [i.text for i in merged] + result

View file

@ -2,8 +2,9 @@
# -*- coding: utf-8 -*-
import json
import re
from pathlib import Path
from typing import Dict, List, Optional, Set, Union
from typing import Dict, List, Optional, Set, Tuple, Union
import tiktoken
from llama_index.core.base.embeddings.base import BaseEmbedding
@ -18,6 +19,14 @@ from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRanker
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
from metagpt.utils.repo_to_markdown import is_text_file
UPLOADS_INDEX_ROOT = "/data/.index/uploads"
DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT
UPLOAD_ROOT = "/data/uploads"
DEFAULT_ROOT = UPLOAD_ROOT
CHATS_INDEX_ROOT = "/data/.index/chats"
CHATS_ROOT = "/data/chats/"
OTHER_TYPE = "other"
class TextScore(BaseModel):
filename: str
@ -26,8 +35,10 @@ class TextScore(BaseModel):
class IndexRepo(BaseModel):
persist_path: str # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/`
root_path: str # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/`
root_path: str = (
DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
)
fingerprint_filename: str = "fingerprint.json"
model: Optional[str] = None
min_token_count: int = 10000
@ -93,6 +104,10 @@ class IndexRepo(BaseModel):
Returns:
List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity.
"""
flat_nodes = [node for indices in indices_list for node in indices]
if len(flat_nodes) <= self.recall_count:
return flat_nodes
if not self.embedding:
config = Config.default()
if self.model:
@ -102,7 +117,6 @@ class IndexRepo(BaseModel):
scores = []
query_embedding = await self.embedding.aget_text_embedding(query)
flat_nodes = [node for indices in indices_list for node in indices]
for i in flat_nodes:
text_embedding = await self.embedding.aget_text_embedding(i.text)
similarity = self.embedding.similarity(query_embedding, text_embedding)
@ -262,3 +276,33 @@ class IndexRepo(BaseModel):
return True
fp = generate_fingerprint(content)
return old_fp != fp
@staticmethod
def classify_path(files_or_paths: List[Union[str, Path]]) -> Tuple[Dict[str, Set[Path]], Dict[str, str]]:
mappings = {
UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"),
CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"),
}
clusters = {}
roots = {}
for i in files_or_paths:
path = Path(i).absolute()
path_type = OTHER_TYPE
for type_, pattern in mappings.items():
if re.match(pattern, str(i)):
path_type = type_
break
if path_type == CHATS_INDEX_ROOT:
chat_id = path.parts[3]
path_type = str(Path(path_type) / chat_id)
roots[path_type] = str(Path(CHATS_ROOT) / chat_id)
elif path_type == UPLOADS_INDEX_ROOT:
roots[path_type] = UPLOAD_ROOT
if path_type in clusters:
clusters[path_type].add(path)
else:
clusters[path_type] = {path}
return clusters, roots

View file

@ -1,7 +1,19 @@
import os
import shutil
from pathlib import Path
import pytest
from metagpt.const import TEST_DATA_PATH
from metagpt.tools.libs.editor import Editor
from metagpt.tools.libs.index_repo import (
CHATS_INDEX_ROOT,
CHATS_ROOT,
UPLOAD_ROOT,
UPLOADS_INDEX_ROOT,
IndexRepo,
)
from metagpt.utils.common import list_files
TEST_FILE_CONTENT = """
# this is line one
@ -645,5 +657,47 @@ def test_append_to_single_empty_line_file():
assert n_added_lines == 1
async def mock_index_repo():
chat_id = "1"
chat_path = Path(CHATS_ROOT) / chat_id
chat_path.mkdir(parents=True, exist_ok=True)
src_path = TEST_DATA_PATH / "requirements"
command = f"cp -rf {str(src_path)} {str(chat_path)}"
os.system(command)
filenames = list_files(chat_path)
chat_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
chat_repo = IndexRepo(
persist_path=str(Path(CHATS_INDEX_ROOT) / chat_id), root_path=str(chat_path), min_token_count=0
)
await chat_repo.add(chat_files)
Path(UPLOAD_ROOT).mkdir(parents=True, exist_ok=True)
command = f"cp -rf {str(src_path)} {str(UPLOAD_ROOT)}"
os.system(command)
filenames = list_files(UPLOAD_ROOT)
uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0)
await uploads_repo.add(uploads_files)
filenames = list_files(src_path)
other_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json"}]
return chat_files, uploads_files, other_files
@pytest.mark.skip
@pytest.mark.asyncio
async def test_index_repo():
# mock data
chat_files, uploads_files, other_files = await mock_index_repo()
editor = Editor()
rsp = await editor.vsearch(query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0)
assert rsp
shutil.rmtree(CHATS_ROOT)
shutil.rmtree(UPLOAD_ROOT)
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -1,11 +1,17 @@
import shutil
from pathlib import Path
import pytest
from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH
from metagpt.tools.libs.index_repo import IndexRepo
from metagpt.tools.libs.index_repo import (
CHATS_INDEX_ROOT,
UPLOADS_INDEX_ROOT,
IndexRepo,
)
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")])
async def test_index_repo(path, query):
@ -28,5 +34,22 @@ async def test_index_repo(path, query):
shutil.rmtree(index_path)
@pytest.mark.parametrize(
("paths", "path_type", "root"),
[
(["/data/uploads"], UPLOADS_INDEX_ROOT, "/data/uploads"),
(["/data/uploads/"], UPLOADS_INDEX_ROOT, "/data/uploads"),
(["/data/chats/1/1.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"),
(["/data/chats/1/2.txt"], str(Path(CHATS_INDEX_ROOT) / "1"), "/data/chats/1"),
(["/data/chats/2/2.txt", "/data/chats/2/2.txt"], str(Path(CHATS_INDEX_ROOT) / "2"), "/data/chats/2"),
(["/data/chats.txt"], "other", ""),
],
)
def test_classify_path(paths, path_type, root):
result, result_root = IndexRepo.classify_path(paths)
assert path_type in set(result.keys())
assert root == result_root.get(path_type, "")
if __name__ == "__main__":
pytest.main([__file__, "-s"])