From a8773d176e2492b944e92dae68e26837f80468cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Thu, 29 Aug 2024 11:25:23 +0800 Subject: [PATCH] feat: add index repo --- metagpt/configs/embedding_config.py | 4 + metagpt/rag/engines/simple.py | 30 +++- metagpt/rag/factories/embedding.py | 15 +- metagpt/tools/libs/index_repo.py | 169 ++++++++++++++++++++ metagpt/utils/common.py | 21 +++ tests/metagpt/tools/libs/test_index_repo.py | 32 ++++ 6 files changed, 262 insertions(+), 9 deletions(-) create mode 100644 metagpt/tools/libs/index_repo.py create mode 100644 tests/metagpt/tools/libs/test_index_repo.py diff --git a/metagpt/configs/embedding_config.py b/metagpt/configs/embedding_config.py index 20de47999..f9b41b9dc 100644 --- a/metagpt/configs/embedding_config.py +++ b/metagpt/configs/embedding_config.py @@ -20,11 +20,13 @@ class EmbeddingConfig(YamlModel): --------- api_type: "openai" api_key: "YOU_API_KEY" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "azure" api_key: "YOU_API_KEY" base_url: "YOU_BASE_URL" api_version: "YOU_API_VERSION" + dimensions: "YOUR_MODEL_DIMENSIONS" api_type: "gemini" api_key: "YOU_API_KEY" @@ -32,6 +34,7 @@ class EmbeddingConfig(YamlModel): api_type: "ollama" base_url: "YOU_BASE_URL" model: "YOU_MODEL" + dimensions: "YOUR_MODEL_DIMENSIONS" """ api_type: Optional[EmbeddingType] = None @@ -41,6 +44,7 @@ class EmbeddingConfig(YamlModel): model: Optional[str] = None embed_batch_size: Optional[int] = None + dimensions: Optional[int] = None # output dimension of embedding model @field_validator("api_type", mode="before") @classmethod diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index be4c3daf5..4b0876911 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -2,7 +2,8 @@ import json import os -from typing import Any, Optional, Union +from pathlib import Path +from typing import Any, List, Optional, Set, Union import fsspec from llama_index.core import SimpleDirectoryReader @@ -77,6 +78,7 @@ class SimpleEngine(RetrieverQueryEngine): callback_manager=callback_manager, ) self._transformations = transformations or self._default_transformations() + self._filenames = set() @classmethod def from_docs( @@ -191,11 +193,11 @@ class SimpleEngine(RetrieverQueryEngine): self._try_reconstruct_obj(nodes) return nodes - def add_docs(self, input_files: list[str]): + def add_docs(self, input_files: List[Union[str, Path]]): """Add docs to retriever. retriever must has add_nodes func.""" self._ensure_retriever_modifiable() - documents = SimpleDirectoryReader(input_files=input_files).load_data() + documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data() self._fix_document_metadata(documents) nodes = run_transformations(documents, transformations=self._transformations) @@ -220,6 +222,24 @@ class SimpleEngine(RetrieverQueryEngine): return self._retriever.query_total_count() + def delete_docs(self, input_files: List[Union[str, Path]]): + """Delete documents from the index and document store. + + Args: + input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted. + + Raises: + NotImplementedError: If the method is not implemented. + """ + exists_filenames = set() + filenames = {str(i) for i in input_files} + for doc_id, info in self.retriever._index.ref_doc_info.items(): + if info.metadata.get("file_path") in filenames: + exists_filenames.add(doc_id) + + for doc_id in exists_filenames: + self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True) + @staticmethod def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]: """Converts a list of RAGObjects to a list of ObjectNodes.""" @@ -323,3 +343,7 @@ class SimpleEngine(RetrieverQueryEngine): @staticmethod def _default_transformations(): return [SentenceSplitter()] + + @property + def filenames(self) -> Set[str]: + return self._filenames diff --git a/metagpt/rag/factories/embedding.py b/metagpt/rag/factories/embedding.py index d647883bd..19b8b36f6 100644 --- a/metagpt/rag/factories/embedding.py +++ b/metagpt/rag/factories/embedding.py @@ -5,9 +5,6 @@ from typing import Any, Optional from llama_index.core.embeddings import BaseEmbedding from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding -from llama_index.embeddings.gemini import GeminiEmbedding -from llama_index.embeddings.ollama import OllamaEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding from metagpt.config2 import Config from metagpt.configs.embedding_config import EmbeddingType @@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory): raise TypeError("To use RAG, please set your embedding in config2.yaml.") - def _create_openai(self) -> OpenAIEmbedding: + def _create_openai(self) -> "OpenAIEmbedding": + from llama_index.embeddings.openai import OpenAIEmbedding + params = dict( api_key=self.config.embedding.api_key or self.config.llm.api_key, api_base=self.config.embedding.base_url or self.config.llm.base_url, @@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory): return AzureOpenAIEmbedding(**params) - def _create_gemini(self) -> GeminiEmbedding: + def _create_gemini(self) -> "GeminiEmbedding": + from llama_index.embeddings.gemini import GeminiEmbedding + params = dict( api_key=self.config.embedding.api_key, api_base=self.config.embedding.base_url, @@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory): return GeminiEmbedding(**params) - def _create_ollama(self) -> OllamaEmbedding: + def _create_ollama(self) -> "OllamaEmbedding": + from llama_index.embeddings.ollama import OllamaEmbedding + params = dict( base_url=self.config.embedding.base_url, ) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py new file mode 100644 index 000000000..720fec4fd --- /dev/null +++ b/metagpt/tools/libs/index_repo.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +import json +from pathlib import Path +from typing import Dict, List, Optional, Set, Union + +import tiktoken +from llama_index.core.base.embeddings.base import BaseEmbedding +from llama_index.core.schema import NodeWithScore +from pydantic import BaseModel, Field, model_validator + +from metagpt.config2 import Config +from metagpt.logs import logger +from metagpt.rag.engines import SimpleEngine +from metagpt.rag.factories.embedding import RAGEmbeddingFactory +from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig, LLMRankerConfig +from metagpt.utils.common import aread, awrite, generate_fingerprint, list_files +from metagpt.utils.repo_to_markdown import is_text_file + + +class TextScore(BaseModel): + filename: str + text: str + score: Optional[float] = None + + +class IndexRepo(BaseModel): + filename: str + root_path: str + fingerprint_filename: str = "fingerprint.json" + model: str = "text-embedding-ada-002" + min_token_count: int = 5000 + max_token_count: int = 100000 + recall_count: int = 5 + embedding: Optional[BaseEmbedding] = Field(default=None, exclude=True) + fingerprints: Dict[str, str] = Field(default_factory=dict) + + @model_validator(mode="after") + def _update_fingerprints(self) -> "IndexRepo": + if not self.fingerprints: + filename = Path(self.filename) / self.fingerprint_filename + if not filename.exists(): + return self + with open(str(filename), "r") as reader: + self.fingerprints = json.load(reader) + return self + + async def search( + self, query: str, filenames: Optional[List[Path]] = None + ) -> Optional[List[Union[NodeWithScore, TextScore]]]: + encoding = tiktoken.get_encoding("cl100k_base") + result: List[Union[NodeWithScore, TextScore]] = [] + filenames, _ = await self._filter(filenames) + filter_filenames = set() + for i in filenames: + content = await aread(filename=i) + token_count = len(encoding.encode(content)) + if not self._is_buildable(token_count): + result.append(TextScore(filename=str(i), text=content)) + continue + file_fingerprint = generate_fingerprint(content) + if self.fingerprints.get(str(i)) != file_fingerprint: + logger.error(f'file: "{i}" changed but not indexed') + continue + filter_filenames.add(str(i)) + nodes = await self._search(query=query, filters=filter_filenames) + return result + nodes + + async def merge( + self, query: str, indices_list: List[List[Union[NodeWithScore, TextScore]]] + ) -> List[Union[NodeWithScore, TextScore]]: + if not self.embedding: + config = Config.default() + config.embedding.model = self.model + factory = RAGEmbeddingFactory(config) + self.embedding = factory.get_rag_embedding() + + scores = [] + query_embedding = await self.embedding.aget_text_embedding(query) + flat_nodes = [node for indices in indices_list for node in indices] + for i in flat_nodes: + text_embedding = await self.embedding.aget_text_embedding(i.text) + similarity = self.embedding.similarity(query_embedding, text_embedding) + scores.append((similarity, i)) + scores.sort(key=lambda x: x[0], reverse=True) + return [i[1] for i in scores][: self.recall_count] + + async def add(self, paths: List[Path]): + encoding = tiktoken.get_encoding("cl100k_base") + filenames, _ = await self._filter(paths) + filter_filenames = [] + delete_filenames = [] + for i in filenames: + content = await aread(filename=i) + token_count = len(encoding.encode(content)) + if self._is_buildable(token_count): + filter_filenames.append(i) + else: + delete_filenames.append(i) + await self._add_batch(filenames=filter_filenames, delete_filenames=delete_filenames) + + async def _add_batch(self, filenames: List[Union[str, Path]], delete_filenames: List[Union[str, Path]]): + if not filenames: + return + engine = None + if Path(self.filename).exists(): + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()] + ) + try: + engine.delete_docs(filenames + delete_filenames) + engine.add_docs(input_files=filenames) + except NotImplementedError as e: + logger.debug(f"{e}") + filenames = list(set([str(i) for i in filenames] + list(self.fingerprints.keys()))) + engine = None + if not engine: + engine = SimpleEngine.from_docs( + input_files=[str(i) for i in filenames], + retriever_configs=[FAISSRetrieverConfig()], + ranker_configs=[LLMRankerConfig()], + ) + engine.persist(persist_dir=self.filename) + for i in filenames: + content = await aread(i) + fp = generate_fingerprint(content) + self.fingerprints[str(i)] = fp + await awrite(filename=Path(self.filename) / self.fingerprint_filename, data=json.dumps(self.fingerprints)) + + def __str__(self): + return f"{self.filename}" + + def _is_buildable(self, token_count: int) -> bool: + if token_count < self.min_token_count or token_count > self.max_token_count: + return False + return True + + async def _filter(self, filenames: Optional[List[Union[str, Path]]] = None) -> (List[Path], List[Path]): + root_path = Path(self.root_path).absolute() + if not filenames: + filenames = [root_path] + pathnames = [] + excludes = [] + for i in filenames: + path = Path(i).absolute() + if not path.is_relative_to(root_path): + excludes.append(path) + continue + if not path.is_dir(): + is_text, _ = await is_text_file(path) + if is_text: + pathnames.append(path) + continue + subfiles = list_files(path) + for j in subfiles: + is_text, _ = await is_text_file(j) + if is_text: + pathnames.append(j) + + return pathnames, excludes + + async def _search(self, query: str, filters: Set[str]) -> List[NodeWithScore]: + if not Path(self.filename).exists(): + return [] + engine = SimpleEngine.from_index( + index_config=FAISSIndexConfig(persist_path=self.filename), retriever_configs=[FAISSRetrieverConfig()] + ) + rsp = await engine.aretrieve(query) + return [i for i in rsp if i.metadata.get("file_path") in filters] diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 2b2a209be..879e06772 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -16,6 +16,7 @@ import base64 import contextlib import csv import functools +import hashlib import importlib import inspect import json @@ -1175,3 +1176,23 @@ def rectify_pathname(path: Union[str, Path], default_filename: str) -> Path: else: output_pathname.parent.mkdir(parents=True, exist_ok=True) return output_pathname + + +def generate_fingerprint(text: str) -> str: + """ + Generate a fingerprint for the given text + + Args: + text (str): The text for which the fingerprint needs to be generated + + Returns: + str: The fingerprint value of the text + """ + text_bytes = text.encode("utf-8") + + # calculate SHA-256 hash + sha256 = hashlib.sha256() + sha256.update(text_bytes) + fingerprint = sha256.hexdigest() + + return fingerprint diff --git a/tests/metagpt/tools/libs/test_index_repo.py b/tests/metagpt/tools/libs/test_index_repo.py new file mode 100644 index 000000000..65c5f1af9 --- /dev/null +++ b/tests/metagpt/tools/libs/test_index_repo.py @@ -0,0 +1,32 @@ +import shutil + +import pytest + +from metagpt.const import DEFAULT_WORKSPACE_ROOT, TEST_DATA_PATH +from metagpt.tools.libs.index_repo import IndexRepo + + +@pytest.mark.asyncio +@pytest.mark.parametrize(("path", "query"), [(TEST_DATA_PATH / "requirements", "业务线")]) +async def test_index_repo(path, query): + index_path = DEFAULT_WORKSPACE_ROOT / ".index" + repo = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0) + await repo.add([path]) + await repo.add([path]) + assert index_path.exists() + + rsp = await repo.search(query) + assert rsp + + repo2 = IndexRepo(filename=str(index_path), root_path=str(path), min_token_count=0) + rsp2 = await repo2.search(query) + assert rsp2 + + merged_rsp = await repo.merge(query=query, indices_list=[rsp, rsp2]) + assert merged_rsp + + shutil.rmtree(index_path) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])