Merge pull request #974 from better629/feat_memory

Feat add rag
This commit is contained in:
Alexander Wu 2024-03-17 23:39:12 +08:00 committed by GitHub
commit e783e5b208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 2353 additions and 248 deletions

View file

@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
DATA_PATH = METAGPT_ROOT / "data"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"

View file

@ -11,12 +11,9 @@ from pathlib import Path
from typing import Optional, Union
import pandas as pd
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import (
TextLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
)
from llama_index.core import Document, SimpleDirectoryReader
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.readers.file import PDFReader
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
@ -29,7 +26,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
raise ValueError("Content column not found in DataFrame.")
def read_data(data_path: Path):
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
suffix = data_path.suffix
if ".xlsx" == suffix:
data = pd.read_excel(data_path)
@ -38,14 +35,13 @@ def read_data(data_path: Path):
elif ".json" == suffix:
data = pd.read_json(data_path)
elif suffix in (".docx", ".doc"):
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
elif ".txt" == suffix:
data = TextLoader(str(data_path)).load()
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
texts = text_splitter.split_documents(data)
data = texts
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
data = node_parser.get_nodes_from_documents(data)
elif ".pdf" == suffix:
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
data = PDFReader.load_data(str(data_path))
else:
raise NotImplementedError("File format not supported.")
return data
@ -150,9 +146,9 @@ class IndexableDocument(Document):
metadatas.append({})
return docs, metadatas
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
data = self.data
docs = [i.page_content for i in data]
docs = [i.text for i in data]
metadatas = [i.metadata for i in data]
return docs, metadatas
@ -160,7 +156,7 @@ class IndexableDocument(Document):
if isinstance(self.data, pd.DataFrame):
return self._get_docs_and_metadatas_by_df()
elif isinstance(self.data, list):
return self._get_docs_and_metadatas_by_langchain()
return self._get_docs_and_metadatas_by_llamaindex()
else:
raise NotImplementedError("Data type not supported for metadata extraction.")

View file

@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC):
if not self.store:
self.store = self.write()
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
index_file = self.cache_dir / f"{self.fname}{index_ext}"
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"):
index_file = self.cache_dir / "default__vector_store" / index_ext
store_file = self.cache_dir / "docstore" / docstore_ext
return index_file, store_file
@abstractmethod

View file

@ -11,9 +11,9 @@ import chromadb
class ChromaStore:
"""If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange."""
def __init__(self, name):
def __init__(self, name: str, get_or_create: bool = False):
client = chromadb.Client()
collection = client.create_collection(name)
collection = client.create_collection(name, get_or_create=get_or_create)
self.client = client
self.collection = collection

View file

