mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
replace langchain with llama-index
This commit is contained in:
parent
7005a1e86f
commit
4fcf724797
15 changed files with 175 additions and 71 deletions
6
.gitignore
vendored
6
.gitignore
vendored
|
|
@ -154,6 +154,11 @@ key.yaml
|
|||
data
|
||||
data.ms
|
||||
examples/nb/
|
||||
examples/default__vector_store.json
|
||||
examples/docstore.json
|
||||
examples/graph_store.json
|
||||
examples/image__vector_store.json
|
||||
examples/index_store.json
|
||||
.chroma
|
||||
*~$*
|
||||
workspace/*
|
||||
|
|
@ -168,6 +173,7 @@ output
|
|||
tmp.png
|
||||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
tests/data/rsp_cache.json
|
||||
*.tmp
|
||||
*.png
|
||||
htmlcov
|
||||
|
|
|
|||
Binary file not shown.
Binary file not shown.
|
|
@ -6,7 +6,7 @@
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_PATH
|
||||
|
|
@ -17,7 +17,7 @@ from metagpt.roles import Sales
|
|||
|
||||
def get_store():
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -11,12 +11,8 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from langchain.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.readers import Document, PDFReader, SimpleDirectoryReader
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -28,7 +24,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
|
|||
raise ValueError("Content column not found in DataFrame.")
|
||||
|
||||
|
||||
def read_data(data_path: Path):
|
||||
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
|
||||
suffix = data_path.suffix
|
||||
if ".xlsx" == suffix:
|
||||
data = pd.read_excel(data_path)
|
||||
|
|
@ -37,14 +33,13 @@ def read_data(data_path: Path):
|
|||
elif ".json" == suffix:
|
||||
data = pd.read_json(data_path)
|
||||
elif suffix in (".docx", ".doc"):
|
||||
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
elif ".txt" == suffix:
|
||||
data = TextLoader(str(data_path)).load()
|
||||
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(data)
|
||||
data = texts
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
data = node_parser.get_nodes_from_documents(data)
|
||||
elif ".pdf" == suffix:
|
||||
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
|
||||
data = PDFReader.load_data(str(data_path))
|
||||
else:
|
||||
raise NotImplementedError("File format not supported.")
|
||||
return data
|
||||
|
|
@ -146,9 +141,9 @@ class IndexableDocument(Document):
|
|||
metadatas.append({})
|
||||
return docs, metadatas
|
||||
|
||||
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
|
||||
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
|
||||
data = self.data
|
||||
docs = [i.page_content for i in data]
|
||||
docs = [i.text for i in data]
|
||||
metadatas = [i.metadata for i in data]
|
||||
return docs, metadatas
|
||||
|
||||
|
|
@ -156,7 +151,7 @@ class IndexableDocument(Document):
|
|||
if isinstance(self.data, pd.DataFrame):
|
||||
return self._get_docs_and_metadatas_by_df()
|
||||
elif isinstance(self.data, list):
|
||||
return self._get_docs_and_metadatas_by_langchain()
|
||||
return self._get_docs_and_metadatas_by_llamaindex()
|
||||
else:
|
||||
raise NotImplementedError("Data type not supported for metadata extraction.")
|
||||
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ class LocalStore(BaseStore, ABC):
|
|||
self.store = self.write()
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
index_file = self.cache_dir / f"{self.fname}{index_ext}"
|
||||
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
|
||||
index_file = self.cache_dir / "default__vector_store.json"
|
||||
store_file = self.cache_dir / "docstore.json"
|
||||
return index_file, store_file
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -7,10 +7,14 @@
|
|||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
import faiss
|
||||
from llama_index import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.embeddings import BaseEmbedding
|
||||
from llama_index.schema import Document, QueryBundle, TextNode
|
||||
from llama_index.storage import StorageContext
|
||||
from llama_index.vector_stores import FaissVectorStore
|
||||
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
|
|
@ -20,36 +24,52 @@ from metagpt.utils.embedding import get_embedding
|
|||
|
||||
class FaissStore(LocalStore):
|
||||
def __init__(
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
|
||||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: VectorStoreIndex
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
def _load(self) -> Optional["VectorStoreIndex"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
|
||||
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
|
||||
index = load_index_from_storage(storage_context)
|
||||
|
||||
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
|
||||
return index
|
||||
|
||||
def _write(self, docs, metadatas):
|
||||
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
|
||||
return store
|
||||
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
|
||||
assert len(docs) == len(metadatas)
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(docs)
|
||||
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
|
||||
|
||||
[TextNode(embedding=embed, metadata=metadatas[idx]) for idx, embed in enumerate(texts_embeds)]
|
||||
# doc_store = SimpleDocumentStore()
|
||||
# doc_store.add_documents(nodes)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context)
|
||||
|
||||
return index
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.raw_data_path.parent, self.fname)
|
||||
self.store.storage_context.persist(self.cache_dir)
|
||||
|
||||
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
retriever = self.store.as_retriever(similarity_top_k=k)
|
||||
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
|
||||
|
||||
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
rsp = self.store.similarity_search(query, k=k, **kwargs)
|
||||
logger.debug(rsp)
|
||||
if expand_cols:
|
||||
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
|
||||
else:
|
||||
return str(sep.join([f"{x.page_content}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}" for x in rsp]))
|
||||
|
||||
async def asearch(self, *args, **kwargs):
|
||||
return await asyncio.to_thread(self.search, *args, **kwargs)
|
||||
|
|
@ -67,8 +87,12 @@ class FaissStore(LocalStore):
|
|||
|
||||
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
|
||||
"""FIXME: Currently, the store is not updated after adding."""
|
||||
return self.store.add_texts(texts)
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(texts)
|
||||
nodes = [TextNode(embedding=embed) for embed in texts_embeds]
|
||||
self.store.insert_nodes(nodes)
|
||||
|
||||
return []
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
"""Currently, langchain does not provide a delete interface."""
|
||||
"""Currently, faiss does not provide a delete interface."""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
22
metagpt/memory/memory2.py
Normal file
22
metagpt/memory/memory2.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : memory mechanism including store/retrieval/rank
|
||||
|
||||
from typing import Union, Optional
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from metagpt.memory.memory_network import MemoryNetwork
|
||||
from metagpt.memory.schema import MemoryNode
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
mem_network: Optional[MemoryNetwork] = Field(default_factory=MemoryNetwork, description="the network to store memory")
|
||||
|
||||
def add_msg(self, message: Message):
|
||||
mem_node = MemoryNode.create_mem_node_from_message(message)
|
||||
self.mem_network.add_mem(mem_node)
|
||||
|
||||
def add_msgs(self, messages: list[Message]):
|
||||
for msg in messages:
|
||||
self.add_msg(msg)
|
||||
18
metagpt/memory/memory_network.py
Normal file
18
metagpt/memory/memory_network.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the memory network to store memory segment
|
||||
|
||||
from pydantic import Field, BaseModel
|
||||
|
||||
from metagpt.memory.schema import MemorySegment, MemoryNode
|
||||
|
||||
|
||||
class MemoryNetwork(BaseModel):
|
||||
mem_seg: MemorySegment = Field(default_factory=MemorySegment, description="the memory segment to store memory nodes")
|
||||
|
||||
def add_mem(self, mem_node: MemoryNode):
|
||||
self.mem_seg.add_mem_node(mem_node)
|
||||
|
||||
def add_mems(self, mem_nodes: list[MemoryNode]):
|
||||
for mem_node in mem_nodes:
|
||||
self.add_mem(mem_node)
|
||||
|
|
@ -5,11 +5,8 @@
|
|||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from llama_index.embeddings import BaseEmbedding
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
|
|
@ -23,29 +20,17 @@ class MemoryStorage(FaissStore):
|
|||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
|
|
@ -69,6 +54,7 @@ class MemoryStorage(FaissStore):
|
|||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
self.cache_dir = Path(self.role_mem_path).joinpath(self.role_id)
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
|
|
|
|||
61
metagpt/memory/schema.py
Normal file
61
metagpt/memory/schema.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the memory schema definition
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class MemNodeType(Enum):
|
||||
OBSERVE = "observe" # memory from observation
|
||||
THINK = "think" # memory from self-think/reflect
|
||||
|
||||
|
||||
class MemoryNode(BaseModel):
|
||||
"""base unit of memory abstraction"""
|
||||
|
||||
mem_node_id: UUID = Field(default_factory=uuid4(), description="unique node id")
|
||||
parent_node_id: Optional[str] = Field(default=None, description="memory's parent memory node id")
|
||||
node_type: MemNodeType = Field(default=MemNodeType.OBSERVE, description="memory node type")
|
||||
|
||||
content: str = Field(default="", description="the memory content")
|
||||
summary: Optional[str] = Field(default=None, description="the summary of the content by providers")
|
||||
keywords: list[str] = Field(default=[], description="the extracted keywords of the content")
|
||||
embedding: list[float] = Field(default=[], description="the embeeding of the content")
|
||||
|
||||
raw_path: Optional[str] = Field(default=None, description="the relative path of the media like image")
|
||||
raw_corpus: list[Union[str, dict, tuple]] = Field(default=[], description="the raw corpus of the memory")
|
||||
|
||||
create_at: datetime = Field(default_factory=datetime, description="the memory create time")
|
||||
access_at: datetime = Field(default_factory=datetime, description="the memory last access time")
|
||||
expire_at: datetime = Field(default_factory=datetime, description="the memory expire time due to a TTL")
|
||||
|
||||
importance: int = Field(default=0, ge=0, le=10, description="the memory importance")
|
||||
access_cnt: int = Field(default=0, description="the memory acess count time")
|
||||
|
||||
@classmethod
|
||||
def create_mem_node(
|
||||
cls,
|
||||
content: str,
|
||||
summary: Optional[str] = None,
|
||||
keywords: list[str] = [],
|
||||
node_type: MemNodeType = MemNodeType.OBSERVE,
|
||||
):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def create_mem_node_from_message(cls, message: "Message"):
|
||||
pass
|
||||
|
||||
|
||||
class MemorySegment(BaseModel):
|
||||
"""segment abstraction to store memory_node"""
|
||||
|
||||
mem_nodes: list[MemoryNode] = Field(default=[], description="memory list to store MemoryNode")
|
||||
|
||||
def add_mem_node(self, mem_node: MemoryNode):
|
||||
self.mem_nodes.append(mem_node)
|
||||
|
|
@ -102,12 +102,6 @@ class RoleContext(BaseModel):
|
|||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
|
||||
pass
|
||||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""Retrieve information corresponding to the attention action."""
|
||||
|
|
@ -300,8 +294,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
buffer during _observe.
|
||||
"""
|
||||
self.rc.watch = {any_to_str(t) for t in actions}
|
||||
# check RoleContext after adding watch actions
|
||||
self.rc.check(self.role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self.rc.watch
|
||||
|
|
|
|||
|
|
@ -5,12 +5,12 @@
|
|||
@Author : alexanderwu
|
||||
@File : embedding.py
|
||||
"""
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_embedding():
|
||||
def get_embedding() -> OpenAIEmbedding:
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return embedding
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
aiohttp==3.8.4
|
||||
aiohttp==3.8.6
|
||||
#azure_storage==0.37.0
|
||||
channels==4.0.0
|
||||
# chromadb
|
||||
|
|
@ -11,11 +11,11 @@ typer==0.9.0
|
|||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.4.0
|
||||
langchain==0.0.352
|
||||
llama-index==0.9.31
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
openai==1.6.0
|
||||
openai==1.6.1
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.2
|
||||
pandas==2.0.3
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ async def test_search_json():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -36,5 +36,5 @@ async def test_search_xlsx():
|
|||
async def test_write():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue