feat: search时不必设置IndexRepo的min/max_token_count

This commit is contained in:
莘权 马 2024-09-06 11:07:31 +08:00
parent ad80dab678
commit 95bf4c3e22
3 changed files with 39 additions and 14 deletions

View file

@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict
from metagpt.const import DEFAULT_WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.tools.libs.index_repo import DEFAULT_MIN_TOKEN_COUNT, OTHER_TYPE, IndexRepo
from metagpt.tools.libs.index_repo import OTHER_TYPE, IndexRepo
from metagpt.tools.libs.linter import Linter
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.file import File
@ -866,9 +866,7 @@ class Editor(BaseModel):
return path
@staticmethod
async def search_index_repo(
query: str, files_or_paths: List[Union[str, Path]], min_token_count: int = DEFAULT_MIN_TOKEN_COUNT
) -> List[str]:
async def search_index_repo(query: str, files_or_paths: List[Union[str, Path]]) -> List[str]:
"""Searches the index repository for a given query across specified files or paths.
This method classifies the provided files or paths, performing a search on each cluster
@ -878,7 +876,6 @@ class Editor(BaseModel):
Args:
query (str): The search query string to look for in the indexed files.
files_or_paths (List[Union[str, Path]]): A list of file paths or names to search within.
min_token_count (int, optional): The minimum token count to consider for indexing. Defaults to 0.
Returns:
List[str]: A list of search results as strings, containing the text from the merged results
@ -892,7 +889,7 @@ class Editor(BaseModel):
others.update(filenames)
continue
root = roots[persist_path]
repo = IndexRepo(persist_path=persist_path, root_path=root, min_token_count=min_token_count)
repo = IndexRepo(persist_path=persist_path, root_path=root)
futures.append(repo.search(query=query, filenames=list(filenames)))
for i in others:
@ -910,6 +907,6 @@ class Editor(BaseModel):
else:
v_result.append(i)
repo = IndexRepo(min_token_count=min_token_count)
repo = IndexRepo()
merged = await repo.merge(query=query, indices_list=v_result)
return [i.text for i in merged] + result

View file

@ -16,7 +16,7 @@ from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.factories.embedding import RAGEmbeddingFactory
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig
from metagpt.utils.common import awrite, generate_fingerprint, list_files
from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files
from metagpt.utils.file import File
UPLOADS_INDEX_ROOT = "/data/.index/uploads"
@ -31,6 +31,11 @@ DEFAULT_MIN_TOKEN_COUNT = 10000
DEFAULT_MAX_TOKEN_COUNT = 100000000
class IndexRepoMeta(BaseModel):
min_token_count: int
max_token_count: int
class TextScore(BaseModel):
filename: str
text: str
@ -43,6 +48,7 @@ class IndexRepo(BaseModel):
DEFAULT_ROOT # `/data/uploads` or r`/data/chats/\d+`, the root path of files indexed by the index repo.
)
fingerprint_filename: str = "fingerprint.json"
meta_filename: str = "meta.json"
model: Optional[str] = None
min_token_count: int = DEFAULT_MIN_TOKEN_COUNT
max_token_count: int = DEFAULT_MAX_TOKEN_COUNT
@ -81,10 +87,13 @@ class IndexRepo(BaseModel):
result: List[Union[NodeWithScore, TextScore]] = []
filenames, _ = await self._filter(filenames)
filter_filenames = set()
meta = await self._read_meta()
for i in filenames:
content = await File.read_text_file(i)
token_count = len(encoding.encode(content))
if not self._is_buildable(token_count):
if not self._is_buildable(
token_count, min_token_count=meta.min_token_count, max_token_count=meta.max_token_count
):
result.append(TextScore(filename=str(i), text=content))
continue
file_fingerprint = generate_fingerprint(content)
@ -190,6 +199,7 @@ class IndexRepo(BaseModel):
fp = generate_fingerprint(content)
self.fingerprints[str(i)] = fp
await awrite(filename=Path(self.persist_path) / self.fingerprint_filename, data=json.dumps(self.fingerprints))
await self._save_meta()
def __str__(self):
"""Return a string representation of the IndexRepo.
@ -199,7 +209,7 @@ class IndexRepo(BaseModel):
"""
return f"{self.persist_path}"
def _is_buildable(self, token_count: int) -> bool:
def _is_buildable(self, token_count: int, min_token_count: int = -1, max_token_count=-1) -> bool:
"""Check if the token count is within the buildable range.
Args:
@ -208,7 +218,9 @@ class IndexRepo(BaseModel):
Returns:
bool: True if buildable, False otherwise.
"""
if token_count < self.min_token_count or token_count > self.max_token_count:
min_token_count = min_token_count if min_token_count >= 0 else self.min_token_count
max_token_count = max_token_count if max_token_count >= 0 else self.max_token_count
if token_count < min_token_count or token_count > max_token_count:
return False
return True
@ -319,3 +331,21 @@ class IndexRepo(BaseModel):
clusters[path_type] = {path}
return clusters, roots
async def _save_meta(self):
meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count)
await awrite(filename=Path(self.persist_path) / self.meta_filename, data=meta.model_dump_json())
async def _read_meta(self) -> IndexRepoMeta:
default_meta = IndexRepoMeta(min_token_count=self.min_token_count, max_token_count=self.max_token_count)
filename = Path(self.persist_path) / self.meta_filename
if not filename.exists():
return default_meta
meta_data = await aread(filename=filename)
try:
meta = IndexRepoMeta.model_validate_json(meta_data)
return meta
except Exception as e:
logger.warning(f"Load meta error: {e}")
return default_meta