From 86cc30b081b0afba91b978a3b486ff501550788f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 2 Sep 2024 20:06:57 +0800 Subject: [PATCH] feat: + comments --- metagpt/tools/libs/index_repo.py | 73 ++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index b7787cbe4..a3efa6d27 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -38,6 +38,11 @@ class IndexRepo(BaseModel): @model_validator(mode="after") def _update_fingerprints(self) -> "IndexRepo": + """Load fingerprints from the fingerprint file if not already loaded. + + Returns: + IndexRepo: The updated IndexRepo instance. + """ if not self.fingerprints: filename = Path(self.filename) / self.fingerprint_filename if not filename.exists(): @@ -49,6 +54,15 @@ class IndexRepo(BaseModel): async def search( self, query: str, filenames: Optional[List[Path]] = None ) -> Optional[List[Union[NodeWithScore, TextScore]]]: + """Search for documents related to the given query. + + Args: + query (str): The search query. + filenames (Optional[List[Path]]): A list of filenames to filter the search. + + Returns: + Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore. + """ encoding = tiktoken.get_encoding("cl100k_base") result: List[Union[NodeWithScore, TextScore]] = [] filenames, _ = await self._filter(filenames) @@ -70,6 +84,15 @@ class IndexRepo(BaseModel): async def merge( self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]] ) -> List[Union[NodeWithScore, TextScore]]: + """Merge results from multiple indices based on the query. + + Args: + query (str): The search query. + indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices. + + Returns: + List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity. + """ if not self.embedding: config = Config.default() config.embedding.model = self.model @@ -87,6 +110,11 @@ class IndexRepo(BaseModel): return [i[1] for i in scores][: self.recall_count] async def add(self, paths: List[Path]): + """Add new documents to the index. + + Args: + paths (List[Path]): A list of paths to the documents to be added. + """ encoding = tiktoken.get_encoding("cl100k_base") filenames, _ = await self._filter(paths) filter_filenames = [] @@ -105,6 +133,12 @@ class IndexRepo(BaseModel): await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames) async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]): + """Add and remove documents in a batch operation. + + Args: + filenames (List[Union[str, Path]]): List of filenames to add. + delete_filenames (List[Union[str, Path]]): List of filenames to delete. + """ if not filenames: return logger.info(f"update index repo, add {filenames}, remove {delete_filenames}") @@ -139,14 +173,35 @@ class IndexRepo(BaseModel): await awrite(filename=Path(self.filename) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) def __str__(self): + """Return a string representation of the IndexRepo. + + Returns: + str: The filename of the index repository. + """ return f"{self.filename}" def _is_buildable(self, token_count: int) -> bool: + """Check if the token count is within the buildable range. + + Args: + token_count (int): The number of tokens in the content. + + Returns: + bool: True if buildable, False otherwise. + """ if token_count < self.min_token_count or token_count > self.max_token_count: return False return True async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]): + """Filter the provided filenames to only include valid text files. + + Args: + filenames (Optional[List[Union[str, Path]]]): List of filenames to filter. + + Returns: + Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths. + """ root_path = Path(self.root_path).absolute() if not filenames: filenames = [root_path] @@ -173,6 +228,15 @@ class IndexRepo(BaseModel): return pathnames, excludes async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]: + """Perform a search for the given query using the index. + + Args: + query (str): The search query. + filters (Set[str]): A set of filenames to filter the search results. + + Returns: + List[NodeWithScore]: A list of nodes with scores matching the query. + """ if not Path(self.filename).exists(): return [] engine = SimpleEngine.from_index( @@ -182,6 +246,15 @@ class IndexRepo(BaseModel): return [i for i in rsp if i.metadata.get("file_path") in filters] def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool: + """Check if the fingerprint of the given document content has changed. + + Args: + filename (Union[str, Path]): The filename of the document. + content (str): The content of the document. + + Returns: + bool: True if the fingerprint has changed, False otherwise. + """ old_fp = self.fingerprints.get(str(filename)) if not old_fp: return True