replace langchain with llama-index

This commit is contained in:
better629 2024-01-19 17:37:12 +08:00
parent 7005a1e86f
commit 4fcf724797
15 changed files with 175 additions and 71 deletions

6
.gitignore vendored
View file

@ -154,6 +154,11 @@ key.yaml
data
data.ms
examples/nb/
examples/default__vector_store.json
examples/docstore.json
examples/graph_store.json
examples/image__vector_store.json
examples/index_store.json
.chroma
*~$*
workspace/*
@ -168,6 +173,7 @@ output
tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
tests/data/rsp_cache.json
*.tmp
*.png
htmlcov

Binary file not shown.

Binary file not shown.

View file

@ -6,7 +6,7 @@
"""
import asyncio
from langchain.embeddings import OpenAIEmbeddings
from llama_index.embeddings import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.const import DATA_PATH, EXAMPLE_PATH
@ -17,7 +17,7 @@ from metagpt.roles import Sales
def get_store():
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
return FaissStore(DATA_PATH / "example.json", embedding=embedding)

View file

@ -11,12 +11,8 @@ from pathlib import Path
from typing import Optional, Union
import pandas as pd
from langchain.document_loaders import (
TextLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
)
from langchain.text_splitter import CharacterTextSplitter
from llama_index.node_parser import SimpleNodeParser
from llama_index.readers import Document, PDFReader, SimpleDirectoryReader
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
@ -28,7 +24,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
raise ValueError("Content column not found in DataFrame.")
def read_data(data_path: Path):
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
suffix = data_path.suffix
if ".xlsx" == suffix:
data = pd.read_excel(data_path)
@ -37,14 +33,13 @@ def read_data(data_path: Path):
elif ".json" == suffix:
data = pd.read_json(data_path)
elif suffix in (".docx", ".doc"):
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
elif ".txt" == suffix:
data = TextLoader(str(data_path)).load()
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
texts = text_splitter.split_documents(data)
data = texts
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
data = node_parser.get_nodes_from_documents(data)
elif ".pdf" == suffix:
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
data = PDFReader.load_data(str(data_path))
else:
raise NotImplementedError("File format not supported.")
return data
@ -146,9 +141,9 @@ class IndexableDocument(Document):
metadatas.append({})
return docs, metadatas
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
data = self.data
docs = [i.page_content for i in data]
docs = [i.text for i in data]
metadatas = [i.metadata for i in data]
return docs, metadatas
@ -156,7 +151,7 @@ class IndexableDocument(Document):
if isinstance(self.data, pd.DataFrame):
return self._get_docs_and_metadatas_by_df()
elif isinstance(self.data, list):
return self._get_docs_and_metadatas_by_langchain()
return self._get_docs_and_metadatas_by_llamaindex()
else:
raise NotImplementedError("Data type not supported for metadata extraction.")

View file

@ -39,8 +39,8 @@ class LocalStore(BaseStore, ABC):
self.store = self.write()
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
index_file = self.cache_dir / f"{self.fname}{index_ext}"
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
index_file = self.cache_dir / "default__vector_store.json"
store_file = self.cache_dir / "docstore.json"
return index_file, store_file
@abstractmethod

View file

@ -7,10 +7,14 @@
"""
import asyncio
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
import faiss
from llama_index import VectorStoreIndex, load_index_from_storage
from llama_index.embeddings import BaseEmbedding
from llama_index.schema import Document, QueryBundle, TextNode
from llama_index.storage import StorageContext
from llama_index.vector_stores import FaissVectorStore
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
@ -20,36 +24,52 @@ from metagpt.utils.embedding import get_embedding
class FaissStore(LocalStore):
def __init__(
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or get_embedding()
self.store: VectorStoreIndex
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
def _load(self) -> Optional["VectorStoreIndex"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
index = load_index_from_storage(storage_context)
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
return index
def _write(self, docs, metadatas):
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
return store
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
assert len(docs) == len(metadatas)
texts_embeds = self.embedding.get_text_embedding_batch(docs)
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
[TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)]
# doc_store = SimpleDocumentStore()
# doc_store.add_documents(nodes)
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context)
return index
def persist(self):
self.store.save_local(self.raw_data_path.parent, self.fname)
self.store.storage_context.persist(self.cache_dir)
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
retriever = self.store.as_retriever(similarity_top_k=k)
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
else:
return str(sep.join([f"{x.page_content}" for x in rsp]))
return str(sep.join([f"{x.node.text}" for x in rsp]))
async def asearch(self, *args, **kwargs):
return await asyncio.to_thread(self.search, *args, **kwargs)
@ -67,8 +87,12 @@ class FaissStore(LocalStore):
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
"""FIXME: Currently, the store is not updated after adding."""
return self.store.add_texts(texts)
texts_embeds = self.embedding.get_text_embedding_batch(texts)
nodes = [TextNode(embedding=embed) for embed in texts_embeds]
self.store.insert_nodes(nodes)
return []
def delete(self, *args, **kwargs):
"""Currently, langchain does not provide a delete interface."""
"""Currently, faiss does not provide a delete interface."""
raise NotImplementedError

22
metagpt/memory/memory2.py Normal file
View file

@ -0,0 +1,22 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : memory mechanism including store/retrieval/rank
from typing import Union, Optional
from pydantic import Field, BaseModel
from metagpt.memory.memory_network import MemoryNetwork
from metagpt.memory.schema import MemoryNode
from metagpt.schema import Message
class Memory(BaseModel):
mem_network: Optional[MemoryNetwork] = Field(default_factory=MemoryNetwork, description="the network to store memory")
def add_msg(self, message: Message):
mem_node = MemoryNode.create_mem_node_from_message(message)
self.mem_network.add_mem(mem_node)
def add_msgs(self, messages: list[Message]):
for msg in messages:
self.add_msg(msg)

View file

@ -0,0 +1,18 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the memory network to store memory segment
from pydantic import Field, BaseModel
from metagpt.memory.schema import MemorySegment, MemoryNode
class MemoryNetwork(BaseModel):
mem_seg: MemorySegment = Field(default_factory=MemorySegment, description="the memory segment to store memory nodes")
def add_mem(self, mem_node: MemoryNode):
self.mem_seg.add_mem_node(mem_node)
def add_mems(self, mem_nodes: list[MemoryNode]):
for mem_node in mem_nodes:
self.add_mem(mem_node)

View file

@ -5,11 +5,8 @@
"""
from pathlib import Path
from typing import Optional
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from llama_index.embeddings import BaseEmbedding
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
@ -23,29 +20,17 @@ class MemoryStorage(FaissStore):
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or OpenAIEmbeddings()
self.store: FAISS = None # Faiss engine
@property
def is_initialized(self) -> bool:
return self._initialized
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
@ -69,6 +54,7 @@ class MemoryStorage(FaissStore):
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id)
return index_fpath, storage_fpath
def persist(self):

61
metagpt/memory/schema.py Normal file
View file

@ -0,0 +1,61 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the memory schema definition
from datetime import datetime
from enum import Enum
from typing import Optional, Union
from uuid import UUID, uuid4
from pydantic import BaseModel, Field
class MemNodeType(Enum):
OBSERVE = "observe" # memory from observation
THINK = "think" # memory from self-think/reflect
class MemoryNode(BaseModel):
"""base unit of memory abstraction"""
mem_node_id: UUID = Field(default_factory=uuid4(), description="unique node id")
parent_node_id: Optional[str] = Field(default=None, description="memory's parent memory node id")
node_type: MemNodeType = Field(default=MemNodeType.OBSERVE, description="memory node type")
content: str = Field(default="", description="the memory content")
summary: Optional[str] = Field(default=None, description="the summary of the content by providers")
keywords: list[str] = Field(default=[], description="the extracted keywords of the content")
embedding: list[float] = Field(default=[], description="the embeeding of the content")
raw_path: Optional[str] = Field(default=None, description="the relative path of the media like image")
raw_corpus: list[Union[str, dict, tuple]] = Field(default=[], description="the raw corpus of the memory")
create_at: datetime = Field(default_factory=datetime, description="the memory create time")
access_at: datetime = Field(default_factory=datetime, description="the memory last access time")
expire_at: datetime = Field(default_factory=datetime, description="the memory expire time due to a TTL")
importance: int = Field(default=0, ge=0, le=10, description="the memory importance")
access_cnt: int = Field(default=0, description="the memory acess count time")
@classmethod
def create_mem_node(
cls,
content: str,
summary: Optional[str] = None,
keywords: list[str] = [],
node_type: MemNodeType = MemNodeType.OBSERVE,
):
pass
@classmethod
def create_mem_node_from_message(cls, message: "Message"):
pass
class MemorySegment(BaseModel):
"""segment abstraction to store memory_node"""
mem_nodes: list[MemoryNode] = Field(default=[], description="memory list to store MemoryNode")
def add_mem_node(self, mem_node: MemoryNode):
self.mem_nodes.append(mem_node)

View file

@ -102,12 +102,6 @@ class RoleContext(BaseModel):
) # see `Role._set_react_mode` for definitions of the following two attributes
max_react_loop: int = 1
def check(self, role_id: str):
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
# self.long_term_memory.recover_memory(role_id, self)
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
pass
@property
def important_memory(self) -> list[Message]:
"""Retrieve information corresponding to the attention action."""
@ -300,8 +294,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
buffer during _observe.
"""
self.rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self.rc.check(self.role_id)
def is_watch(self, caused_by: str):
return caused_by in self.rc.watch

View file

@ -5,12 +5,12 @@
@Author : alexanderwu
@File : embedding.py
"""
from langchain_community.embeddings import OpenAIEmbeddings
from llama_index.embeddings import OpenAIEmbedding
from metagpt.config2 import config
def get_embedding():
def get_embedding() -> OpenAIEmbedding:
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
return embedding

View file

@ -1,4 +1,4 @@
aiohttp==3.8.4
aiohttp==3.8.6
#azure_storage==0.37.0
channels==4.0.0
# chromadb
@ -11,11 +11,11 @@ typer==0.9.0
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0
langchain==0.0.352
llama-index==0.9.31
loguru==0.6.0
meilisearch==0.21.0
numpy==1.24.3
openai==1.6.0
openai==1.6.1
openpyxl
beautifulsoup4==4.12.2
pandas==2.0.3

View file

@ -25,7 +25,7 @@ async def test_search_json():
@pytest.mark.asyncio
async def test_search_xlsx():
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -36,5 +36,5 @@ async def test_search_xlsx():
async def test_write():
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore
assert _faiss_store.index
assert _faiss_store.storage_context.docstore
assert _faiss_store.storage_context.vector_store.client