diff --git a/metagpt/roles/architect.py b/metagpt/roles/architect.py index ccce75afa..b990808b3 100644 --- a/metagpt/roles/architect.py +++ b/metagpt/roles/architect.py @@ -41,7 +41,7 @@ class Architect(RoleZero): instruction: str = ARCHITECT_INSTRUCTION max_react_loop: int = 1 # FIXME: Read and edit files requires more steps, consider later tools: list[str] = [ - "Editor:write,read,write_content", + "Editor:write,read,write_content,similarity_search", "RoleZero", "WriteDesign", ] diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index f9bead1ac..54ee3864b 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -30,8 +30,15 @@ class DataAnalyst(RoleZero): instruction: str = ROLE_INSTRUCTION + EXTRA_INSTRUCTION task_type_desc: str = TASK_TYPE_DESC - tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser", "Editor:write,read", "SearchEnhancedQA"] - custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read"] + tools: list[str] = [ + "Plan", + "DataAnalyst", + "RoleZero", + "Browser", + "Editor:write,read,similarity_search", + "SearchEnhancedQA", + ] + custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read,similarity_search"] custom_tool_recommender: ToolRecommender = None experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 7cbac2d04..f636ecbc3 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -158,6 +158,7 @@ class RoleZero(Role): "scroll_up", "search_dir", "search_file", + "similarity_search", # "set_workdir", "write", ] diff --git a/metagpt/roles/project_manager.py b/metagpt/roles/project_manager.py index 228b38660..d64fb90a1 100644 --- a/metagpt/roles/project_manager.py +++ b/metagpt/roles/project_manager.py @@ -31,7 +31,7 @@ class ProjectManager(RoleZero): instruction: str = """Use WriteTasks tool to write a project task list""" max_react_loop: int = 1 # FIXME: Read and edit files requires more steps, consider later - tools: list[str] = ["Editor:write,read,write_content", "RoleZero", "WriteTasks"] + tools: list[str] = ["Editor:write,read,write_content,similarity_search", "RoleZero", "WriteTasks"] def __init__(self, **kwargs) -> None: super().__init__(**kwargs) diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index ca8cd9ccf..225697a60 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -10,10 +10,11 @@ import tempfile from pathlib import Path from typing import List, Optional, Union +import tiktoken from pydantic import BaseModel, ConfigDict from metagpt.const import DEFAULT_WORKSPACE_ROOT -from metagpt.tools.libs.index_repo import IndexRepo +from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils.file import File @@ -128,9 +129,18 @@ class Editor(BaseModel): async def read(self, path: str) -> FileBlock: """Read the whole content of a file. Using absolute paths as the argument for specifying the file location.""" + error = FileBlock( + file_path=str(path), + block_content="The file is too large to read. Use `Editor.similarity_search` to read the file instead.", + ) + path = Path(path) + if path.stat().st_size > 5 * DEFAULT_MIN_TOKEN_COUNT: + return error content = await File.read_text_file(path) if not content: return FileBlock(file_path=str(path), block_content="") + if self.is_large_file(content=content): + return error self.resource.report(str(path), "path") lines = content.splitlines(keepends=True) @@ -1086,19 +1096,33 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]: - """Searches the index repository for a given query across specified files or paths. + async def similarity_search(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Given a filename or a pathname, performs a similarity search for a given query across the specified file or path. - This method classifies the provided files or paths, performing a search on each cluster - of files while handling other types of files separately. It merges results from structured - indices with any results from non-indexed files. + This method searches the index repository for the provided query, classifying the specified + files or paths. It performs a search on each cluster of files and handles non-indexed files + separately, merging results from structured indices with any direct results from non-indexed files. + This function call does not depend on other functions. Args: query (str): The search query string to look for in the indexed files. - file_or_path (Union[str, Path]): A path or a filename to search within. + file_or_path (Union[str, Path]): A pathname or filename to search within. Returns: - List[str]: A list of search results as strings, containing the text from the merged results - and any direct results from other files. + List[str]: A list of results as strings, containing the text from the merged results + and any direct results from non-indexed files. + + Example: + >>> query = "The problem to be analyzed from the document" + >>> file_or_path = "The pathname or filename you want to search within" + >>> texts: List[str] = await Editor.similarity_search(query=query, file_or_path=file_or_path) + >>> print(texts) """ return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path) + + @staticmethod + def is_large_file(content: str, mix_token_count: int = 0) -> bool: + encoding = tiktoken.get_encoding("cl100k_base") + token_count = len(encoding.encode(content)) + mix_token_count = mix_token_count or DEFAULT_MIN_TOKEN_COUNT + return token_count >= mix_token_count diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 24065c4be..4c4e6c59b 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -12,12 +12,14 @@ from llama_index.core.schema import NodeWithScore from pydantic import BaseModel, Field, model_validator from metagpt.config2 import Config +from metagpt.context import Context from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files from metagpt.utils.file import File +from metagpt.utils.report import EditorReporter UPLOADS_INDEX_ROOT = "/data/.index/uploads" DEFAULT_INDEX_ROOT = UPLOADS_INDEX_ROOT @@ -45,7 +47,7 @@ class TextScore(BaseModel): class IndexRepo(BaseModel): persist_path: str = DEFAULT_INDEX_ROOT # The persist path of the index repo, `/data/.index/uploads/` or `/data/.index/chats/{chat_id}/` root_path: str = ( - DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. + DEFAULT_ROOT # `/data/uploads` or r`/data/chats/[a-z0-9]+`, the root path of files indexed by the index repo. ) fingerprint_filename: str = "fingerprint.json" meta_filename: str = "meta.json" @@ -88,9 +90,19 @@ class IndexRepo(BaseModel): filenames, excludes = await self._filter(filenames) if not filenames: raise ValueError(f"Unsupported file types: {[str(i) for i in excludes]}") + resource = EditorReporter() + for i in filenames: + await resource.async_report(str(i), "path") filter_filenames = set() meta = await self._read_meta() + new_files = {} for i in filenames: + if Path(i).suffix.lower() in {".pdf", ".doc", ".docx"}: + if str(i) not in self.fingerprints: + new_files[i] = "" + logger.warning(f'file: "{i}" not indexed') + filter_filenames.add(str(i)) + continue content = await File.read_text_file(i) token_count = len(encoding.encode(content)) if not self._is_buildable( @@ -99,10 +111,17 @@ class IndexRepo(BaseModel): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) - if self.fingerprints.get(str(i)) != file_fingerprint and Path(i).suffix.lower() not in {".pdf"}: - logger.error(f'file: "{i}" changed but not indexed') + if str(i) not in self.fingerprints or (self.fingerprints.get(str(i)) != file_fingerprint): + new_files[i] = content + logger.warning(f'file: "{i}" changed but not indexed') continue filter_filenames.add(str(i)) + if new_files: + added, others = await self.add(paths=list(new_files.keys()), file_datas=new_files) + filter_filenames.update([str(i) for i in added]) + for i in others: + result.append(TextScore(filename=str(i), text=new_files.get(i))) + filter_filenames.discard(str(i)) nodes = await self._search(query=query, filters=filter_filenames) return result + nodes @@ -132,24 +151,48 @@ class IndexRepo(BaseModel): scores = [] query_embedding = await self.embedding.aget_text_embedding(query) for i in flat_nodes: - text_embedding = await self.embedding.aget_text_embedding(i.text) + try: + text_embedding = await self.embedding.aget_text_embedding(i.text) + except Exception as e: # 超过最大长度 + tenth = int(len(i.text) / 10) # DEFAULT_MIN_TOKEN_COUNT = 10000 + logger.warning( + f"{e}, tenth len={tenth}, pre_part_len={len(i.text[: tenth * 6])}, post_part_len={len(i.text[tenth * 4:])}" + ) + pre_win_part = await self.embedding.aget_text_embedding(i.text[: tenth * 6]) + post_win_part = await self.embedding.aget_text_embedding(i.text[tenth * 4 :]) + similarity = max( + self.embedding.similarity(query_embedding, pre_win_part), + self.embedding.similarity(query_embedding, post_win_part), + ) + scores.append((similarity, i)) + continue similarity = self.embedding.similarity(query_embedding, text_embedding) scores.append((similarity, i)) scores.sort(key=lambda x: x[0], reverse=True) return [i[1] for i in scores][: self.recall_count] - async def add(self, paths: List[Path]): + async def add( + self, paths: List[Path], file_datas: Dict[Union[str, Path], str] = None + ) -> Tuple[List[str], List[str]]: """Add new documents to the index. Args: paths (List[Path]): A list of paths to the documents to be added. + file_datas (Dict[Union[str, Path], str]): A list of file content. + + Returns: + Tuple[List[str], List[str]]: A tuple containing two lists: + 1. The list of filenames that were successfully added to the index. + 2. The list of filenames that were not added to the index because they were not buildable. """ encoding = tiktoken.get_encoding("cl100k_base") filenames, _ = await self._filter(paths) filter_filenames = [] delete_filenames = [] + file_datas = file_datas or {} for i in filenames: - content = await File.read_text_file(i) + content = file_datas.get(i) or await File.read_text_file(i) + file_datas[i] = content if not self._is_fingerprint_changed(filename=i, content=content): continue token_count = len(encoding.encode(content)) @@ -159,9 +202,15 @@ class IndexRepo(BaseModel): else: delete_filenames.append(i) logger.debug(f"{i} not is_buildable: {token_count}, {self.min_token_count}~{self.max_token_count}") - await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames) + await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames, file_datas=file_datas) + return filter_filenames, delete_filenames - async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]): + async def _add_batch( + self, + filenames: List[Union[str, Path]], + delete_filenames: List[Union[str, Path]], + file_datas: Dict[Union[str, Path], str], + ): """Add and remove documents in a batch operation. Args: @@ -172,6 +221,7 @@ class IndexRepo(BaseModel): return logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") engine = None + Context() if Path(self.persist_path).exists(): logger.debug(f"load index from {self.persist_path}") engine = SimpleEngine.from_index( @@ -180,9 +230,9 @@ class IndexRepo(BaseModel): ) try: engine.delete_docs(filenames + delete_filenames) - logger.debug(f"delete docs {filenames + delete_filenames}") + logger.info(f"delete docs {filenames + delete_filenames}") engine.add_docs(input_files=filenames) - logger.debug(f"add docs {filenames}") + logger.info(f"add docs {filenames}") except NotImplementedError as e: logger.debug(f"{e}") filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys()))) @@ -194,10 +244,10 @@ class IndexRepo(BaseModel): retriever_configs=[FAISSRetrieverConfig()], ranker_configs=[LLMRankerConfig()], ) - logger.debug(f"add docs {filenames}") + logger.info(f"add docs {filenames}") engine.persist(persist_dir=self.persist_path) for i in filenames: - content = await File.read_text_file(i) + content = file_datas.get(i) or await File.read_text_file(i) fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) @@ -270,10 +320,14 @@ class IndexRepo(BaseModel): Returns: List[NodeWithScore]: A list of nodes with scores matching the query. """ + if not filters: + return [] if not Path(self.persist_path).exists(): raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") + Context() engine = SimpleEngine.from_index( - index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], ) rsp = await engine.aretrieve(query) return [i for i in rsp if i.metadata.get("file_path") in filters] @@ -308,7 +362,7 @@ class IndexRepo(BaseModel): """ mappings = { UPLOADS_INDEX_ROOT: re.compile(r"^/data/uploads($|/.*)"), - CHATS_INDEX_ROOT: re.compile(r"^/data/chats/\d+($|/.*)"), + CHATS_INDEX_ROOT: re.compile(r"^/data/chats/[a-z0-9]+($|/.*)"), } clusters = {} @@ -396,6 +450,8 @@ class IndexRepo(BaseModel): result = [] v_result = [] for i in futures_results: + if not i: + continue if isinstance(i, str): result.append(i) else: diff --git a/tests/metagpt/roles/di/test_data_analyst.py b/tests/metagpt/roles/di/test_data_analyst.py new file mode 100644 index 000000000..0f285ecd7 --- /dev/null +++ b/tests/metagpt/roles/di/test_data_analyst.py @@ -0,0 +1,21 @@ +import pytest + +from metagpt.const import TEST_DATA_PATH +from metagpt.roles.di.data_analyst import DataAnalyst + + +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "filename"), [("similarity search about '有哪些需求描述?' in document ", TEST_DATA_PATH / "requirements/2.pdf")] +) +async def test_similarity_search(query, filename): + di = DataAnalyst() + query += f"'{str(filename)}'" + + rsp = await di.run(query) + assert rsp + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index b56f7bf0e..c6a45da06 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -9,8 +9,8 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.libs.index_repo import ( CHATS_INDEX_ROOT, CHATS_ROOT, + DEFAULT_MIN_TOKEN_COUNT, UPLOAD_ROOT, - UPLOADS_INDEX_ROOT, IndexRepo, ) from metagpt.utils.common import list_files @@ -756,8 +756,6 @@ async def mock_index_repo(): os.system(command) filenames = list_files(UPLOAD_ROOT) uploads_files = [i for i in filenames if Path(i).suffix in {".md", ".txt", ".json", ".pdf"}] - uploads_repo = IndexRepo(persist_path=UPLOADS_INDEX_ROOT, root_path=UPLOAD_ROOT, min_token_count=0) - await uploads_repo.add(uploads_files) assert uploads_files filenames = list_files(src_path) @@ -771,19 +769,63 @@ async def mock_index_repo(): @pytest.mark.asyncio async def test_index_repo(): # mock data - chat_path, UPLOAD_ROOT, src_path = await mock_index_repo() + chat_path, upload_path, src_path = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path) + rsp = await editor.similarity_search(query="业务线", file_or_path=chat_path) assert rsp - rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT) + rsp = await editor.similarity_search(query="业务线", file_or_path=upload_path) assert rsp - rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path) + rsp = await editor.similarity_search(query="业务线", file_or_path=src_path) assert rsp shutil.rmtree(CHATS_ROOT) shutil.rmtree(UPLOAD_ROOT) +@pytest.mark.skip +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("query", "filename"), + [ + ( + "In this document, who are the legal representatives of both parties?", + TEST_DATA_PATH / "pdf/20210709逗你学云豆付费课程协议.pdf", + ), + ( + "What is the short name of the company in this document?", + TEST_DATA_PATH / "pdf/company_stock_code.pdf", + ), + ("平安创新推出中国版的什么模式,将差异化的医疗健康服务与作为支付方的金融业务无缝结合", TEST_DATA_PATH / "pdf/9112674.pdf"), + ( + "What principle is introduced by the author to explain the conditions necessary for the emergence of complexity?", + TEST_DATA_PATH / "pdf/9781444323498.ch2_1.pdf", + ), + ("行高的继承性的代码示例是?", TEST_DATA_PATH / "pdf/02-CSS.pdf"), + ], +) +async def test_similarity_search(query, filename): + filename = Path(filename) + save_to = Path(UPLOAD_ROOT) / filename.name + save_to.parent.mkdir(parents=True, exist_ok=True) + os.system(f"cp {str(filename)} {str(save_to)}") + + editor = Editor() + rsp = await editor.similarity_search(query=query, file_or_path=save_to) + assert rsp + + save_to.unlink(missing_ok=True) + + +@pytest.mark.skip +@pytest.mark.asyncio +async def test_read(): + editor = Editor() + filename = TEST_DATA_PATH / "pdf/9112674.pdf" + content = await editor.read(str(filename)) + size = filename.stat().st_size + assert "similarity_search" in content.block_content and size > 5 * DEFAULT_MIN_TOKEN_COUNT + + if __name__ == "__main__": pytest.main([__file__, "-s"])