diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index cf56eb292..6bb43458c 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.const import DEFAULT_WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo +from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool from metagpt.utils.file import File @@ -866,9 +866,7 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo( - query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = DEFAULT_MIN_TOKEN_COUNT - ) -> List[str]: + async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]: """Searches the index repository for a given query across specified files or paths. This method classifies the provided files or paths, performing a search on each cluster @@ -878,7 +876,6 @@ class Editor(BaseModel): Args: query (str): The search query string to look for in the indexed files. files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within. - min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0. Returns: List[str]: A list of search results as strings, containing the text from the merged results @@ -892,7 +889,7 @@ class Editor(BaseModel): others.update(filenames) continue root = roots[persist_path] - repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count) + repo = IndexRepo(persist_path=persist_path, root_path=root) futures.append(repo.search(query=query, filenames=list(filenames))) for i in others: @@ -910,6 +907,6 @@ class Editor(BaseModel): else: v_result.append(i) - repo = IndexRepo(min_token_count=min_token_count) + repo = IndexRepo() merged = await repo.merge(query=query, indices_list=v_result) return [i.text for i in merged] + result diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index d29254318..216304e93 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -16,7 +16,7 @@ from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig -from metagpt.utils.common import awrite, generate_fingerprint, list_files +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files from metagpt.utils.file import File UPLOADS_INDEX_ROOT = "/data/.index/uploads" @@ -31,6 +31,11 @@ DEFAULT_MIN_TOKEN_COUNT = 10000 DEFAULT_MAX_TOKEN_COUNT = 100000000 +class IndexRepoMeta(BaseModel): + min_token_count: int + max_token_count: int + + class TextScore(BaseModel): filename: str text: str @@ -43,6 +48,7 @@ class IndexRepo(BaseModel): DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo. ) fingerprint_filename: str = "fingerprint.json" + meta_filename: str = "meta.json" model: Optional[str] = None min_token_count: int = DEFAULT_MIN_TOKEN_COUNT max_token_count: int = DEFAULT_MAX_TOKEN_COUNT @@ -81,10 +87,13 @@ class IndexRepo(BaseModel): result: List[Union[NodeWithScore, TextScore]] = [] filenames, _ = await self._filter(filenames) filter_filenames = set() + meta = await self._read_meta() for i in filenames: content = await File.read_text_file(i) token_count = len(encoding.encode(content)) - if not self._is_buildable(token_count): + if not self._is_buildable( + token_count, min_token_count=meta.min_token_count, max_token_count=meta.max_token_count + ): result.append(TextScore(filename=str(i), text=content)) continue file_fingerprint = generate_fingerprint(content) @@ -190,6 +199,7 @@ class IndexRepo(BaseModel): fp = generate_fingerprint(content) self.fingerprints[str(i)] = fp await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + await self._save_meta() def __str__(self): """Return a string representation of the IndexRepo. @@ -199,7 +209,7 @@ class IndexRepo(BaseModel): """ return f"{self.persist_path}" - def _is_buildable(self, token_count: int) -> bool: + def _is_buildable(self, token_count: int, min_token_count: int = -1, max_token_count=-1) -> bool: """Check if the token count is within the buildable range. Args: @@ -208,7 +218,9 @@ class IndexRepo(BaseModel): Returns: bool: True if buildable, False otherwise. """ - if token_count < self.min_token_count or token_count > self.max_token_count: + min_token_count = min_token_count if min_token_count >= 0 else self.min_token_count + max_token_count = max_token_count if max_token_count >= 0 else self.max_token_count + if token_count < min_token_count or token_count > max_token_count: return False return True @@ -319,3 +331,21 @@ class IndexRepo(BaseModel): clusters[path_type] = {path} return clusters, roots + + async def _save_meta(self): + meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + await awrite(filename=Path(self.persist_path) / self.meta_filename, data=meta.model_dump_json()) + + async def _read_meta(self) -> IndexRepoMeta: + default_meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count) + + filename = Path(self.persist_path) / self.meta_filename + if not filename.exists(): + return default_meta + meta_data = await aread(filename=filename) + try: + meta = IndexRepoMeta.model_validate_json(meta_data) + return meta + except Exception as e: + logger.warning(f"Load meta error: {e}") + return default_meta diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 3a4cf65fe..20857b97f 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -692,9 +692,7 @@ async def test_index_repo(): chat_files, uploads_files, other_files = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo( - query="业务线", files_or_paths=chat_files + uploads_files + other_files, min_token_count=0 - ) + rsp = await editor.search_index_repo(query="业务线", files_or_paths=chat_files + uploads_files + other_files) assert rsp shutil.rmtree(CHATS_ROOT)