@ -7,10 +7,14 @@
"""
import asyncio
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
import faiss
from llama_index.core import VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import Document, QueryBundle, TextNode
from llama_index.core.storage import StorageContext
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
@ -20,36 +24,50 @@ from metagpt.utils.embedding import get_embedding
class FaissStore(LocalStore):
def __init__(
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or get_embedding()
self.store: VectorStoreIndex
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
def _load(self) -> Optional["VectorStoreIndex"]:
index_file, store_file = self._get_index_and_store_fname()
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
index = load_index_from_storage(storage_context, embed_model=self.embedding)
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
return index
def _write(self, docs, metadatas):
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
return store
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
assert len(docs) == len(metadatas)
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=documents, storage_context=storage_context, embed_model=self.embedding
)
return index
def persist(self):
self.store.save_local(self.raw_data_path.parent, self.fname)
self.store.storage_context.persist(self.cache_dir)
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
retriever = self.store.as_retriever(similarity_top_k=k)
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
else:
return str(sep.join([f"{x.page_content}" for x in rsp]))
return str(sep.join([f"{x.node.text}" for x in rsp]))
async def asearch(self, *args, **kwargs):
return await asyncio.to_thread(self.search, *args, **kwargs)
@ -67,8 +85,12 @@ class FaissStore(LocalStore):
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
"""FIXME: Currently, the store is not updated after adding."""
return self.store.add_texts(texts)
texts_embeds = self.embedding.get_text_embedding_batch(texts)
nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)]
self.store.insert_nodes(nodes)
return []
def delete(self, *args, **kwargs):
"""Currently, langchain does not provide a delete interface."""
"""Currently, faiss does not provide a delete interface."""
raise NotImplementedError

View file

@ -8,8 +8,6 @@ import re
import time
from typing import Any, Iterable
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
@ -17,6 +15,7 @@ from metagpt.environment.base_env import Environment
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
from metagpt.logs import logger
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
@ -48,9 +47,9 @@ class MincraftEnv(Environment, MincraftExtEnv):
runtime_status: bool = False # equal to action execution status: success or failed
vectordb: Chroma = Field(default_factory=Chroma)
vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
@property
def progress(self):
@ -73,16 +72,14 @@ class MincraftEnv(Environment, MincraftExtEnv):
self.set_mc_resume()
def set_mc_resume(self):
self.qa_cache_questions_vectordb = Chroma(
self.qa_cache_questions_vectordb = ChromaVectorStore(
collection_name="qa_cache_questions_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb",
persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb",
)
self.vectordb = Chroma(
self.vectordb = ChromaVectorStore(
collection_name="skill_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/skill/vectordb",
persist_dir=f"{MC_CKPT_DIR}/skill/vectordb",
)
if CONFIG.resume:

View file

@ -29,16 +29,14 @@ class LongTermMemory(Memory):
msg_from_recover: bool = False
def recover_memory(self, role_id: str, rc: RoleContext):
messages = self.memory_storage.recover_memory(role_id)
self.memory_storage.recover_memory(role_id)
self.rc = rc
if not self.memory_storage.is_initialized:
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty")
else:
logger.warning(
f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them."
)
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
self.msg_from_recover = True
self.add_batch(messages)
# self.add_batch(messages) # TODO no need
self.msg_from_recover = False
def add(self, message: Message):
@ -49,7 +47,7 @@ class LongTermMemory(Memory):
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)
def find_news(self, observed: list[Message], k=0) -> list[Message]:
async def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
1. find the short-term memory(stm) news
@ -63,11 +61,14 @@ class LongTermMemory(Memory):
ltm_news: list[Message] = []
for mem in stm_news:
# filter out messages similar to those seen previously in ltm, only keep fresh news
mem_searched = self.memory_storage.search_dissimilar(mem)
if len(mem_searched) > 0:
mem_searched = await self.memory_storage.search_similar(mem)
if len(mem_searched) == 0:
ltm_news.append(mem)
return ltm_news[-k:]
def persist(self):
self.memory_storage.persist()
def delete(self, message: Message):
super().delete(message)
# TODO delete message in memory_storage

View file

@ -3,115 +3,75 @@
"""
@Desc : the implement of memory storage
"""
import shutil
from pathlib import Path
from typing import Optional
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from llama_index.core.embeddings import BaseEmbedding
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig
from metagpt.schema import Message
from metagpt.utils.embedding import get_embedding
from metagpt.utils.serialize import deserialize_message, serialize_message
class MemoryStorage(FaissStore):
class MemoryStorage(object):
"""
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or get_embedding()
self.store: FAISS = None # Faiss engine
self.faiss_engine = None
@property
def is_initialized(self) -> bool:
return self._initialized
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.cache_dir = self.role_mem_path
self.store = self._load()
messages = []
if not self.store:
# TODO init `self.store` under here with raw faiss api instead under `add`
pass
if self.role_mem_path.joinpath("default__vector_store.json").exists():
self.faiss_engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.cache_dir),
retriever_configs=[FAISSRetrieverConfig()],
embed_model=self.embedding,
)
else:
for _id, document in self.store.docstore._dict.items():
messages.append(deserialize_message(document.metadata.get("message_ser")))
self._initialized = True
return messages
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
self.faiss_engine = SimpleEngine.from_objs(
objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding
)
self._initialized = True
def add(self, message: Message) -> bool:
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
self.faiss_engine.add_objs([message])
logger.info(f"Role {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []
resp = self.store.similarity_search_with_score(query=message.content, k=k)
async def search_similar(self, message: Message, k=4) -> list[Message]:
"""search for similar messages"""
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:
# the smaller score means more similar relation
if score < self.threshold:
continue
# convert search result into Memory
metadata = item.metadata
new_mem = deserialize_message(metadata.get("message_ser"))
filtered_resp.append(new_mem)
resp = await self.faiss_engine.aretrieve(message.content)
for item in resp:
if item.score < self.threshold:
filtered_resp.append(item.metadata.get("obj"))
return filtered_resp
def clean(self):
index_fpath, storage_fpath = self._get_index_and_store_fname()
if index_fpath and index_fpath.exists():
index_fpath.unlink(missing_ok=True)
if storage_fpath and storage_fpath.exists():
storage_fpath.unlink(missing_ok=True)
self.store = None
shutil.rmtree(self.cache_dir, ignore_errors=True)
self._initialized = False
def persist(self):
if self.faiss_engine:
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)

