From 3d6286ca9ace8162c98483b92586825fd9c82de3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 12 Sep 2024 11:19:36 +0800 Subject: [PATCH] feat: rename similarity_search --- metagpt/roles/di/data_analyst.py | 2 +- metagpt/roles/di/role_zero.py | 2 +- metagpt/tools/libs/editor.py | 18 +++++++++--------- tests/metagpt/tools/libs/test_editor.py | 6 +++--- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/metagpt/roles/di/data_analyst.py b/metagpt/roles/di/data_analyst.py index ed742c725..b89c91ca3 100644 --- a/metagpt/roles/di/data_analyst.py +++ b/metagpt/roles/di/data_analyst.py @@ -35,7 +35,7 @@ class DataAnalyst(RoleZero): "DataAnalyst", "RoleZero", "Browser", - "Editor:write,read,search_index_repo", + "Editor:write,read,similarity_search", "SearchEnhancedQA", ] custom_tools: list[str] = ["web scraping", "Terminal", "Editor:write,read,search_index_repo"] diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 30a706b54..6a36dc65e 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -155,7 +155,7 @@ class RoleZero(Role): "scroll_up", "search_dir", "search_file", - "search_index_repo", + "similarity_search", # "set_workdir", "write", ] diff --git a/metagpt/tools/libs/editor.py b/metagpt/tools/libs/editor.py index 49e552d2d..9ba207bc6 100644 --- a/metagpt/tools/libs/editor.py +++ b/metagpt/tools/libs/editor.py @@ -936,25 +936,25 @@ class Editor(BaseModel): return path @staticmethod - async def search_index_repo(query: str, file_or_path: Union[str, Path]) -> List[str]: - """Searches the index repository for a given query across specified files or paths. + async def similarity_search(query: str, file_or_path: Union[str, Path]) -> List[str]: + """Performs a similarity search for a given query across specified files or paths. - This method classifies the provided files or paths, performing a search on each cluster - of files while handling other types of files separately. It merges results from structured - indices with any results from non-indexed files. + This method searches the index repository for the provided query, classifying the specified + files or paths. It performs a search on each cluster of files and handles non-indexed files + separately, merging results from structured indices with any direct results from non-indexed files. Args: query (str): The search query string to look for in the indexed files. - file_or_path (Union[str, Path]): A path or a filename to search within. + file_or_path (Union[str, Path]): A path or filename to search within. Returns: - List[str]: A list of search results as strings, containing the text from the merged results - and any direct results from other files. + List[str]: A list of results as strings, containing the text from the merged results + and any direct results from non-indexed files. Example: >>> query = "The problem to be analyzed from the document" >>> file_or_path = "The document or folder you want to query" - >>> texts: List[str] = await Editor.search_index_repo(query=query, file_or_path=file_or_path) + >>> texts: List[str] = await Editor.similarity_search(query=query, file_or_path=file_or_path) >>> print(texts) """ return await IndexRepo.cross_repo_search(query=query, file_or_path=file_or_path) diff --git a/tests/metagpt/tools/libs/test_editor.py b/tests/metagpt/tools/libs/test_editor.py index 1adcbc2b7..942eb4b57 100644 --- a/tests/metagpt/tools/libs/test_editor.py +++ b/tests/metagpt/tools/libs/test_editor.py @@ -692,11 +692,11 @@ async def test_index_repo(): chat_path, UPLOAD_ROOT, src_path = await mock_index_repo() editor = Editor() - rsp = await editor.search_index_repo(query="业务线", file_or_path=chat_path) + rsp = await editor.similarity_search(query="业务线", file_or_path=chat_path) assert rsp - rsp = await editor.search_index_repo(query="业务线", file_or_path=UPLOAD_ROOT) + rsp = await editor.similarity_search(query="业务线", file_or_path=UPLOAD_ROOT) assert rsp - rsp = await editor.search_index_repo(query="业务线", file_or_path=src_path) + rsp = await editor.similarity_search(query="业务线", file_or_path=src_path) assert rsp shutil.rmtree(CHATS_ROOT)