Merge branch 'fixbug/index_max_token' into 'mgx_ops'

feat: IndexRepo + config

See merge request pub/MetaGPT!376
This commit is contained in:
林义章 2024-09-25 07:14:55 +00:00
commit a9ec57dbbb
8 changed files with 185 additions and 34 deletions

View file

@ -41,7 +41,7 @@ class Architect(RoleZero):
instruction: str = ARCHITECT_INSTRUCTION
max_react_loop: int = 1 # FIXME: Read and edit files requires more steps, consider later
tools: list[str] = [
"Editor:write,read,write_content",
"Editor:write,read,write_content,similarity_search",
"RoleZero",
"WriteDesign",
]

View file

@ -30,8 +30,15 @@ class DataAnalyst(RoleZero):
instruction: str = ROLE_INSTRUCTION + EXTRA_INSTRUCTION
task_type_desc: str = TASK_TYPE_DESC
tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser", "Editor:write,read", "SearchEnhancedQA"]
custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read"]
tools: list[str] = [
"Plan",
"DataAnalyst",
"RoleZero",
"Browser",
"Editor:write,read,similarity_search",
"SearchEnhancedQA",
]
custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read,similarity_search"]
custom_tool_recommender: ToolRecommender = None
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever()

View file

@ -158,6 +158,7 @@ class RoleZero(Role):
"scroll_up",
"search_dir",
"search_file",
"similarity_search",
# "set_workdir",
"write",
]

View file

@ -31,7 +31,7 @@ class ProjectManager(RoleZero):
instruction: str = """Use WriteTasks tool to write a project task list"""
max_react_loop: int = 1 # FIXME: Read and edit files requires more steps, consider later
tools: list[str] = ["Editor:write,read,write_content", "RoleZero", "WriteTasks"]
tools: list[str] = ["Editor:write,read,write_content,similarity_search", "RoleZero", "WriteTasks"]
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)

View file