0
metagpt/rag/__init__.py Normal file
View file

View file

@ -0,0 +1,5 @@
"""Engines init"""
from metagpt.rag.engines.simple import SimpleEngine
__all__ = ["SimpleEngine"]

View file

@ -0,0 +1,259 @@
"""Simple Engine."""
import json
import os
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion.pipeline import run_transformations
from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
)
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import (
BaseNode,
Document,
NodeWithScore,
QueryBundle,
QueryType,
TransformComponent,
)
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
get_rag_llm,
get_rankers,
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
)
from metagpt.utils.common import import_class
class SimpleEngine(RetrieverQueryEngine):
"""SimpleEngine is designed to be simple and straightforward.
It is a lightweight and easy-to-use search engine that integrates
document reading, embedding, indexing, retrieving, and ranking functionalities
into a single, straightforward workflow. It is designed to quickly set up a
search engine from a collection of documents.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
) -> None:
super().__init__(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
@classmethod
def from_docs(
cls,
input_dir: str = None,
input_files: list[str] = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""From docs.
Must provide either `input_dir` or `input_files`.
Args:
input_dir: Path to the directory.
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
cls,
objs: Optional[list[RAGObject]] = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""From objs.
Args:
objs: List of RAGObject.
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
cls,
index_config: BaseIndexConfig,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""Load from previously maintained index by self.persist(), index_config contains persis_path."""
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func."""
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
"""Persist."""
self._ensure_retriever_persistable()
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_index(
cls,
index: BaseIndex,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
self._ensure_retriever_of_type(ModifiableRAGRetriever)
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
Args:
required_type: The class that the retriever is expected to be an instance of.
"""
if isinstance(self.retriever, SimpleHybridRetriever):
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
raise TypeError(
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
)
if not isinstance(self.retriever, required_type):
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
def _save_nodes(self, nodes: list[BaseNode]):
self.retriever.add_nodes(nodes)
def _persist(self, persist_dir: str, **kwargs):
self.retriever.persist(persist_dir, **kwargs)
@staticmethod
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
for node in nodes:
if node.metadata.get("is_obj", False):
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
obj_dict = json.loads(node.metadata["obj_json"])
node.metadata["obj"] = obj_cls(**obj_dict)
@staticmethod
def _fix_document_metadata(documents: list[Document]):
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
for doc in documents:
doc.excluded_embed_metadata_keys.append("file_path")
@staticmethod
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
if configs and all(isinstance(c, NoEmbedding) for c in configs):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()

View file

@ -0,0 +1,9 @@
"""RAG factories"""
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.embedding import get_rag_embedding
from metagpt.rag.factories.index import get_index
from metagpt.rag.factories.llm import get_rag_llm
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"]

View file

@ -0,0 +1,59 @@
"""Base Factory."""
from typing import Any, Callable
class GenericFactory:
"""Designed to get objects based on any keys."""
def __init__(self, creators: dict[Any, Callable] = None):
"""Creators is a dictionary.
Keys are identifiers, and the values are the associated creator function, which create objects.
"""
self._creators = creators or {}
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
"""Get instances by keys."""
return [self.get_instance(key, **kwargs) for key in keys]
def get_instance(self, key: Any, **kwargs) -> Any:
"""Get instance by key.
Raise Exception if key not found.
"""
creator = self._creators.get(key)
if creator:
return creator(**kwargs)
raise ValueError(f"Creator not registered for key: {key}")
class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
Call func by the type of key, and the key will be passed to func.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
raise ValueError(f"Unknown config: {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
return val
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)

