feat: + comments

This commit is contained in:
莘权 马 2024-09-02 20:06:57 +08:00
parent a2b212703f
commit 86cc30b081

View file

@ -38,6 +38,11 @@ class IndexRepo(BaseModel):
@model_validator(mode="after")
def _update_fingerprints(self) -> "IndexRepo":
"""Load fingerprints from the fingerprint file if not already loaded.
Returns:
IndexRepo: The updated IndexRepo instance.
"""
if not self.fingerprints:
filename = Path(self.filename) / self.fingerprint_filename
if not filename.exists():
@ -49,6 +54,15 @@ class IndexRepo(BaseModel):
async def search(
self, query: str, filenames: Optional[List[Path]] = None
) -> Optional[List[Union[NodeWithScore, TextScore]]]:
"""Search for documents related to the given query.
Args:
query (str): The search query.
filenames (Optional[List[Path]]): A list of filenames to filter the search.
Returns:
Optional[List[Union[NodeWithScore, TextScore]]]: A list of search results containing NodeWithScore or TextScore.
"""
encoding = tiktoken.get_encoding("cl100k_base")
result: List[Union[NodeWithScore, TextScore]] = []
filenames, _ = await self._filter(filenames)
@ -70,6 +84,15 @@ class IndexRepo(BaseModel):
async def merge(
self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]]
) -> List[Union[NodeWithScore, TextScore]]:
"""Merge results from multiple indices based on the query.
Args:
query (str): The search query.
indices_list (List[List[Union[NodeWithScore, TextScore]]]): A list of result lists from different indices.
Returns:
List[Union[NodeWithScore, TextScore]]: A list of merged results sorted by similarity.
"""
if not self.embedding:
config = Config.default()
config.embedding.model = self.model
@ -87,6 +110,11 @@ class IndexRepo(BaseModel):
return [i[1] for i in scores][: self.recall_count]
async def add(self, paths: List[Path]):
"""Add new documents to the index.
Args:
paths (List[Path]): A list of paths to the documents to be added.
"""
encoding = tiktoken.get_encoding("cl100k_base")
filenames, _ = await self._filter(paths)
filter_filenames = []
@ -105,6 +133,12 @@ class IndexRepo(BaseModel):
await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames)
async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]):
"""Add and remove documents in a batch operation.
Args:
filenames (List[Union[str, Path]]): List of filenames to add.
delete_filenames (List[Union[str, Path]]): List of filenames to delete.
"""
if not filenames:
return
logger.info(f"update index repo, add {filenames}, remove {delete_filenames}")
@ -139,14 +173,35 @@ class IndexRepo(BaseModel):
await awrite(filename=Path(self.filename) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
def __str__(self):
"""Return a string representation of the IndexRepo.
Returns:
str: The filename of the index repository.
"""
return f"{self.filename}"
def _is_buildable(self, token_count: int) -> bool:
"""Check if the token count is within the buildable range.
Args:
token_count (int): The number of tokens in the content.
Returns:
bool: True if buildable, False otherwise.
"""
if token_count < self.min_token_count or token_count > self.max_token_count:
return False
return True
async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]):
"""Filter the provided filenames to only include valid text files.
Args:
filenames (Optional[List[Union[str, Path]]]): List of filenames to filter.
Returns:
Tuple[List[Path], List[Path]]: A tuple containing a list of valid pathnames and a list of excluded paths.
"""
root_path = Path(self.root_path).absolute()
if not filenames:
filenames = [root_path]
@ -173,6 +228,15 @@ class IndexRepo(BaseModel):
return pathnames, excludes
async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]:
"""Perform a search for the given query using the index.
Args:
query (str): The search query.
filters (Set[str]): A set of filenames to filter the search results.
Returns:
List[NodeWithScore]: A list of nodes with scores matching the query.
"""
if not Path(self.filename).exists():
return []
engine = SimpleEngine.from_index(
@ -182,6 +246,15 @@ class IndexRepo(BaseModel):
return [i for i in rsp if i.metadata.get("file_path") in filters]
def _is_fingerprint_changed(self, filename: Union[str, Path], content: str) -> bool:
"""Check if the fingerprint of the given document content has changed.
Args:
filename (Union[str, Path]): The filename of the document.
content (str): The content of the document.
Returns:
bool: True if the fingerprint has changed, False otherwise.
"""
old_fp = self.fingerprints.get(str(filename))
if not old_fp:
return True