@ -10,10 +10,11 @@ import tempfile
from pathlib import Path
from typing import List, Optional, Union
import tiktoken
from pydantic import BaseModel, ConfigDict
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.tools.libs.index_repo import IndexRepo
from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.file import File
@ -128,9 +129,18 @@ class Editor(BaseModel):
async def read(self, path: str) -> FileBlock:
"""Read the whole content of a file. Using absolute paths as the argument for specifying the file location."""
error = FileBlock(
file_path=str(path),
block_content="The file is too large to read. Use `Editor.similarity_search` to read the file instead.",
)
path = Path(path)
if path.stat().st_size > 5 * DEFAULT_MIN_TOKEN_COUNT:
return error
content = await File.read_text_file(path)
if not content:
return FileBlock(file_path=str(path), block_content="")
if self.is_large_file(content=content):
return error
self.resource.report(str(path), "path")
lines = content.splitlines(keepends=True)
@ -1086,19 +1096,33 @@ class Editor(BaseModel):
return path
@staticmethod
async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]:
"""Searches the index repository for a given query across specified files or paths.
async def similarity_search(query: str, file_or_path: Union[str, Path]) -> List[str]:
"""Given a filename or a pathname, performs a similarity search for a given query across the specified file or path.
This method classifies the provided files or paths, performing a search on each cluster
of files while handling other types of files separately. It merges results from structured
indices with any results from non-indexed files.
This method searches the index repository for the provided query, classifying the specified
files or paths. It performs a search on each cluster of files and handles non-indexed files
separately, merging results from structured indices with any direct results from non-indexed files.
This function call does not depend on other functions.
Args:
query (str): The search query string to look for in the indexed files.
file_or_path (Union[str, Path]): A path or a filename to search within.
file_or_path (Union[str, Path]): A pathname or filename to search within.
Returns:
List[str]: A list of search results as strings, containing the text from the merged results
and any direct results from other files.
List[str]: A list of results as strings, containing the text from the merged results
and any direct results from non-indexed files.
Example:
>>> query = "The problem to be analyzed from the document"
>>> file_or_path = "The pathname or filename you want to search within"
>>> texts: List[str] = await Editor.similarity_search(query=query, file_or_path=file_or_path)
>>> print(texts)
"""
return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path)
@staticmethod
def is_large_file(content: str, mix_token_count: int = 0) -> bool:
encoding = tiktoken.get_encoding("cl100k_base")
token_count = len(encoding.encode(content))
mix_token_count = mix_token_count or DEFAULT_MIN_TOKEN_COUNT
return token_count >= mix_token_count

View file

@ -12,12 +12,14 @@ from llama_index.core.schema import NodeWithScore
from pydantic import BaseModel, Field, model_validator
from metagpt.config2 import Config
from metagpt.context import Context
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
from metagpt.utils.file import File
from metagpt.utils.report import EditorReporter
UPLOADS_INDEX_ROOT = "/data/.index/uploads"
DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT
@ -45,7 +47,7 @@ class TextScore(BaseModel):
class IndexRepo(BaseModel):
persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/`
root_path: str = (
DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
DEFAULT_ROOT # `/data/uploads` or r`/data/chats/[a-z0-9]+`, the root path of files indexed by the index repo.
)
fingerprint_filename: str = "fingerprint.json"
meta_filename: str = "meta.json"
@ -88,9 +90,19 @@ class IndexRepo(BaseModel):
filenames, excludes = await self._filter(filenames)
if not filenames:
raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}")
resource = EditorReporter()
for i in filenames:
await resource.async_report(str(i), "path")
filter_filenames = set()
meta = await self._read_meta()
new_files = {}
for i in filenames:
if Path(i).suffix.lower() in {".pdf", ".doc", ".docx"}:
if str(i) not in self.fingerprints:
new_files[i] = ""
logger.warning(f'file: "{i}" not indexed')
filter_filenames.add(str(i))
continue
content = await File.read_text_file(i)
token_count = len(encoding.encode(content))
if not self._is_buildable(
@ -99,10 +111,17 @@ class IndexRepo(BaseModel):
result.append(TextScore(filename=str(i), text=content))
continue
file_fingerprint = generate_fingerprint(content)
if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}:
logger.error(f'file: "{i}" changed but not indexed')
if str(i) not in self.fingerprints or (self.fingerprints.get(str(i)) != file_fingerprint):
new_files[i] = content
logger.warning(f'file: "{i}" changed but not indexed')
continue
filter_filenames.add(str(i))
if new_files:
added, others = await self.add(paths=list(new_files.keys()), file_datas=new_files)
filter_filenames.update([str(i) for i in added])
for i in others:
result.append(TextScore(filename=str(i), text=new_files.get(i)))
filter_filenames.discard(str(i))
nodes = await self._search(query=query, filters=filter_filenames)
return result + nodes
@ -132,24 +151,48 @@ class IndexRepo(BaseModel):
scores = []
query_embedding = await self.embedding.aget_text_embedding(query)
for i in flat_nodes:
text_embedding = await self.embedding.aget_text_embedding(i.text)
try:
text_embedding = await self.embedding.aget_text_embedding(i.text)
except Exception as e: # 超过最大长度
tenth = int(len(i.text) / 10) # DEFAULT_MIN_TOKEN_COUNT = 10000
logger.warning(
f"{e}, tenth len={tenth}, pre_part_len={len(i.text[: tenth * 6])}, post_part_len={len(i.text[tenth * 4:])}"
)
pre_win_part = await self.embedding.aget_text_embedding(i.text[: tenth * 6])
post_win_part = await self.embedding.aget_text_embedding(i.text[tenth * 4 :])
similarity = max(
self.embedding.similarity(query_embedding, pre_win_part),
self.embedding.similarity(query_embedding, post_win_part),
)
scores.append((similarity, i))
continue
similarity = self.embedding.similarity(query_embedding, text_embedding)
scores.append((similarity, i))
scores.sort(key=lambda x: x[0], reverse=True)
return [i[1] for i in scores][: self.recall_count]
async def add(self, paths: List[Path]):
async def add(
self, paths: List[Path], file_datas: Dict[Union[str, Path], str] = None
) -> Tuple[List[str], List[str]]:
"""Add new documents to the index.
Args:
paths (List[Path]): A list of paths to the documents to be added.
file_datas (Dict[Union[str, Path], str]): A list of file content.
Returns:
Tuple[List[str], List[str]]: A tuple containing two lists:
1. The list of filenames that were successfully added to the index.
2. The list of filenames that were not added to the index because they were not buildable.
"""
encoding = tiktoken.get_encoding("cl100k_base")
filenames, _ = await self._filter(paths)
filter_filenames = []
delete_filenames = []
file_datas = file_datas or {}
for i in filenames:
content = await File.read_text_file(i)
content = file_datas.get(i) or await File.read_text_file(i)
file_datas[i] = content
if not self._is_fingerprint_changed(filename=i, content=content):
continue
token_count = len(encoding.encode(content))
@ -159,9 +202,15 @@ class IndexRepo(BaseModel):
else:
delete_filenames.append(i)
logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}")
await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames)
await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames, file_datas=file_datas)
return filter_filenames, delete_filenames
async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]):
async def _add_batch(
self,
filenames: List[Union[str, Path]],
delete_filenames: List[Union[str, Path]],
file_datas: Dict[Union[str, Path], str],
):
"""Add and remove documents in a batch operation.
Args:
@ -172,6 +221,7 @@ class IndexRepo(BaseModel):
return
logger.info(f"update index repo, add {filenames}, remove {delete_filenames}")
engine = None
Context()
if Path(self.persist_path).exists():
logger.debug(f"load index from {self.persist_path}")
engine = SimpleEngine.from_index(
@ -180,9 +230,9 @@ class IndexRepo(BaseModel):
)
try:
engine.delete_docs(filenames + delete_filenames)
logger.debug(f"delete docs {filenames + delete_filenames}")
logger.info(f"delete docs {filenames + delete_filenames}")
engine.add_docs(input_files=filenames)
logger.debug(f"add docs {filenames}")
logger.info(f"add docs {filenames}")
except NotImplementedError as e:
logger.debug(f"{e}")
filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys())))
@ -194,10 +244,10 @@ class IndexRepo(BaseModel):
retriever_configs=[FAISSRetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
logger.debug(f"add docs {filenames}")
logger.info(f"add docs {filenames}")
engine.persist(persist_dir=self.persist_path)
for i in filenames:
content = await File.read_text_file(i)
content = file_datas.get(i) or await File.read_text_file(i)
fp = generate_fingerprint(content)
self.fingerprints[str(i)] = fp
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
@ -270,10 +320,14 @@ class IndexRepo(BaseModel):
Returns:
List[NodeWithScore]: A list of nodes with scores matching the query.
"""
if not filters:
return []
if not Path(self.persist_path).exists():
raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.")
Context()
engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()]
index_config=FAISSIndexConfig(persist_path=self.persist_path),
retriever_configs=[FAISSRetrieverConfig()],
)
rsp = await engine.aretrieve(query)
return [i for i in rsp if i.metadata.get("file_path") in filters]
@ -308,7 +362,7 @@ class IndexRepo(BaseModel):
"""
mappings = {
UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"),
CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"),
CHATS_INDEX_ROOT: re.compile(r"^/data/chats/[a-z0-9]+($|/.*)"),
}
clusters = {}
@ -396,6 +450,8 @@ class IndexRepo(BaseModel):
result = []
v_result = []
for i in futures_results:
if not i:
continue
if isinstance(i, str):
result.append(i)
else:

View file

@ -0,0 +1,21 @@
import pytest
from metagpt.const import TEST_DATA_PATH
from metagpt.roles.di.data_analyst import DataAnalyst
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "filename"), [("similarity search about '有哪些需求描述?' in document ", TEST_DATA_PATH / "requirements/2.pdf")]
)
async def test_similarity_search(query, filename):
di = DataAnalyst()
query += f"'{str(filename)}'"
rsp = await di.run(query)
assert rsp
if __name__ == "__main__":
pytest.main([__file__, "-s"])