View file

@ -0,0 +1,37 @@
"""RAG Embedding Factory."""
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex Embedding with MetaGPT's config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
def _create_azure(self):
return AzureOpenAIEmbedding(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
)
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding

View file

@ -0,0 +1,63 @@
"""RAG Index Factory."""
import chromadb
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(ConfigBasedFactory):
def __init__(self):
creators = {
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
}
super().__init__(creators)
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
"""Key is PersistType."""
return super().get_instance(config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
get_index = RAGIndexFactory().get_index

View file

@ -0,0 +1,54 @@
"""RAG LLM."""
from typing import Any
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
from llama_index.core.llms import (
CompletionResponse,
CompletionResponseGen,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
from metagpt.utils.token_counter import TOKEN_MAX
class RAGLLM(CustomLLM):
"""LlamaIndex's LLM is different from MetaGPT's LLM.
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
"""
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
num_output: int = config.llm.max_token
model_name: str = config.llm.model
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
text = await self.model_infer.aask(msg=prompt, stream=False)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
...
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
"""Get llm that can be used by LlamaIndex."""
return RAGLLM(model_infer=model_infer or LLM())

View file

@ -0,0 +1,35 @@
"""RAG Ranker Factory."""
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
class RankerFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return []
return super().get_instances(configs, **kwargs)
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)
get_rankers = RankerFactory().get_rankers

View file

@ -0,0 +1,86 @@
"""RAG Retriever Factory."""
import copy
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
}
super().__init__(creators)
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations.
If multiple retrievers, using SimpleHybridRetriever.
"""
if not configs:
return self._create_default(**kwargs)
retrievers = super().get_instances(configs, **kwargs)
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
storage_context=storage_context,
embed_model=old_index._embed_model,
)
return new_index
get_retriever = RetrieverFactory().get_retriever

24
metagpt/rag/interface.py Normal file
View file

@ -0,0 +1,24 @@
"""RAG Interfaces."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class RAGObject(Protocol):
"""Support rag add object."""
def rag_key(self) -> str:
"""For rag search."""
def model_dump_json(self) -> str:
"""For rag persist.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
"""
@runtime_checkable
class NoEmbedding(Protocol):
"""Some retriever does not require embeddings, e.g. BM25"""
_no_embedding: bool

View file

@ -0,0 +1 @@
"""Rankers init"""

View file

@ -0,0 +1,19 @@
"""Base Ranker."""
from abc import abstractmethod
from typing import Optional
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
class RAGRanker(BaseNodePostprocessor):
"""inherit from llama_index"""
@abstractmethod
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> list[NodeWithScore]:
"""postprocess nodes."""

View file

@ -0,0 +1,5 @@
"""Retrievers init."""
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
__all__ = ["SimpleHybridRetriever"]

View file

@ -0,0 +1,47 @@
"""Base retriever."""
from abc import abstractmethod
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
from metagpt.utils.reflection import check_methods
class RAGRetriever(BaseRetriever):
"""Inherit from llama_index"""
@abstractmethod
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
class ModifiableRAGRetriever(RAGRetriever):
"""Support modification."""
@classmethod
def __subclasshook__(cls, C):
if cls is ModifiableRAGRetriever:
return check_methods(C, "add_nodes")
return NotImplemented
@abstractmethod
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""
class PersistableRAGRetriever(RAGRetriever):
"""Support persistent."""
@classmethod
def __subclasshook__(cls, C):
if cls is PersistableRAGRetriever:
return check_methods(C, "persist")
return NotImplemented
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""

View file

