From 05221a5dc0304f8a60e6b4565a45552e62c85356 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Wed, 11 Sep 2024 17:46:09 +0800 Subject: [PATCH] feat: IndexRepo + config feat: +search_index_repo --- metagpt/roles/di/data_analyst.py | 11 +++++++++-- metagpt/tools/libs/editor.py | 7 ++++++- metagpt/tools/libs/index_repo.py | 11 ++++++++++- 3 files changed, 25 insertions(+), 4 deletions(-) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index f9bead1ac..ed742c725 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -30,8 +30,15 @@ class DataAnalyst(RoleZero): instruction: str = ROLE_INSTRUCTION + EXTRA_INSTRUCTION task_type_desc: str = TASK_TYPE_DESC - tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser", "Editor:write,read", "SearchEnhancedQA"] - custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read"] + tools: list[str] = [ + "Plan", + "DataAnalyst", + "RoleZero", + "Browser", + "Editor:write,read,search_index_repo", + "SearchEnhancedQA", + ] + custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read,search_index_repo"] custom_tool_recommender: ToolRecommender = None experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever() diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 4448c3d02..49e552d2d 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -13,7 +13,6 @@ from typing import List, Optional, Union from pydantic import BaseModel, ConfigDict from metagpt.const import DEFAULT_WORKSPACE_ROOT -from metagpt.logs import logger from metagpt.tools.libs.index_repo import IndexRepo from metagpt.tools.libs.linter import Linter from metagpt.tools.tool_registry import register_tool @@ -951,5 +950,11 @@ class Editor(BaseModel): Returns: List[str]: A list of search results as strings, containing the text from the merged results and any direct results from other files. + + Example: + >>> query = "The problem to be analyzed from the document" + >>> file_or_path = "The document or folder you want to query" + >>> texts: List[str] = await Editor.search_index_repo(query=query, file_or_path=file_or_path) + >>> print(texts) """ return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 9f2fbdb27..c349fb53e 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -12,6 +12,7 @@ from llama_index.core.schema import NodeWithScore from pydantic import BaseModel, Field, model_validator from metagpt.config2 import Config +from metagpt.context import Context from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.factories.embedding import RAGEmbeddingFactory @@ -150,6 +151,7 @@ class IndexRepo(BaseModel): Args: paths (List[Path]): A list of paths to the documents to be added. + file_datas (Dict[Union[str, Path], str]): A list of file content. """ encoding = tiktoken.get_encoding("cl100k_base") filenames, _ = await self._filter(paths) @@ -185,6 +187,7 @@ class IndexRepo(BaseModel): return logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") engine = None + Context() if Path(self.persist_path).exists(): logger.debug(f"load index from {self.persist_path}") engine = SimpleEngine.from_index( @@ -283,10 +286,14 @@ class IndexRepo(BaseModel): Returns: List[NodeWithScore]: A list of nodes with scores matching the query. """ + if not filters: + return [] if not Path(self.persist_path).exists(): raise ValueError(f"IndexRepo {Path(self.persist_path).name} not exists.") + Context() engine = SimpleEngine.from_index( - index_config=FAISSIndexConfig(persist_path=self.persist_path), retriever_configs=[FAISSRetrieverConfig()] + index_config=FAISSIndexConfig(persist_path=self.persist_path), + retriever_configs=[FAISSRetrieverConfig()], ) rsp = await engine.aretrieve(query) return [i for i in rsp if i.metadata.get("file_path") in filters] @@ -409,6 +416,8 @@ class IndexRepo(BaseModel): result = [] v_result = [] for i in futures_results: + if not i: + continue if isinstance(i, str): result.append(i) else: