mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
commit
e783e5b208
61 changed files with 2353 additions and 248 deletions
|
|
@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
|||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
|
|
|
|||
|
|
@ -11,12 +11,9 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from llama_index.core import Document, SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.readers.file import PDFReader
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -29,7 +26,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
|
|||
raise ValueError("Content column not found in DataFrame.")
|
||||
|
||||
|
||||
def read_data(data_path: Path):
|
||||
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
|
||||
suffix = data_path.suffix
|
||||
if ".xlsx" == suffix:
|
||||
data = pd.read_excel(data_path)
|
||||
|
|
@ -38,14 +35,13 @@ def read_data(data_path: Path):
|
|||
elif ".json" == suffix:
|
||||
data = pd.read_json(data_path)
|
||||
elif suffix in (".docx", ".doc"):
|
||||
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
elif ".txt" == suffix:
|
||||
data = TextLoader(str(data_path)).load()
|
||||
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(data)
|
||||
data = texts
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
data = node_parser.get_nodes_from_documents(data)
|
||||
elif ".pdf" == suffix:
|
||||
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
|
||||
data = PDFReader.load_data(str(data_path))
|
||||
else:
|
||||
raise NotImplementedError("File format not supported.")
|
||||
return data
|
||||
|
|
@ -150,9 +146,9 @@ class IndexableDocument(Document):
|
|||
metadatas.append({})
|
||||
return docs, metadatas
|
||||
|
||||
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
|
||||
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
|
||||
data = self.data
|
||||
docs = [i.page_content for i in data]
|
||||
docs = [i.text for i in data]
|
||||
metadatas = [i.metadata for i in data]
|
||||
return docs, metadatas
|
||||
|
||||
|
|
@ -160,7 +156,7 @@ class IndexableDocument(Document):
|
|||
if isinstance(self.data, pd.DataFrame):
|
||||
return self._get_docs_and_metadatas_by_df()
|
||||
elif isinstance(self.data, list):
|
||||
return self._get_docs_and_metadatas_by_langchain()
|
||||
return self._get_docs_and_metadatas_by_llamaindex()
|
||||
else:
|
||||
raise NotImplementedError("Data type not supported for metadata extraction.")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC):
|
|||
if not self.store:
|
||||
self.store = self.write()
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
index_file = self.cache_dir / f"{self.fname}{index_ext}"
|
||||
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
|
||||
def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"):
|
||||
index_file = self.cache_dir / "default__vector_store" / index_ext
|
||||
store_file = self.cache_dir / "docstore" / docstore_ext
|
||||
return index_file, store_file
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ import chromadb
|
|||
class ChromaStore:
|
||||
"""If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange."""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str, get_or_create: bool = False):
|
||||
client = chromadb.Client()
|
||||
collection = client.create_collection(name)
|
||||
collection = client.create_collection(name, get_or_create=get_or_create)
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,14 @@
|
|||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
import faiss
|
||||
from llama_index.core import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import Document, QueryBundle, TextNode
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
|
|
@ -20,36 +24,50 @@ from metagpt.utils.embedding import get_embedding
|
|||
|
||||
class FaissStore(LocalStore):
|
||||
def __init__(
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
|
||||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: VectorStoreIndex
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
def _load(self) -> Optional["VectorStoreIndex"]:
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
|
||||
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
|
||||
index = load_index_from_storage(storage_context, embed_model=self.embedding)
|
||||
|
||||
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
|
||||
return index
|
||||
|
||||
def _write(self, docs, metadatas):
|
||||
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
|
||||
return store
|
||||
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
|
||||
assert len(docs) == len(metadatas)
|
||||
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents, storage_context=storage_context, embed_model=self.embedding
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.raw_data_path.parent, self.fname)
|
||||
self.store.storage_context.persist(self.cache_dir)
|
||||
|
||||
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
retriever = self.store.as_retriever(similarity_top_k=k)
|
||||
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
|
||||
|
||||
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
rsp = self.store.similarity_search(query, k=k, **kwargs)
|
||||
logger.debug(rsp)
|
||||
if expand_cols:
|
||||
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
|
||||
else:
|
||||
return str(sep.join([f"{x.page_content}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}" for x in rsp]))
|
||||
|
||||
async def asearch(self, *args, **kwargs):
|
||||
return await asyncio.to_thread(self.search, *args, **kwargs)
|
||||
|
|
@ -67,8 +85,12 @@ class FaissStore(LocalStore):
|
|||
|
||||
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
|
||||
"""FIXME: Currently, the store is not updated after adding."""
|
||||
return self.store.add_texts(texts)
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(texts)
|
||||
nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)]
|
||||
self.store.insert_nodes(nodes)
|
||||
|
||||
return []
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
"""Currently, langchain does not provide a delete interface."""
|
||||
"""Currently, faiss does not provide a delete interface."""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ import re
|
|||
import time
|
||||
from typing import Any, Iterable
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
|
|
@ -17,6 +15,7 @@ from metagpt.environment.base_env import Environment
|
|||
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
|
||||
|
||||
|
||||
|
|
@ -48,9 +47,9 @@ class MincraftEnv(Environment, MincraftExtEnv):
|
|||
|
||||
runtime_status: bool = False # equal to action execution status: success or failed
|
||||
|
||||
vectordb: Chroma = Field(default_factory=Chroma)
|
||||
vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
|
||||
|
||||
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
|
||||
qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
|
||||
|
||||
@property
|
||||
def progress(self):
|
||||
|
|
@ -73,16 +72,14 @@ class MincraftEnv(Environment, MincraftExtEnv):
|
|||
self.set_mc_resume()
|
||||
|
||||
def set_mc_resume(self):
|
||||
self.qa_cache_questions_vectordb = Chroma(
|
||||
self.qa_cache_questions_vectordb = ChromaVectorStore(
|
||||
collection_name="qa_cache_questions_vectordb",
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb",
|
||||
persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb",
|
||||
)
|
||||
|
||||
self.vectordb = Chroma(
|
||||
self.vectordb = ChromaVectorStore(
|
||||
collection_name="skill_vectordb",
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory=f"{MC_CKPT_DIR}/skill/vectordb",
|
||||
persist_dir=f"{MC_CKPT_DIR}/skill/vectordb",
|
||||
)
|
||||
|
||||
if CONFIG.resume:
|
||||
|
|
|
|||
|
|
@ -29,16 +29,14 @@ class LongTermMemory(Memory):
|
|||
msg_from_recover: bool = False
|
||||
|
||||
def recover_memory(self, role_id: str, rc: RoleContext):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
|
||||
logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them."
|
||||
)
|
||||
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
|
||||
self.msg_from_recover = True
|
||||
self.add_batch(messages)
|
||||
# self.add_batch(messages) # TODO no need
|
||||
self.msg_from_recover = False
|
||||
|
||||
def add(self, message: Message):
|
||||
|
|
@ -49,7 +47,7 @@ class LongTermMemory(Memory):
|
|||
# and ignore adding messages from recover repeatedly
|
||||
self.memory_storage.add(message)
|
||||
|
||||
def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
async def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
"""
|
||||
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
|
||||
1. find the short-term memory(stm) news
|
||||
|
|
@ -63,11 +61,14 @@ class LongTermMemory(Memory):
|
|||
ltm_news: list[Message] = []
|
||||
for mem in stm_news:
|
||||
# filter out messages similar to those seen previously in ltm, only keep fresh news
|
||||
mem_searched = self.memory_storage.search_dissimilar(mem)
|
||||
if len(mem_searched) > 0:
|
||||
mem_searched = await self.memory_storage.search_similar(mem)
|
||||
if len(mem_searched) == 0:
|
||||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
||||
def persist(self):
|
||||
self.memory_storage.persist()
|
||||
|
||||
def delete(self, message: Message):
|
||||
super().delete(message)
|
||||
# TODO delete message in memory_storage
|
||||
|
|
|
|||
|
|
@ -3,115 +3,75 @@
|
|||
"""
|
||||
@Desc : the implement of memory storage
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
|
||||
|
||||
class MemoryStorage(FaissStore):
|
||||
class MemoryStorage(object):
|
||||
"""
|
||||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
self.faiss_engine = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir = self.role_mem_path
|
||||
|
||||
self.store = self._load()
|
||||
messages = []
|
||||
if not self.store:
|
||||
# TODO init `self.store` under here with raw faiss api instead under `add`
|
||||
pass
|
||||
if self.role_mem_path.joinpath("default__vector_store.json").exists():
|
||||
self.faiss_engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=self.cache_dir),
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
embed_model=self.embedding,
|
||||
)
|
||||
else:
|
||||
for _id, document in self.store.docstore._dict.items():
|
||||
messages.append(deserialize_message(document.metadata.get("message_ser")))
|
||||
self._initialized = True
|
||||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
self.faiss_engine = SimpleEngine.from_objs(
|
||||
objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
"""add message into memory storage"""
|
||||
docs = [message.content]
|
||||
metadatas = [{"message_ser": serialize_message(message)}]
|
||||
if not self.store:
|
||||
# init Faiss
|
||||
self.store = self._write(docs, metadatas)
|
||||
self._initialized = True
|
||||
else:
|
||||
self.store.add_texts(texts=docs, metadatas=metadatas)
|
||||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
self.faiss_engine.add_objs([message])
|
||||
logger.info(f"Role {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
||||
resp = self.store.similarity_search_with_score(query=message.content, k=k)
|
||||
async def search_similar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for similar messages"""
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
for item, score in resp:
|
||||
# the smaller score means more similar relation
|
||||
if score < self.threshold:
|
||||
continue
|
||||
# convert search result into Memory
|
||||
metadata = item.metadata
|
||||
new_mem = deserialize_message(metadata.get("message_ser"))
|
||||
filtered_resp.append(new_mem)
|
||||
resp = await self.faiss_engine.aretrieve(message.content)
|
||||
for item in resp:
|
||||
if item.score < self.threshold:
|
||||
filtered_resp.append(item.metadata.get("obj"))
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
index_fpath, storage_fpath = self._get_index_and_store_fname()
|
||||
if index_fpath and index_fpath.exists():
|
||||
index_fpath.unlink(missing_ok=True)
|
||||
if storage_fpath and storage_fpath.exists():
|
||||
storage_fpath.unlink(missing_ok=True)
|
||||
|
||||
self.store = None
|
||||
shutil.rmtree(self.cache_dir, ignore_errors=True)
|
||||
self._initialized = False
|
||||
|
||||
def persist(self):
|
||||
if self.faiss_engine:
|
||||
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)
|
||||
|
|
|
|||
0
metagpt/rag/__init__.py
Normal file
0
metagpt/rag/__init__.py
Normal file
5
metagpt/rag/engines/__init__.py
Normal file
5
metagpt/rag/engines/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Engines init"""
|
||||
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
259
metagpt/rag/engines/simple.py
Normal file
259
metagpt/rag/engines/simple.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""Simple Engine."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion.pipeline import run_transformations
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
)
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import (
|
||||
BaseNode,
|
||||
Document,
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
get_rag_llm,
|
||||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""SimpleEngine is designed to be simple and straightforward.
|
||||
|
||||
It is a lightweight and easy-to-use search engine that integrates
|
||||
document reading, embedding, indexing, retrieving, and ranking functionalities
|
||||
into a single, straightforward workflow. It is designed to quickly set up a
|
||||
search engine from a collection of documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
response_synthesizer=response_synthesizer,
|
||||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
input_dir: str = None,
|
||||
input_files: list[str] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From docs.
|
||||
|
||||
Must provide either `input_dir` or `input_files`.
|
||||
|
||||
Args:
|
||||
input_dir: Path to the directory.
|
||||
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
cls,
|
||||
objs: Optional[list[RAGObject]] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From objs.
|
||||
|
||||
Args:
|
||||
objs: List of RAGObject.
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
cls,
|
||||
index_config: BaseIndexConfig,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""Load from previously maintained index by self.persist(), index_config contains persis_path."""
|
||||
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
||||
nodes = await super().aretrieve(query_bundle)
|
||||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever. retriever must has add_nodes func."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
|
||||
"""Persist."""
|
||||
self._ensure_retriever_persistable()
|
||||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
index: BaseIndex,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
self._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_persistable(self):
|
||||
self._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
Args:
|
||||
required_type: The class that the retriever is expected to be an instance of.
|
||||
"""
|
||||
if isinstance(self.retriever, SimpleHybridRetriever):
|
||||
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
|
||||
raise TypeError(
|
||||
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
|
||||
)
|
||||
|
||||
if not isinstance(self.retriever, required_type):
|
||||
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
|
||||
|
||||
def _save_nodes(self, nodes: list[BaseNode]):
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
def _persist(self, persist_dir: str, **kwargs):
|
||||
self.retriever.persist(persist_dir, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
|
||||
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
|
||||
for node in nodes:
|
||||
if node.metadata.get("is_obj", False):
|
||||
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
|
||||
obj_dict = json.loads(node.metadata["obj_json"])
|
||||
node.metadata["obj"] = obj_cls(**obj_dict)
|
||||
|
||||
@staticmethod
|
||||
def _fix_document_metadata(documents: list[Document]):
|
||||
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
|
||||
for doc in documents:
|
||||
doc.excluded_embed_metadata_keys.append("file_path")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
|
||||
if configs and all(isinstance(c, NoEmbedding) for c in configs):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
9
metagpt/rag/factories/__init__.py
Normal file
9
metagpt/rag/factories/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""RAG factories"""
|
||||
|
||||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.embedding import get_rag_embedding
|
||||
from metagpt.rag.factories.index import get_index
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"]
|
||||
59
metagpt/rag/factories/base.py
Normal file
59
metagpt/rag/factories/base.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Base Factory."""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class GenericFactory:
|
||||
"""Designed to get objects based on any keys."""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable] = None):
|
||||
"""Creators is a dictionary.
|
||||
|
||||
Keys are identifiers, and the values are the associated creator function, which create objects.
|
||||
"""
|
||||
self._creators = creators or {}
|
||||
|
||||
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
|
||||
"""Get instances by keys."""
|
||||
return [self.get_instance(key, **kwargs) for key in keys]
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Get instance by key.
|
||||
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(key)
|
||||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
class ConfigBasedFactory(GenericFactory):
|
||||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
37
metagpt/rag/factories/embedding.py
Normal file
37
metagpt/rag/factories/embedding.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""RAG Embedding Factory."""
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
)
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
63
metagpt/rag/factories/index.py
Normal file
63
metagpt/rag/factories/index.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""RAG Index Factory."""
|
||||
|
||||
import chromadb
|
||||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RAGIndexFactory(ConfigBasedFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
|
||||
"""Key is PersistType."""
|
||||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
|
||||
get_index = RAGIndexFactory().get_index
|
||||
54
metagpt/rag/factories/llm.py
Normal file
54
metagpt/rag/factories/llm.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
from llama_index.core.llms import (
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
class RAGLLM(CustomLLM):
|
||||
"""LlamaIndex's LLM is different from MetaGPT's LLM.
|
||||
|
||||
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
|
||||
"""
|
||||
|
||||
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
|
||||
context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
num_output: int = config.llm.max_token
|
||||
model_name: str = config.llm.model
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
text = await self.model_infer.aask(msg=prompt, stream=False)
|
||||
return CompletionResponse(text=text)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
...
|
||||
|
||||
|
||||
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
|
||||
"""Get llm that can be used by LlamaIndex."""
|
||||
return RAGLLM(model_infer=model_infer or LLM())
|
||||
35
metagpt/rag/factories/ranker.py
Normal file
35
metagpt/rag/factories/ranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""RAG Ranker Factory."""
|
||||
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
|
||||
|
||||
|
||||
class RankerFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return []
|
||||
|
||||
return super().get_instances(configs, **kwargs)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
|
||||
config.llm = self._extract_llm(config, **kwargs)
|
||||
return LLMRerank(**config.model_dump())
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
86
metagpt/rag/factories/retriever.py
Normal file
86
metagpt/rag/factories/retriever.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations.
|
||||
|
||||
If multiple retrievers, using SimpleHybridRetriever.
|
||||
"""
|
||||
if not configs:
|
||||
return self._create_default(**kwargs)
|
||||
|
||||
retrievers = super().get_instances(configs, **kwargs)
|
||||
|
||||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
)
|
||||
return new_index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
24
metagpt/rag/interface.py
Normal file
24
metagpt/rag/interface.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""RAG Interfaces."""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class RAGObject(Protocol):
|
||||
"""Support rag add object."""
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For rag search."""
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
"""For rag persist.
|
||||
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class NoEmbedding(Protocol):
|
||||
"""Some retriever does not require embeddings, e.g. BM25"""
|
||||
|
||||
_no_embedding: bool
|
||||
1
metagpt/rag/rankers/__init__.py
Normal file
1
metagpt/rag/rankers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Rankers init"""
|
||||
19
metagpt/rag/rankers/base.py
Normal file
19
metagpt/rag/rankers/base.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Base Ranker."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
|
||||
class RAGRanker(BaseNodePostprocessor):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""postprocess nodes."""
|
||||
5
metagpt/rag/retrievers/__init__.py
Normal file
5
metagpt/rag/retrievers/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Retrievers init."""
|
||||
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
47
metagpt/rag/retrievers/base.py
Normal file
47
metagpt/rag/retrievers/base.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Base retriever."""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
|
||||
|
||||
from metagpt.utils.reflection import check_methods
|
||||
|
||||
|
||||
class RAGRetriever(BaseRetriever):
|
||||
"""Inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
|
||||
class ModifiableRAGRetriever(RAGRetriever):
|
||||
"""Support modification."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is ModifiableRAGRetriever:
|
||||
return check_methods(C, "add_nodes")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
||||
|
||||
class PersistableRAGRetriever(RAGRetriever):
|
||||
"""Support persistent."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is PersistableRAGRetriever:
|
||||
return check_methods(C, "persist")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""To support persist, must inplement this func"""
|
||||
47
metagpt/rag/retrievers/bm25_retriever.py
Normal file
47
metagpt/rag/retrievers/bm25_retriever.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""BM25 retriever."""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.schema import BaseNode, IndexNode
|
||||
from llama_index.retrievers.bm25 import BM25Retriever
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
|
||||
class DynamicBM25Retriever(BM25Retriever):
|
||||
"""BM25 retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: list[BaseNode],
|
||||
tokenizer: Optional[Callable[[str], list[str]]] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[list[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
index: VectorStoreIndex = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
nodes=nodes,
|
||||
tokenizer=tokenizer,
|
||||
similarity_top_k=similarity_top_k,
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
self._index = index
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
17
metagpt/rag/retrievers/chroma_retriever.py
Normal file
17
metagpt/rag/retrievers/chroma_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Chroma retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""Chroma retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Chromadb automatically saves, so there is no need to implement."""
|
||||
16
metagpt/rag/retrievers/faiss_retriever.py
Normal file
16
metagpt/rag/retrievers/faiss_retriever.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""FAISS retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class FAISSRetriever(VectorIndexRetriever):
|
||||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
48
metagpt/rag/retrievers/hybrid_retriever.py
Normal file
48
metagpt/rag/retrievers/hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Hybrid retriever."""
|
||||
|
||||
import copy
|
||||
|
||||
from llama_index.core.schema import BaseNode, QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
class SimpleHybridRetriever(RAGRetriever):
|
||||
"""A composite retriever that aggregates search results from multiple retrievers."""
|
||||
|
||||
def __init__(self, *retrievers):
|
||||
self.retrievers: list[RAGRetriever] = retrievers
|
||||
super().__init__()
|
||||
|
||||
async def _aretrieve(self, query: QueryType, **kwargs):
|
||||
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
|
||||
This method queries each retriever in the `retrievers` list with the given query and
|
||||
additional keyword arguments. It then combines the results, ensuring that each node is
|
||||
unique, based on the node's ID.
|
||||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
# Prevent retriever changing query
|
||||
query_copy = copy.deepcopy(query)
|
||||
nodes = await retriever.aretrieve(query_copy, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
# combine all nodes
|
||||
result = []
|
||||
node_ids = set()
|
||||
for n in all_nodes:
|
||||
if n.node.node_id not in node_ids:
|
||||
result.append(n)
|
||||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode]) -> None:
|
||||
"""Support add nodes."""
|
||||
for r in self.retrievers:
|
||||
r.add_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
for r in self.retrievers:
|
||||
r.persist(persist_dir, **kwargs)
|
||||
124
metagpt/rag/schema.py
Normal file
124
metagpt/rag/schema.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
|
||||
|
||||
|
||||
class IndexRetrieverConfig(BaseRetrieverConfig):
|
||||
"""Config for Index-basd retrievers."""
|
||||
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver.")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
class LLMRankerConfig(BaseRankerConfig):
|
||||
"""Config for LLM-based rankers."""
|
||||
|
||||
llm: Any = Field(
|
||||
default=None,
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
|
||||
)
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
class VectorIndexConfig(BaseIndexConfig):
|
||||
"""Config for vector-based index."""
|
||||
|
||||
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
|
||||
|
||||
|
||||
class FAISSIndexConfig(VectorIndexConfig):
|
||||
"""Config for faiss-based index."""
|
||||
|
||||
|
||||
class ChromaIndexConfig(VectorIndexConfig):
|
||||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
|
||||
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
||||
class ObjectNode(TextNode):
|
||||
"""RAG add object."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
|
||||
return metadata.model_dump()
|
||||
0
metagpt/rag/vector_stores/__init__.py
Normal file
0
metagpt/rag/vector_stores/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
|
||||
|
||||
__all__ = ["ChromaVectorStore"]
|
||||
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""Chroma vector store.
|
||||
|
||||
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
|
||||
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Generator, List, Optional, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
|
||||
from llama_index.core.utils import truncate_text
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
MetadataFilters,
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryResult,
|
||||
)
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
legacy_metadata_dict_to_node,
|
||||
metadata_dict_to_node,
|
||||
node_to_metadata_dict,
|
||||
)
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def _transform_chroma_filter_condition(condition: str) -> str:
|
||||
"""Translate standard metadata filter op to Chroma specific spec."""
|
||||
if condition == "and":
|
||||
return "$and"
|
||||
elif condition == "or":
|
||||
return "$or"
|
||||
else:
|
||||
raise ValueError(f"Filter condition {condition} not supported")
|
||||
|
||||
|
||||
def _transform_chroma_filter_operator(operator: str) -> str:
|
||||
"""Translate standard metadata filter operator to Chroma specific spec."""
|
||||
if operator == "!=":
|
||||
return "$ne"
|
||||
elif operator == "==":
|
||||
return "$eq"
|
||||
elif operator == ">":
|
||||
return "$gt"
|
||||
elif operator == "<":
|
||||
return "$lt"
|
||||
elif operator == ">=":
|
||||
return "$gte"
|
||||
elif operator == "<=":
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Filter operator {operator} not supported")
|
||||
|
||||
|
||||
def _to_chroma_filter(
|
||||
standard_filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Translate standard metadata filters to Chroma specific spec."""
|
||||
filters = {}
|
||||
filters_list = []
|
||||
condition = standard_filters.condition or "and"
|
||||
condition = _transform_chroma_filter_condition(condition)
|
||||
if standard_filters.filters:
|
||||
for filter in standard_filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value})
|
||||
if len(filters_list) == 1:
|
||||
# If there is only one filter, return it directly
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
filters[condition] = filters_list
|
||||
return filters
|
||||
|
||||
|
||||
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
|
||||
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
|
||||
|
||||
|
||||
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class ChromaVectorStore(BasePydanticVectorStore):
|
||||
"""Chroma vector store.
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
Args:
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
collection_name: Optional[str]
|
||||
host: Optional[str]
|
||||
port: Optional[str]
|
||||
ssl: bool
|
||||
headers: Optional[Dict[str, str]]
|
||||
persist_dir: Optional[str]
|
||||
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
_collection: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_collection: Optional[Any] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
collection_kwargs = collection_kwargs or {}
|
||||
if chroma_collection is None:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
self._collection = cast(Collection, chroma_collection)
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_name=collection_name,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
|
||||
try:
|
||||
from chromadb import Collection
|
||||
except ImportError:
|
||||
raise ImportError(import_err_msg)
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("argument is not chromadb collection instance")
|
||||
return cls(chroma_collection=collection)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
collection_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: dict = {},
|
||||
**kwargs: Any,
|
||||
) -> "ChromaVectorStore":
|
||||
if persist_dir:
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
elif host and port:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
|
||||
return cls(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ChromaVectorStore"
|
||||
|
||||
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
||||
"""Add nodes to index.
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
"""
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
max_chunk_size = MAX_CHUNK_SIZE
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
|
||||
for key in metadata_dict:
|
||||
if metadata_dict[key] is None:
|
||||
metadata_dict[key] = ""
|
||||
metadatas.append(metadata_dict)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
return all_ids
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""
|
||||
Delete nodes using with ref_doc_id.
|
||||
Args:
|
||||
ref_doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
self._collection.delete(where={"document_id": ref_doc_id})
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Return client."""
|
||||
return self._collection
|
||||
|
||||
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
Args:
|
||||
query_embedding (List[float]): query embedding
|
||||
similarity_top_k (int): top k most similar nodes
|
||||
"""
|
||||
if query.filters is not None:
|
||||
if "where" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify metadata filters via both query and kwargs. "
|
||||
"Use kwargs only for chroma specific items that are "
|
||||
"not supported via the generic query interface."
|
||||
)
|
||||
where = _to_chroma_filter(query.filters)
|
||||
else:
|
||||
where = kwargs.pop("where", {})
|
||||
results = self._collection.query(
|
||||
query_embeddings=query.query_embedding,
|
||||
n_results=query.similarity_top_k,
|
||||
where=where,
|
||||
**kwargs,
|
||||
)
|
||||
logger.debug(f"> Top {len(results['documents'])} nodes:")
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for node_id, text, metadata, distance in zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
):
|
||||
try:
|
||||
node = metadata_dict_to_node(metadata)
|
||||
node.set_content(text)
|
||||
except Exception:
|
||||
# NOTE: deprecated legacy logic for backward compatibility
|
||||
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
|
||||
node = TextNode(
|
||||
text=text,
|
||||
id_=node_id,
|
||||
metadata=metadata,
|
||||
start_char_idx=node_info.get("start", None),
|
||||
end_char_idx=node_info.get("end", None),
|
||||
relationships=relationships,
|
||||
)
|
||||
nodes.append(node)
|
||||
similarity_score = math.exp(-distance)
|
||||
similarities.append(similarity_score)
|
||||
logger.debug(
|
||||
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
|
||||
)
|
||||
ids.append(node_id)
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
|
@ -108,12 +108,6 @@ class RoleContext(BaseModel):
|
|||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
|
||||
pass
|
||||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""Retrieve information corresponding to the attention action."""
|
||||
|
|
@ -311,8 +305,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
buffer during _observe.
|
||||
"""
|
||||
self.rc.watch = {any_to_str(t) for t in actions}
|
||||
# check RoleContext after adding watch actions
|
||||
self.rc.check(self.role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self.rc.watch
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Optional
|
|||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
|
@ -27,7 +26,7 @@ class Sales(Role):
|
|||
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
|
||||
)
|
||||
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_stroe(self):
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ class Message(BaseModel):
|
|||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("send_to", mode="plain")
|
||||
def ser_send_to(self, send_to: set) -> list:
|
||||
return list(send_to)
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
||||
ic_dict = None
|
||||
|
|
@ -276,6 +280,10 @@ class Message(BaseModel):
|
|||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum):
|
|||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class SearchInterface:
|
||||
async def asearch(self, *args, **kwargs):
|
||||
...
|
||||
|
|
|
|||
22
metagpt/utils/async_helper.py
Normal file
22
metagpt/utils/async_helper.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import asyncio
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_coroutine_in_new_loop(coroutine) -> Any:
|
||||
"""Runs a coroutine in a new, separate event loop on a different thread.
|
||||
|
||||
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
|
||||
"""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=lambda: new_loop.run_forever())
|
||||
t.start()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
|
||||
|
||||
try:
|
||||
return future.result()
|
||||
finally:
|
||||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
|
@ -5,12 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : embedding.py
|
||||
"""
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_embedding():
|
||||
def get_embedding() -> OpenAIEmbedding:
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
if llm is None:
|
||||
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
|
||||
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return embedding
|
||||
|
|
|
|||
18
metagpt/utils/reflection.py
Normal file
18
metagpt/utils/reflection.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
|
||||
|
||||
|
||||
def check_methods(C, *methods):
|
||||
"""Check if the class has methods. borrow from _collections_abc.
|
||||
|
||||
Useful when implementing implicit interfaces, such as defining an abstract class, isinstance can be used for determination without inheritance.
|
||||
"""
|
||||
mro = C.__mro__
|
||||
for method in methods:
|
||||
for B in mro:
|
||||
if method in B.__dict__:
|
||||
if B.__dict__[method] is None:
|
||||
return NotImplemented
|
||||
break
|
||||
else:
|
||||
return NotImplemented
|
||||
return True
|
||||
Loading…
Add table
Add a link
Reference in a new issue