feat: add index repo

This commit is contained in:
莘权 马 2024-08-29 11:25:23 +08:00
parent 17a1e76120
commit a8773d176e
6 changed files with 262 additions and 9 deletions

View file

@ -2,7 +2,8 @@
import json
import os
from typing import Any, Optional, Union
from pathlib import Path
from typing import Any, List, Optional, Set, Union
import fsspec
from llama_index.core import SimpleDirectoryReader
@ -77,6 +78,7 @@ class SimpleEngine(RetrieverQueryEngine):
callback_manager=callback_manager,
)
self._transformations = transformations or self._default_transformations()
self._filenames = set()
@classmethod
def from_docs(
@ -191,11 +193,11 @@ class SimpleEngine(RetrieverQueryEngine):
self._try_reconstruct_obj(nodes)
return nodes
def add_docs(self, input_files: list[str]):
def add_docs(self, input_files: List[Union[str, Path]]):
"""Add docs to retriever. retriever must has add_nodes func."""
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
documents = SimpleDirectoryReader(input_files=[str(i) for i in input_files]).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self._transformations)
@ -220,6 +222,24 @@ class SimpleEngine(RetrieverQueryEngine):
return self._retriever.query_total_count()
def delete_docs(self, input_files: List[Union[str, Path]]):
"""Delete documents from the index and document store.
Args:
input_files (List[Union[str, Path]]): A list of file paths or file names to be deleted.
Raises:
NotImplementedError: If the method is not implemented.
"""
exists_filenames = set()
filenames = {str(i) for i in input_files}
for doc_id, info in self.retriever._index.ref_doc_info.items():
if info.metadata.get("file_path") in filenames:
exists_filenames.add(doc_id)
for doc_id in exists_filenames:
self.retriever._index.delete_ref_doc(doc_id, delete_from_docstore=True)
@staticmethod
def get_obj_nodes(objs: Optional[list[RAGObject]] = None) -> list[ObjectNode]:
"""Converts a list of RAGObjects to a list of ObjectNodes."""
@ -323,3 +343,7 @@ class SimpleEngine(RetrieverQueryEngine):
@staticmethod
def _default_transformations():
return [SentenceSplitter()]
@property
def filenames(self) -> Set[str]:
return self._filenames

View file

@ -5,9 +5,6 @@ from typing import Any, Optional
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.gemini import GeminiEmbedding
from llama_index.embeddings.ollama import OllamaEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import Config
from metagpt.configs.embedding_config import EmbeddingType
@ -49,7 +46,9 @@ class RAGEmbeddingFactory(GenericFactory):
raise TypeError("To use RAG, please set your embedding in config2.yaml.")
def _create_openai(self) -> OpenAIEmbedding:
def _create_openai(self) -> "OpenAIEmbedding":
from llama_index.embeddings.openai import OpenAIEmbedding
params = dict(
api_key=self.config.embedding.api_key or self.config.llm.api_key,
api_base=self.config.embedding.base_url or self.config.llm.base_url,
@ -70,7 +69,9 @@ class RAGEmbeddingFactory(GenericFactory):
return AzureOpenAIEmbedding(**params)
def _create_gemini(self) -> GeminiEmbedding:
def _create_gemini(self) -> "GeminiEmbedding":
from llama_index.embeddings.gemini import GeminiEmbedding
params = dict(
api_key=self.config.embedding.api_key,
api_base=self.config.embedding.base_url,
@ -80,7 +81,9 @@ class RAGEmbeddingFactory(GenericFactory):
return GeminiEmbedding(**params)
def _create_ollama(self) -> OllamaEmbedding:
def _create_ollama(self) -> "OllamaEmbedding":
from llama_index.embeddings.ollama import OllamaEmbedding
params = dict(
base_url=self.config.embedding.base_url,
)