@ -0,0 +1,47 @@
"""BM25 retriever."""
from typing import Callable, Optional
from llama_index.core import VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.schema import BaseNode, IndexNode
from llama_index.retrievers.bm25 import BM25Retriever
from rank_bm25 import BM25Okapi
class DynamicBM25Retriever(BM25Retriever):
"""BM25 retriever."""
def __init__(
self,
nodes: list[BaseNode],
tokenizer: Optional[Callable[[str], list[str]]] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[list[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
index: VectorStoreIndex = None,
) -> None:
super().__init__(
nodes=nodes,
tokenizer=tokenizer,
similarity_top_k=similarity_top_k,
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
self._index = index
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -0,0 +1,17 @@
"""Chroma retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class ChromaRetriever(VectorIndexRetriever):
"""Chroma retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""

View file

@ -0,0 +1,16 @@
"""FAISS retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -0,0 +1,48 @@
"""Hybrid retriever."""
import copy
from llama_index.core.schema import BaseNode, QueryType
from metagpt.rag.retrievers.base import RAGRetriever
class SimpleHybridRetriever(RAGRetriever):
"""A composite retriever that aggregates search results from multiple retrievers."""
def __init__(self, *retrievers):
self.retrievers: list[RAGRetriever] = retrievers
super().__init__()
async def _aretrieve(self, query: QueryType, **kwargs):
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
This method queries each retriever in the `retrievers` list with the given query and
additional keyword arguments. It then combines the results, ensuring that each node is
unique, based on the node's ID.
"""
all_nodes = []
for retriever in self.retrievers:
# Prevent retriever changing query
query_copy = copy.deepcopy(query)
nodes = await retriever.aretrieve(query_copy, **kwargs)
all_nodes.extend(nodes)
# combine all nodes
result = []
node_ids = set()
for n in all_nodes:
if n.node.node_id not in node_ids:
result.append(n)
node_ids.add(n.node.node_id)
return result
def add_nodes(self, nodes: list[BaseNode]) -> None:
"""Support add nodes."""
for r in self.retrievers:
r.add_nodes(nodes)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
for r in self.retrievers:
r.persist(persist_dir, **kwargs)

124
metagpt/rag/schema.py Normal file
View file

@ -0,0 +1,124 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver.")
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
_no_embedding: bool = PrivateAttr(default=True)
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
class LLMRankerConfig(BaseRankerConfig):
"""Config for LLM-based rankers."""
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
)
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
class VectorIndexConfig(BaseIndexConfig):
"""Config for vector-based index."""
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
class FAISSIndexConfig(VectorIndexConfig):
"""Config for faiss-based index."""
class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()

View file

View file

@ -0,0 +1,3 @@
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
__all__ = ["ChromaVectorStore"]

View file

@ -0,0 +1,290 @@
"""Chroma vector store.
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
"""
import math
from typing import Any, Dict, Generator, List, Optional, cast
import chromadb
from chromadb.api.models.Collection import Collection
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.utils import truncate_text
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
from metagpt.logs import logger
def _transform_chroma_filter_condition(condition: str) -> str:
"""Translate standard metadata filter op to Chroma specific spec."""
if condition == "and":
return "$and"
elif condition == "or":
return "$or"
else:
raise ValueError(f"Filter condition {condition} not supported")
def _transform_chroma_filter_operator(operator: str) -> str:
"""Translate standard metadata filter operator to Chroma specific spec."""
if operator == "!=":
return "$ne"
elif operator == "==":
return "$eq"
elif operator == ">":
return "$gt"
elif operator == "<":
return "$lt"
elif operator == ">=":
return "$gte"
elif operator == "<=":
return "$lte"
else:
raise ValueError(f"Filter operator {operator} not supported")
def _to_chroma_filter(
standard_filters: MetadataFilters,
) -> dict:
"""Translate standard metadata filters to Chroma specific spec."""
filters = {}
filters_list = []
condition = standard_filters.condition or "and"
condition = _transform_chroma_filter_condition(condition)
if standard_filters.filters:
for filter in standard_filters.filters:
if filter.operator:
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
else:
filters_list.append({filter.key: filter.value})
if len(filters_list) == 1:
# If there is only one filter, return it directly
return filters_list[0]
elif len(filters_list) > 1:
filters[condition] = filters_list
return filters
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
stores_text: bool = True
flat_metadata: bool = True
collection_name: Optional[str]
host: Optional[str]
port: Optional[str]
ssl: bool
headers: Optional[Dict[str, str]]
persist_dir: Optional[str]
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
_collection: Any = PrivateAttr()
def __init__(
self,
chroma_collection: Optional[Any] = None,
collection_name: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Init params."""
collection_kwargs = collection_kwargs or {}
if chroma_collection is None:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
self._collection = cast(Collection, chroma_collection)
super().__init__(
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_name=collection_name,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs or {},
)
@classmethod
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
try:
from chromadb import Collection
except ImportError:
raise ImportError(import_err_msg)
if not isinstance(collection, Collection):
raise Exception("argument is not chromadb collection instance")
return cls(chroma_collection=collection)
@classmethod
def from_params(
cls,
collection_name: str,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: dict = {},
**kwargs: Any,
) -> "ChromaVectorStore":
if persist_dir:
client = chromadb.PersistentClient(path=persist_dir)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
elif host and port:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
return cls(
chroma_collection=collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "ChromaVectorStore"
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
for key in metadata_dict:
if metadata_dict[key] is None:
metadata_dict[key] = ""
metadatas.append(metadata_dict)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._collection.delete(where={"document_id": ref_doc_id})
@property
def client(self) -> Any:
"""Return client."""
return self._collection
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify metadata filters via both query and kwargs. "
"Use kwargs only for chroma specific items that are "
"not supported via the generic query interface."
)
where = _to_chroma_filter(query.filters)
else:
where = kwargs.pop("where", {})
results = self._collection.query(
query_embeddings=query.query_embedding,
n_results=query.similarity_top_k,
where=where,
**kwargs,
)
logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
ids = []
for node_id, text, metadata, distance in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
similarity_score = math.exp(-distance)
similarities.append(similarity_score)
logger.debug(
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
)
ids.append(node_id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

View file

@ -108,12 +108,6 @@ class RoleContext(BaseModel):
) # see `Role._set_react_mode` for definitions of the following two attributes
max_react_loop: int = 1
def check(self, role_id: str):
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
# self.long_term_memory.recover_memory(role_id, self)
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
pass
@property
def important_memory(self) -> list[Message]:
"""Retrieve information corresponding to the attention action."""
@ -311,8 +305,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
buffer during _observe.
"""
self.rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self.rc.check(self.role_id)
def is_watch(self, caused_by: str):
return caused_by in self.rc.watch

View file

@ -11,7 +11,6 @@ from typing import Optional
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize, UserRequirement
from metagpt.document_store.base_store import BaseStore
from metagpt.roles import Role
from metagpt.tools.search_engine import SearchEngine
@ -27,7 +26,7 @@ class Sales(Role):
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
)
store: Optional[BaseStore] = Field(default=None, exclude=True)
store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface
@model_validator(mode="after")
def validate_stroe(self):

View file

@ -233,6 +233,10 @@ class Message(BaseModel):
def check_send_to(cls, send_to: Any) -> set:
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("send_to", mode="plain")
def ser_send_to(self, send_to: set) -> list:
return list(send_to)
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
ic_dict = None
@ -276,6 +280,10 @@ class Message(BaseModel):
def __repr__(self):
return self.__str__()
def rag_key(self) -> str:
"""For search"""
return self.content
def to_dict(self) -> dict:
"""Return a dict containing `role` and `content` for the LLM call.l"""
return {"role": self.role, "content": self.content}

View file

@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum):
def __missing__(cls, key):
"""Default type conversion"""
return cls.CUSTOM
class SearchInterface:
async def asearch(self, *args, **kwargs):
...

View file

@ -0,0 +1,22 @@
import asyncio
import threading
from typing import Any
def run_coroutine_in_new_loop(coroutine) -> Any:
"""Runs a coroutine in a new, separate event loop on a different thread.
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
"""
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=lambda: new_loop.run_forever())
t.start()
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
try:
return future.result()
finally:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()

View file

@ -5,12 +5,15 @@
@Author : alexanderwu
@File : embedding.py
"""
from langchain_community.embeddings import OpenAIEmbeddings
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
def get_embedding():
def get_embedding() -> OpenAIEmbedding:
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
if llm is None:
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
return embedding

View file

@ -0,0 +1,18 @@
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
def check_methods(C, *methods):
"""Check if the class has methods. borrow from _collections_abc.
Useful when implementing implicit interfaces, such as defining an abstract class, isinstance can be used for determination without inheritance.
"""
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True