replace langchain with llama-index

This commit is contained in:
better629 2024-01-19 17:37:12 +08:00 committed by betterwang
parent 01d40e077b
commit cc91df59e5
13 changed files with 175 additions and 71 deletions

View file

@ -11,12 +11,8 @@ from pathlib import Path
from typing import Optional, Union
import pandas as pd
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import (
TextLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
)
from llama_index.node_parser import SimpleNodeParser
from llama_index.readers import Document, PDFReader, SimpleDirectoryReader
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
@ -29,7 +25,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
raise ValueError("Content column not found in DataFrame.")
def read_data(data_path: Path):
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
suffix = data_path.suffix
if ".xlsx" == suffix:
data = pd.read_excel(data_path)
@ -38,14 +34,13 @@ def read_data(data_path: Path):
elif ".json" == suffix:
data = pd.read_json(data_path)
elif suffix in (".docx", ".doc"):
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
elif ".txt" == suffix:
data = TextLoader(str(data_path)).load()
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
texts = text_splitter.split_documents(data)
data = texts
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
data = node_parser.get_nodes_from_documents(data)
elif ".pdf" == suffix:
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
data = PDFReader.load_data(str(data_path))
else:
raise NotImplementedError("File format not supported.")
return data
@ -150,9 +145,9 @@ class IndexableDocument(Document):
metadatas.append({})
return docs, metadatas
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
data = self.data
docs = [i.page_content for i in data]
docs = [i.text for i in data]
metadatas = [i.metadata for i in data]
return docs, metadatas
@ -160,7 +155,7 @@ class IndexableDocument(Document):
if isinstance(self.data, pd.DataFrame):
return self._get_docs_and_metadatas_by_df()
elif isinstance(self.data, list):
return self._get_docs_and_metadatas_by_langchain()
return self._get_docs_and_metadatas_by_llamaindex()
else:
raise NotImplementedError("Data type not supported for metadata extraction.")