View file

@ -9,8 +9,8 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.libs.index_repo import (
CHATS_INDEX_ROOT,
CHATS_ROOT,
DEFAULT_MIN_TOKEN_COUNT,
UPLOAD_ROOT,
UPLOADS_INDEX_ROOT,
IndexRepo,
)
from metagpt.utils.common import list_files
@ -756,8 +756,6 @@ async def mock_index_repo():
os.system(command)
filenames = list_files(UPLOAD_ROOT)
uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}]
uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0)
await uploads_repo.add(uploads_files)
assert uploads_files
filenames = list_files(src_path)
@ -771,19 +769,63 @@ async def mock_index_repo():
@pytest.mark.asyncio
async def test_index_repo():
# mock data
chat_path, UPLOAD_ROOT, src_path = await mock_index_repo()
chat_path, upload_path, src_path = await mock_index_repo()
editor = Editor()
rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path)
rsp = await editor.similarity_search(query="业务线", file_or_path=chat_path)
assert rsp
rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT)
rsp = await editor.similarity_search(query="业务线", file_or_path=upload_path)
assert rsp
rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path)
rsp = await editor.similarity_search(query="业务线", file_or_path=src_path)
assert rsp
shutil.rmtree(CHATS_ROOT)
shutil.rmtree(UPLOAD_ROOT)
@pytest.mark.skip
@pytest.mark.asyncio
@pytest.mark.parametrize(
("query", "filename"),
[
(
"In this document, who are the legal representatives of both parties?",
TEST_DATA_PATH / "pdf/20210709逗你学云豆付费课程协议.pdf",
),
(
"What is the short name of the company in this document?",
TEST_DATA_PATH / "pdf/company_stock_code.pdf",
),
("平安创新推出中国版的什么模式,将差异化的医疗健康服务与作为支付方的金融业务无缝结合", TEST_DATA_PATH / "pdf/9112674.pdf"),
(
"What principle is introduced by the author to explain the conditions necessary for the emergence of complexity?",
TEST_DATA_PATH / "pdf/9781444323498.ch2_1.pdf",
),
("行高的继承性的代码示例是?", TEST_DATA_PATH / "pdf/02-CSS.pdf"),
],
)
async def test_similarity_search(query, filename):
filename = Path(filename)
save_to = Path(UPLOAD_ROOT) / filename.name
save_to.parent.mkdir(parents=True, exist_ok=True)
os.system(f"cp {str(filename)} {str(save_to)}")
editor = Editor()
rsp = await editor.similarity_search(query=query, file_or_path=save_to)
assert rsp
save_to.unlink(missing_ok=True)
@pytest.mark.skip
@pytest.mark.asyncio
async def test_read():
editor = Editor()
filename = TEST_DATA_PATH / "pdf/9112674.pdf"
content = await editor.read(str(filename))
size = filename.stat().st_size
assert "similarity_search" in content.block_content and size > 5 * DEFAULT_MIN_TOKEN_COUNT
if __name__ == "__main__":
pytest.main([__file__, "-s"])