Merge pull request #974 from better629/feat_memory

Feat add rag
This commit is contained in:
Alexander Wu 2024-03-17 23:39:12 +08:00 committed by GitHub
commit e783e5b208
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
61 changed files with 2353 additions and 248 deletions

5
.gitattributes vendored
View file

@ -12,6 +12,11 @@
*.jpg binary
*.gif binary
*.ico binary
*.jpeg binary
*.mp3 binary
*.zip binary
*.bin binary
# Preserve original line endings for specific document files
*.doc text eol=crlf

13
.gitignore vendored
View file

@ -27,6 +27,8 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
metagpt/tools/schemas/
examples/data/search_kb/*.json
# PyInstaller
# Usually these files are written by a python scripts from a template
@ -151,9 +153,14 @@ allure-results
.vscode
key.yaml
data
/data/
data.ms
examples/nb/
examples/default__vector_store.json
examples/docstore.json
examples/graph_store.json
examples/image__vector_store.json
examples/index_store.json
.chroma
*~$*
workspace/*
@ -168,6 +175,7 @@ output
tmp.png
.dependencies.json
tests/metagpt/utils/file_repo_git
tests/data/rsp_cache_new.json
*.tmp
*.png
htmlcov
@ -178,4 +186,5 @@ cov.xml
*.faiss
*-structure.csv
*-structure.json
metagpt/tools/schemas
*.dot
.python-version

View file

@ -0,0 +1 @@
Bob likes traveling.

View file

@ -0,0 +1,109 @@
Productivity
I think I am at least somewhat more productive than average, and people sometimes ask me for productivity tips. So I decided to just write them all down in one place.
Compound growth gets discussed as a financial concept, but it works in careers as well, and it is magic. A small productivity gain, compounded over 50 years, is worth a lot. So its worth figuring out how to optimize productivity. If you get 10% more done and 1% better every day compared to someone else, the compounded difference is massive.
What you work on
Famous writers have some essential qualities, creativity and discipline
It doesnt matter how fast you move if its in a worthless direction. Picking the right thing to work on is the most important element of productivity and usually almost ignored. So think about it more! Independent thought is hard but its something you can get better at with practice.
The most impressive people I know have strong beliefs about the world, which is rare in the general population. If you find yourself always agreeing with whomever you last spoke with, thats bad. You will of course be wrong sometimes, but develop the confidence to stick with your convictions. It will let you be courageous when youre right about something important that most people dont see.
I make sure to leave enough time in my schedule to think about what to work on. The best ways for me to do this are reading books, hanging out with interesting people, and spending time in nature.
Ive learned that I cant be very productive working on things I dont care about or dont like. So I just try not to put myself in a position where I have to do them (by delegating, avoiding, or something else). Stuff that you dont like is a painful drag on morale and momentum.
By the way, here is an important lesson about delegation: remember that everyone else is also most productive when theyre doing what they like, and do what youd want other people to do for you—try to figure out who likes (and is good at) doing what, and delegate that way.
If you find yourself not liking what youre doing for a long period of time, seriously consider a major job change. Short-term burnout happens, but if it isnt resolved with some time off, maybe its time to do something youre more interested in.
Ive been very fortunate to find work I like so much Id do it for free, which makes it easy to be really productive.
Its important to learn that you can learn anything you want, and that you can get better quickly. This feels like an unlikely miracle the first few times it happens, but eventually you learn to trust that you can do it.
Doing great work usually requires colleagues of some sort. Try to be around smart, productive, happy, and positive people that dont belittle your ambitions. I love being around people who push me and inspire me to be better. To the degree you able to, avoid the opposite kind of people—the cost of letting them take up your mental cycles is horrific.
You have to both pick the right problem and do the work. There arent many shortcuts. If youre going to do something really important, you are very likely going to work both smart and hard. The biggest prizes are heavily competed for. This isnt true in every field (there are great mathematicians who never spend that many hours a week working) but it is in most.
Prioritization
Writers have to work hard to be successful
My system has three key pillars: “Make sure to get the important shit done”, “Dont waste time on stupid shit”, and “make a lot of lists”.
I highly recommend using lists. I make lists of what I want to accomplish each year, each month, and each day. Lists are very focusing, and they help me with multitasking because I dont have to keep as much in my head. If Im not in the mood for some particular task, I can always find something else Im excited to do.
I prefer lists written down on paper. Its easy to add and remove tasks. I can access them during meetings without feeling rude. I re-transcribe lists frequently, which forces me to think about everything on the list and gives me an opportunity to add and remove items.
I dont bother with categorization or trying to size tasks or anything like that (the most I do is put a star next to really important items).
I try to prioritize in a way that generates momentum. The more I get done, the better I feel, and then the more I get done. I like to start and end each day with something I can really make progress on.
I am relentless about getting my most important projects done—Ive found that if I really want something to happen and I push hard enough, it usually happens.
I try to be ruthless about saying no to stuff, and doing non-critical things in the quickest way possible. I probably take this too far—for example, I am almost sure I am terse to the point of rudeness when replying to emails.
Passion and adaptability are key qualities to writers
I generally try to avoid meetings and conferences as I find the time cost to be huge—I get the most value out of time in my office. However, it is critical that you keep enough space in your schedule to allow for chance encounters and exposure to new people and ideas. Having an open network is valuable; though probably 90% of the random meetings I take are a waste of time, the other 10% really make up for it.
I find most meetings are best scheduled for 15-20 minutes, or 2 hours. The default of 1 hour is usually wrong, and leads to a lot of wasted time.
I have different times of day I try to use for different kinds of work. The first few hours of the morning are definitely my most productive time of the day, so I dont let anyone schedule anything then. I try to do meetings in the afternoon. I take a break, or switch tasks, whenever I feel my attention starting to fade.
I dont think most people value their time enough—I am surprised by the number of people I know who make $100 an hour and yet will spend a couple of hours doing something they dont want to do to save $20.
Also, dont fall into the trap of productivity porn—chasing productivity for its own sake isnt helpful. Many people spend too much time thinking about how to perfectly optimize their system, and not nearly enough asking if theyre working on the right problems. It doesnt matter what system you use or if you squeeze out every second if youre working on the wrong thing.
The right goal is to allocate your year optimally, not your day.
Physical factors
Very likely what is optimal for me wont be optimal for you. Youll have to experiment to find out what works best for your body. Its definitely worth doing—it helps in all aspects of life, and youll feel a lot better and happier overall.
It probably took a little bit of my time every week for a few years to arrive at what works best for me, but my sense is if I do a good job at all the below Im at least 1.5x more productive than if not.
Sleep seems to be the most important physical factor in productivity for me. Some sort of sleep tracker to figure out how to sleep best is helpful. Ive found the only thing Im consistent with are in the set-it-and-forget-it category, and I really like the Emfit QS+Active.
I like a cold, dark, quiet room, and a great mattress (I resisted spending a bunch of money on a great mattress for years, which was stupid—it makes a huge difference to my sleep quality. I love this one). Not eating a lot in the few hours before sleep helps. Not drinking alcohol helps a lot, though Im not willing to do that all the time.
I use a Chili Pad to be cold while I sleep if I cant get the room cold enough, which is great but loud (I set it up to have the cooler unit outside my room).
When traveling, I use an eye mask and ear plugs.
Writers usually have empathy to write good books.
This is likely to be controversial, but I take a low dose of sleeping pills (like a third of a normal dose) or a very low dose of cannabis whenever I cant sleep. I am a bad sleeper in general, and a particularly bad sleeper when I travel. It likely has tradeoffs, but so does not sleeping well. If you can already sleep well, I wouldnt recommend this.
I use a full spectrum LED light most mornings for about 10-15 minutes while I catch up on email. Its great—if you try nothing else in here, this is the thing Id try. Its a ridiculous gain for me. I like this one, and its easy to travel with.
Exercise is probably the second most important physical factor. I tried a number of different exercise programs for a few months each and the one that seemed best was lifting heavy weights 3x a week for an hour, and high intensity interval training occasionally. In addition to productivity gains, this is also the exercise program that makes me feel the best overall.
The third area is nutrition. I very rarely eat breakfast, so I get about 15 hours of fasting most days (except an espresso when I wake up). I know this is contrary to most advice, and I suspect its not optimal for most people, but it definitely works well for me.
Eating lots of sugar is the thing that makes me feel the worst and that I try hardest to avoid. I also try to avoid foods that aggravate my digestion or spike up inflammation (for example, very spicy foods). I dont have much willpower when it comes to sweet things, so I mostly just try to keep junk food out of the house.
I have one big shot of espresso immediately when I wake up and one after lunch. I assume this is about 200mg total of caffeine per day. I tried a few other configurations; this was the one that worked by far the best. I otherwise aggressively avoid stimulants, but I will have more coffee if Im super tired and really need to get something done.
If a writer want to be super, then should include innovative thinking.
Im vegetarian and have been since I was a kid, and I supplement methyl B-12, Omega-3, Iron, and Vitamin D-3. I got to this list with a year or so of quarterly blood tests; its worked for me ever since (I re-test maybe every year and a half or so). There are many doctors who will happily work with you on a super comprehensive blood test (and services like WellnessFX). I also go out of my way to drink a lot of protein shakes, which I hate and I wouldnt do if I werent vegetarian.
Other stuff
Heres what I like in a workspace: natural light, quiet, knowing that I wont be interrupted if I dont want to be, long blocks of time, and being comfortable and relaxed (Ive got a beautiful desk with a couple of 4k monitors on it in my office, but I spend almost all my time on my couch with my laptop).
I wrote custom software for the annoying things I have to do frequently, which is great. I also made an effort to learn to type really fast and the keyboard shortcuts that help with my workflow.
Like most people, I sometimes go through periods of a week or two where I just have no motivation to do anything (I suspect it may have something to do with nutrition). This sucks and always seems to happen at inconvenient times. I have not figured out what to do about it besides wait for the fog to lift, and to trust that eventually it always does. And I generally try to avoid people and situations that put me in bad moods, which is good advice whether you care about productivity or not.
In general, I think its good to overcommit a little bit. I find that I generally get done what I take on, and if I have a little bit too much to do it makes me more efficient at everything, which is a way to train to avoid distractions (a great habit to build!). However, overcommitting a lot is disastrous.
Dont neglect your family and friends for the sake of productivity—thats a very stupid tradeoff (and very likely a net productivity loss, because youll be less happy). Dont neglect doing things you love or that clear your head either.
Finally, to repeat one more time: productivity in the wrong direction isnt worth anything at all. Think more about what to work on.
Open-Mindedness and curiosity are essential to writers

211
examples/rag_pipeline.py Normal file
View file

@ -0,0 +1,211 @@
"""RAG pipeline"""
import asyncio
from pydantic import BaseModel
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import (
BM25RetrieverConfig,
ChromaIndexConfig,
ChromaRetrieverConfig,
FAISSRetrieverConfig,
LLMRankerConfig,
)
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
QUESTION = "What are key qualities to be a good writer?"
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
TRAVEL_QUESTION = "What does Bob like?"
LLM_TIP = "If you not sure, just answer I don't know."
class Player(BaseModel):
"""To demonstrate rag add objs."""
name: str = ""
goal: str = "Win The 100-meter Sprint."
tool: str = "Red Bull Energy Drink."
def rag_key(self) -> str:
"""For search"""
return self.goal
class RAGExample:
"""Show how to use RAG."""
def __init__(self):
self.engine = SimpleEngine.from_docs(
input_files=[DOC_PATH],
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
ranker_configs=[LLMRankerConfig()],
)
async def run_pipeline(self, question=QUESTION, print_title=True):
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
Retrieve Result:
0. Productivi..., 10.0
1. I wrote cu..., 7.0
2. I highly r..., 5.0
Query Result:
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
"""
if print_title:
self._print_title("Run Pipeline")
nodes = await self.engine.aretrieve(question)
self._print_retrieve_result(nodes)
answer = await self.engine.aquery(question)
self._print_query_result(answer)
async def add_docs(self):
"""This example show how to add docs.
Before add docs llm anwser I don't know.
After add docs llm give the correct answer, will print something like:
[Before add docs]
Retrieve Result:
Query Result:
Empty Response
[After add docs]
Retrieve Result:
0. Bob like..., 10.0
Query Result:
Bob likes traveling.
"""
self._print_title("Add Docs")
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
travel_filepath = TRAVEL_DOC_PATH
logger.info("[Before add docs]")
await self.run_pipeline(question=travel_question, print_title=False)
logger.info("[After add docs]")
self.engine.add_docs([travel_filepath])
await self.run_pipeline(question=travel_question, print_title=False)
async def add_objects(self, print_title=True):
"""This example show how to add objects.
Before add docs, engine retrieve nothing.
After add objects, engine give the correct answer, will print something like:
[Before add objs]
Retrieve Result:
[After add objs]
Retrieve Result:
0. 100m Sprin..., 10.0
[Object Detail]
{'name': 'Mike', 'goal': 'Win The 100-meter Sprint', 'tool': 'Red Bull Energy Drink'}
"""
if print_title:
self._print_title("Add Objects")
player = Player(name="Mike")
question = f"{player.rag_key()}"
logger.info("[Before add objs]")
await self._retrieve_and_print(question)
logger.info("[After add objs]")
self.engine.add_objs([player])
try:
nodes = await self._retrieve_and_print(question)
logger.info("[Object Detail]")
player: Player = nodes[0].metadata["obj"]
logger.info(player.name)
except Exception as e:
logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")
async def init_objects(self):
"""This example show how to from objs, will print something like:
Same as add_objects.
"""
self._print_title("Init Objects")
pre_engine = self.engine
self.engine = SimpleEngine.from_objs(retriever_configs=[FAISSRetrieverConfig()])
await self.add_objects(print_title=False)
self.engine = pre_engine
async def init_and_query_chromadb(self):
"""This example show how to use chromadb. how to save and load index. will print something like:
Query Result:
Bob likes traveling.
"""
self._print_title("Init And Query ChromaDB")
# save index
output_dir = DATA_PATH / "rag"
SimpleEngine.from_docs(
input_files=[TRAVEL_DOC_PATH],
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
)
# load index
engine = SimpleEngine.from_index(
index_config=ChromaIndexConfig(persist_path=output_dir),
)
# query
answer = engine.query(TRAVEL_QUESTION)
self._print_query_result(answer)
@staticmethod
def _print_title(title):
logger.info(f"{'#'*30} {title} {'#'*30}")
@staticmethod
def _print_retrieve_result(result):
"""Print retrieve result."""
logger.info("Retrieve Result:")
for i, node in enumerate(result):
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
logger.info("")
@staticmethod
def _print_query_result(result):
"""Print query result."""
logger.info("Query Result:")
logger.info(f"{result}\n")
async def _retrieve_and_print(self, question):
nodes = await self.engine.aretrieve(question)
self._print_retrieve_result(nodes)
return nodes
async def main():
"""RAG pipeline"""
e = RAGExample()
await e.run_pipeline()
await e.add_docs()
await e.add_objects()
await e.init_objects()
await e.init_and_query_chromadb()
if __name__ == "__main__":
asyncio.run(main())

21
examples/rag_search.py Normal file
View file

@ -0,0 +1,21 @@
"""Agent with RAG search."""
import asyncio
from examples.rag_pipeline import DOC_PATH, QUESTION
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.roles import Sales
async def search():
"""Agent with RAG search."""
store = SimpleEngine.from_docs(input_files=[DOC_PATH])
role = Sales(profile="Sales", store=store)
result = await role.run(QUESTION)
logger.info(result)
if __name__ == "__main__":
asyncio.run(search())

View file

@ -1,33 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@File : search_kb.py
@Modified By: mashenquan, 2023-12-22. Delete useless codes.
"""
import asyncio
from langchain.embeddings import OpenAIEmbeddings
from metagpt.config2 import config
from metagpt.const import DATA_PATH, EXAMPLE_PATH
from metagpt.document_store import FaissStore
from metagpt.logs import logger
from metagpt.roles import Sales
def get_store():
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
async def search():
store = FaissStore(EXAMPLE_PATH / "example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
logger.info(result)
if __name__ == "__main__":
asyncio.run(search())

View file

@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
EXAMPLE_PATH = METAGPT_ROOT / "examples"
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
DATA_PATH = METAGPT_ROOT / "data"
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
RESEARCH_PATH = DATA_PATH / "research"

View file

@ -11,12 +11,9 @@ from pathlib import Path
from typing import Optional, Union
import pandas as pd
from langchain.text_splitter import CharacterTextSplitter
from langchain_community.document_loaders import (
TextLoader,
UnstructuredPDFLoader,
UnstructuredWordDocumentLoader,
)
from llama_index.core import Document, SimpleDirectoryReader
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.readers.file import PDFReader
from pydantic import BaseModel, ConfigDict, Field
from tqdm import tqdm
@ -29,7 +26,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
raise ValueError("Content column not found in DataFrame.")
def read_data(data_path: Path):
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
suffix = data_path.suffix
if ".xlsx" == suffix:
data = pd.read_excel(data_path)
@ -38,14 +35,13 @@ def read_data(data_path: Path):
elif ".json" == suffix:
data = pd.read_json(data_path)
elif suffix in (".docx", ".doc"):
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
elif ".txt" == suffix:
data = TextLoader(str(data_path)).load()
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
texts = text_splitter.split_documents(data)
data = texts
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
data = node_parser.get_nodes_from_documents(data)
elif ".pdf" == suffix:
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
data = PDFReader.load_data(str(data_path))
else:
raise NotImplementedError("File format not supported.")
return data
@ -150,9 +146,9 @@ class IndexableDocument(Document):
metadatas.append({})
return docs, metadatas
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
data = self.data
docs = [i.page_content for i in data]
docs = [i.text for i in data]
metadatas = [i.metadata for i in data]
return docs, metadatas
@ -160,7 +156,7 @@ class IndexableDocument(Document):
if isinstance(self.data, pd.DataFrame):
return self._get_docs_and_metadatas_by_df()
elif isinstance(self.data, list):
return self._get_docs_and_metadatas_by_langchain()
return self._get_docs_and_metadatas_by_llamaindex()
else:
raise NotImplementedError("Data type not supported for metadata extraction.")

View file

@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC):
if not self.store:
self.store = self.write()
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
index_file = self.cache_dir / f"{self.fname}{index_ext}"
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"):
index_file = self.cache_dir / "default__vector_store" / index_ext
store_file = self.cache_dir / "docstore" / docstore_ext
return index_file, store_file
@abstractmethod

View file

@ -11,9 +11,9 @@ import chromadb
class ChromaStore:
"""If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange."""
def __init__(self, name):
def __init__(self, name: str, get_or_create: bool = False):
client = chromadb.Client()
collection = client.create_collection(name)
collection = client.create_collection(name, get_or_create=get_or_create)
self.client = client
self.collection = collection

View file

@ -7,10 +7,14 @@
"""
import asyncio
from pathlib import Path
from typing import Optional
from typing import Any, Optional
from langchain.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
import faiss
from llama_index.core import VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.schema import Document, QueryBundle, TextNode
from llama_index.core.storage import StorageContext
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.document import IndexableDocument
from metagpt.document_store.base_store import LocalStore
@ -20,36 +24,50 @@ from metagpt.utils.embedding import get_embedding
class FaissStore(LocalStore):
def __init__(
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
):
self.meta_col = meta_col
self.content_col = content_col
self.embedding = embedding or get_embedding()
self.store: VectorStoreIndex
super().__init__(raw_data, cache_dir)
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
def _load(self) -> Optional["VectorStoreIndex"]:
index_file, store_file = self._get_index_and_store_fname()
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
index = load_index_from_storage(storage_context, embed_model=self.embedding)
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
return index
def _write(self, docs, metadatas):
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
return store
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
assert len(docs) == len(metadatas)
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
storage_context = StorageContext.from_defaults(vector_store=vector_store)
index = VectorStoreIndex.from_documents(
documents=documents, storage_context=storage_context, embed_model=self.embedding
)
return index
def persist(self):
self.store.save_local(self.raw_data_path.parent, self.fname)
self.store.storage_context.persist(self.cache_dir)
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
retriever = self.store.as_retriever(similarity_top_k=k)
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
rsp = self.store.similarity_search(query, k=k, **kwargs)
logger.debug(rsp)
if expand_cols:
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
else:
return str(sep.join([f"{x.page_content}" for x in rsp]))
return str(sep.join([f"{x.node.text}" for x in rsp]))
async def asearch(self, *args, **kwargs):
return await asyncio.to_thread(self.search, *args, **kwargs)
@ -67,8 +85,12 @@ class FaissStore(LocalStore):
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
"""FIXME: Currently, the store is not updated after adding."""
return self.store.add_texts(texts)
texts_embeds = self.embedding.get_text_embedding_batch(texts)
nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)]
self.store.insert_nodes(nodes)
return []
def delete(self, *args, **kwargs):
"""Currently, langchain does not provide a delete interface."""
"""Currently, faiss does not provide a delete interface."""
raise NotImplementedError

View file

@ -8,8 +8,6 @@ import re
import time
from typing import Any, Iterable
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from pydantic import ConfigDict, Field
from metagpt.config2 import config as CONFIG
@ -17,6 +15,7 @@ from metagpt.environment.base_env import Environment
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
from metagpt.logs import logger
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
@ -48,9 +47,9 @@ class MincraftEnv(Environment, MincraftExtEnv):
runtime_status: bool = False # equal to action execution status: success or failed
vectordb: Chroma = Field(default_factory=Chroma)
vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
@property
def progress(self):
@ -73,16 +72,14 @@ class MincraftEnv(Environment, MincraftExtEnv):
self.set_mc_resume()
def set_mc_resume(self):
self.qa_cache_questions_vectordb = Chroma(
self.qa_cache_questions_vectordb = ChromaVectorStore(
collection_name="qa_cache_questions_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb",
persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb",
)
self.vectordb = Chroma(
self.vectordb = ChromaVectorStore(
collection_name="skill_vectordb",
embedding_function=OpenAIEmbeddings(),
persist_directory=f"{MC_CKPT_DIR}/skill/vectordb",
persist_dir=f"{MC_CKPT_DIR}/skill/vectordb",
)
if CONFIG.resume:

View file

@ -29,16 +29,14 @@ class LongTermMemory(Memory):
msg_from_recover: bool = False
def recover_memory(self, role_id: str, rc: RoleContext):
messages = self.memory_storage.recover_memory(role_id)
self.memory_storage.recover_memory(role_id)
self.rc = rc
if not self.memory_storage.is_initialized:
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty")
else:
logger.warning(
f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them."
)
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
self.msg_from_recover = True
self.add_batch(messages)
# self.add_batch(messages) # TODO no need
self.msg_from_recover = False
def add(self, message: Message):
@ -49,7 +47,7 @@ class LongTermMemory(Memory):
# and ignore adding messages from recover repeatedly
self.memory_storage.add(message)
def find_news(self, observed: list[Message], k=0) -> list[Message]:
async def find_news(self, observed: list[Message], k=0) -> list[Message]:
"""
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
1. find the short-term memory(stm) news
@ -63,11 +61,14 @@ class LongTermMemory(Memory):
ltm_news: list[Message] = []
for mem in stm_news:
# filter out messages similar to those seen previously in ltm, only keep fresh news
mem_searched = self.memory_storage.search_dissimilar(mem)
if len(mem_searched) > 0:
mem_searched = await self.memory_storage.search_similar(mem)
if len(mem_searched) == 0:
ltm_news.append(mem)
return ltm_news[-k:]
def persist(self):
self.memory_storage.persist()
def delete(self, message: Message):
super().delete(message)
# TODO delete message in memory_storage

View file

@ -3,115 +3,75 @@
"""
@Desc : the implement of memory storage
"""
import shutil
from pathlib import Path
from typing import Optional
from langchain.vectorstores.faiss import FAISS
from langchain_core.embeddings import Embeddings
from llama_index.core.embeddings import BaseEmbedding
from metagpt.const import DATA_PATH, MEM_TTL
from metagpt.document_store.faiss_store import FaissStore
from metagpt.logs import logger
from metagpt.rag.engines.simple import SimpleEngine
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig
from metagpt.schema import Message
from metagpt.utils.embedding import get_embedding
from metagpt.utils.serialize import deserialize_message, serialize_message
class MemoryStorage(FaissStore):
class MemoryStorage(object):
"""
The memory storage with Faiss as ANN search engine
"""
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
self.role_id: str = None
self.role_mem_path: str = None
self.mem_ttl: int = mem_ttl # later use
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
self._initialized: bool = False
self.embedding = embedding or get_embedding()
self.store: FAISS = None # Faiss engine
self.faiss_engine = None
@property
def is_initialized(self) -> bool:
return self._initialized
def _load(self) -> Optional["FaissStore"]:
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
if not (index_file.exists() and store_file.exists()):
logger.info("Missing at least one of index_file/store_file, load failed and return None")
return None
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
def recover_memory(self, role_id: str) -> list[Message]:
self.role_id = role_id
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
self.role_mem_path.mkdir(parents=True, exist_ok=True)
self.cache_dir = self.role_mem_path
self.store = self._load()
messages = []
if not self.store:
# TODO init `self.store` under here with raw faiss api instead under `add`
pass
if self.role_mem_path.joinpath("default__vector_store.json").exists():
self.faiss_engine = SimpleEngine.from_index(
index_config=FAISSIndexConfig(persist_path=self.cache_dir),
retriever_configs=[FAISSRetrieverConfig()],
embed_model=self.embedding,
)
else:
for _id, document in self.store.docstore._dict.items():
messages.append(deserialize_message(document.metadata.get("message_ser")))
self._initialized = True
return messages
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
if not self.role_mem_path:
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
return None, None
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
return index_fpath, storage_fpath
def persist(self):
self.store.save_local(self.role_mem_path, self.role_id)
logger.debug(f"Agent {self.role_id} persist memory into local")
self.faiss_engine = SimpleEngine.from_objs(
objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding
)
self._initialized = True
def add(self, message: Message) -> bool:
"""add message into memory storage"""
docs = [message.content]
metadatas = [{"message_ser": serialize_message(message)}]
if not self.store:
# init Faiss
self.store = self._write(docs, metadatas)
self._initialized = True
else:
self.store.add_texts(texts=docs, metadatas=metadatas)
self.persist()
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
self.faiss_engine.add_objs([message])
logger.info(f"Role {self.role_id}'s memory_storage add a message")
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
"""search for dissimilar messages"""
if not self.store:
return []
resp = self.store.similarity_search_with_score(query=message.content, k=k)
async def search_similar(self, message: Message, k=4) -> list[Message]:
"""search for similar messages"""
# filter the result which score is smaller than the threshold
filtered_resp = []
for item, score in resp:
# the smaller score means more similar relation
if score < self.threshold:
continue
# convert search result into Memory
metadata = item.metadata
new_mem = deserialize_message(metadata.get("message_ser"))
filtered_resp.append(new_mem)
resp = await self.faiss_engine.aretrieve(message.content)
for item in resp:
if item.score < self.threshold:
filtered_resp.append(item.metadata.get("obj"))
return filtered_resp
def clean(self):
index_fpath, storage_fpath = self._get_index_and_store_fname()
if index_fpath and index_fpath.exists():
index_fpath.unlink(missing_ok=True)
if storage_fpath and storage_fpath.exists():
storage_fpath.unlink(missing_ok=True)
self.store = None
shutil.rmtree(self.cache_dir, ignore_errors=True)
self._initialized = False
def persist(self):
if self.faiss_engine:
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)

0
metagpt/rag/__init__.py Normal file
View file

View file

@ -0,0 +1,5 @@
"""Engines init"""
from metagpt.rag.engines.simple import SimpleEngine
__all__ = ["SimpleEngine"]

View file

@ -0,0 +1,259 @@
"""Simple Engine."""
import json
import os
from typing import Any, Optional, Union
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion.pipeline import run_transformations
from llama_index.core.llms import LLM
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import (
BaseSynthesizer,
get_response_synthesizer,
)
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import (
BaseNode,
Document,
NodeWithScore,
QueryBundle,
QueryType,
TransformComponent,
)
from metagpt.rag.factories import (
get_index,
get_rag_embedding,
get_rag_llm,
get_rankers,
get_retriever,
)
from metagpt.rag.interface import NoEmbedding, RAGObject
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseIndexConfig,
BaseRankerConfig,
BaseRetrieverConfig,
BM25RetrieverConfig,
ObjectNode,
)
from metagpt.utils.common import import_class
class SimpleEngine(RetrieverQueryEngine):
"""SimpleEngine is designed to be simple and straightforward.
It is a lightweight and easy-to-use search engine that integrates
document reading, embedding, indexing, retrieving, and ranking functionalities
into a single, straightforward workflow. It is designed to quickly set up a
search engine from a collection of documents.
"""
def __init__(
self,
retriever: BaseRetriever,
response_synthesizer: Optional[BaseSynthesizer] = None,
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
callback_manager: Optional[CallbackManager] = None,
index: Optional[BaseIndex] = None,
) -> None:
super().__init__(
retriever=retriever,
response_synthesizer=response_synthesizer,
node_postprocessors=node_postprocessors,
callback_manager=callback_manager,
)
self.index = index
@classmethod
def from_docs(
cls,
input_dir: str = None,
input_files: list[str] = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""From docs.
Must provide either `input_dir` or `input_files`.
Args:
input_dir: Path to the directory.
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not input_dir and not input_files:
raise ValueError("Must provide either `input_dir` or `input_files`.")
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
cls._fix_document_metadata(documents)
index = VectorStoreIndex.from_documents(
documents=documents,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_objs(
cls,
objs: Optional[list[RAGObject]] = None,
transformations: Optional[list[TransformComponent]] = None,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""From objs.
Args:
objs: List of RAGObject.
transformations: Parse documents to nodes. Default [SentenceSplitter].
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
llm: Must supported by llama index. Default OpenAI.
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
ranker_configs: Configuration for rankers.
"""
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
objs = objs or []
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
index = VectorStoreIndex(
nodes=nodes,
transformations=transformations or [SentenceSplitter()],
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
)
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
@classmethod
def from_index(
cls,
index_config: BaseIndexConfig,
embed_model: BaseEmbedding = None,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
"""Load from previously maintained index by self.persist(), index_config contains persis_path."""
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
async def asearch(self, content: str, **kwargs) -> str:
"""Inplement tools.SearchInterface"""
return await self.aquery(content)
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Allow query to be str."""
query_bundle = QueryBundle(query) if isinstance(query, str) else query
nodes = await super().aretrieve(query_bundle)
self._try_reconstruct_obj(nodes)
return nodes
def add_docs(self, input_files: list[str]):
"""Add docs to retriever. retriever must has add_nodes func."""
self._ensure_retriever_modifiable()
documents = SimpleDirectoryReader(input_files=input_files).load_data()
self._fix_document_metadata(documents)
nodes = run_transformations(documents, transformations=self.index._transformations)
self._save_nodes(nodes)
def add_objs(self, objs: list[RAGObject]):
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
self._ensure_retriever_modifiable()
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
self._save_nodes(nodes)
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
"""Persist."""
self._ensure_retriever_persistable()
self._persist(str(persist_dir), **kwargs)
@classmethod
def _from_index(
cls,
index: BaseIndex,
llm: LLM = None,
retriever_configs: list[BaseRetrieverConfig] = None,
ranker_configs: list[BaseRankerConfig] = None,
) -> "SimpleEngine":
llm = llm or get_rag_llm()
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
return cls(
retriever=retriever,
node_postprocessors=rankers,
response_synthesizer=get_response_synthesizer(llm=llm),
index=index,
)
def _ensure_retriever_modifiable(self):
self._ensure_retriever_of_type(ModifiableRAGRetriever)
def _ensure_retriever_persistable(self):
self._ensure_retriever_of_type(PersistableRAGRetriever)
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
Args:
required_type: The class that the retriever is expected to be an instance of.
"""
if isinstance(self.retriever, SimpleHybridRetriever):
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
raise TypeError(
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
)
if not isinstance(self.retriever, required_type):
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
def _save_nodes(self, nodes: list[BaseNode]):
self.retriever.add_nodes(nodes)
def _persist(self, persist_dir: str, **kwargs):
self.retriever.persist(persist_dir, **kwargs)
@staticmethod
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
for node in nodes:
if node.metadata.get("is_obj", False):
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
obj_dict = json.loads(node.metadata["obj_json"])
node.metadata["obj"] = obj_cls(**obj_dict)
@staticmethod
def _fix_document_metadata(documents: list[Document]):
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
for doc in documents:
doc.excluded_embed_metadata_keys.append("file_path")
@staticmethod
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
if configs and all(isinstance(c, NoEmbedding) for c in configs):
return MockEmbedding(embed_dim=1)
return embed_model or get_rag_embedding()

View file

@ -0,0 +1,9 @@
"""RAG factories"""
from metagpt.rag.factories.retriever import get_retriever
from metagpt.rag.factories.ranker import get_rankers
from metagpt.rag.factories.embedding import get_rag_embedding
from metagpt.rag.factories.index import get_index
from metagpt.rag.factories.llm import get_rag_llm
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"]

View file

@ -0,0 +1,59 @@
"""Base Factory."""
from typing import Any, Callable
class GenericFactory:
"""Designed to get objects based on any keys."""
def __init__(self, creators: dict[Any, Callable] = None):
"""Creators is a dictionary.
Keys are identifiers, and the values are the associated creator function, which create objects.
"""
self._creators = creators or {}
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
"""Get instances by keys."""
return [self.get_instance(key, **kwargs) for key in keys]
def get_instance(self, key: Any, **kwargs) -> Any:
"""Get instance by key.
Raise Exception if key not found.
"""
creator = self._creators.get(key)
if creator:
return creator(**kwargs)
raise ValueError(f"Creator not registered for key: {key}")
class ConfigBasedFactory(GenericFactory):
"""Designed to get objects based on object type."""
def get_instance(self, key: Any, **kwargs) -> Any:
"""Key is config, such as a pydantic model.
Call func by the type of key, and the key will be passed to func.
"""
creator = self._creators.get(type(key))
if creator:
return creator(key, **kwargs)
raise ValueError(f"Unknown config: {key}")
@staticmethod
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
if config is not None and hasattr(config, key):
val = getattr(config, key)
if val is not None:
return val
if key in kwargs:
return kwargs[key]
raise KeyError(
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
)

View file

@ -0,0 +1,37 @@
"""RAG Embedding Factory."""
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
from metagpt.configs.llm_config import LLMType
from metagpt.rag.factories.base import GenericFactory
class RAGEmbeddingFactory(GenericFactory):
"""Create LlamaIndex Embedding with MetaGPT's config."""
def __init__(self):
creators = {
LLMType.OPENAI: self._create_openai,
LLMType.AZURE: self._create_azure,
}
super().__init__(creators)
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
"""Key is LLMType, default use config.llm.api_type."""
return super().get_instance(key or config.llm.api_type)
def _create_openai(self):
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
def _create_azure(self):
return AzureOpenAIEmbedding(
azure_endpoint=config.llm.base_url,
api_key=config.llm.api_key,
api_version=config.llm.api_version,
)
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding

View file

@ -0,0 +1,63 @@
"""RAG Index Factory."""
import chromadb
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import (
BaseIndexConfig,
BM25IndexConfig,
ChromaIndexConfig,
FAISSIndexConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RAGIndexFactory(ConfigBasedFactory):
def __init__(self):
creators = {
FAISSIndexConfig: self._create_faiss,
ChromaIndexConfig: self._create_chroma,
BM25IndexConfig: self._create_bm25,
}
super().__init__(creators)
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
"""Key is PersistType."""
return super().get_instance(config, **kwargs)
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
db = chromadb.PersistentClient(str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
index = VectorStoreIndex.from_vector_store(
vector_store,
embed_model=embed_model,
)
return index
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
embed_model = self._extract_embed_model(config, **kwargs)
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
return index
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
get_index = RAGIndexFactory().get_index

View file

@ -0,0 +1,54 @@
"""RAG LLM."""
from typing import Any
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
from llama_index.core.llms import (
CompletionResponse,
CompletionResponseGen,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import llm_completion_callback
from pydantic import Field
from metagpt.config2 import config
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.async_helper import run_coroutine_in_new_loop
from metagpt.utils.token_counter import TOKEN_MAX
class RAGLLM(CustomLLM):
"""LlamaIndex's LLM is different from MetaGPT's LLM.
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
"""
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
num_output: int = config.llm.max_token
model_name: str = config.llm.model
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
@llm_completion_callback()
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
text = await self.model_infer.aask(msg=prompt, stream=False)
return CompletionResponse(text=text)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
...
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
"""Get llm that can be used by LlamaIndex."""
return RAGLLM(model_infer=model_infer or LLM())

View file

@ -0,0 +1,35 @@
"""RAG Ranker Factory."""
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
class RankerFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
LLMRankerConfig: self._create_llm_ranker,
}
super().__init__(creators)
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
"""Creates and returns a retriever instance based on the provided configurations."""
if not configs:
return []
return super().get_instances(configs, **kwargs)
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
config.llm = self._extract_llm(config, **kwargs)
return LLMRerank(**config.model_dump())
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
return self._val_from_config_or_kwargs("llm", config, **kwargs)
get_rankers = RankerFactory().get_rankers

View file

@ -0,0 +1,86 @@
"""RAG Retriever Factory."""
import copy
import chromadb
import faiss
from llama_index.core import StorageContext, VectorStoreIndex
from llama_index.core.vector_stores.types import BasePydanticVectorStore
from llama_index.vector_stores.faiss import FaissVectorStore
from metagpt.rag.factories.base import ConfigBasedFactory
from metagpt.rag.retrievers.base import RAGRetriever
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import (
BaseRetrieverConfig,
BM25RetrieverConfig,
ChromaRetrieverConfig,
FAISSRetrieverConfig,
IndexRetrieverConfig,
)
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
class RetrieverFactory(ConfigBasedFactory):
"""Modify creators for dynamically instance implementation."""
def __init__(self):
creators = {
FAISSRetrieverConfig: self._create_faiss_retriever,
BM25RetrieverConfig: self._create_bm25_retriever,
ChromaRetrieverConfig: self._create_chroma_retriever,
}
super().__init__(creators)
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
"""Creates and returns a retriever instance based on the provided configurations.
If multiple retrievers, using SimpleHybridRetriever.
"""
if not configs:
return self._create_default(**kwargs)
retrievers = super().get_instances(configs, **kwargs)
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
def _create_default(self, **kwargs) -> RAGRetriever:
return self._extract_index(**kwargs).as_retriever()
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return FAISSRetriever(**config.model_dump())
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
nodes = list(config.index.docstore.docs.values())
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
db = chromadb.PersistentClient(path=str(config.persist_path))
chroma_collection = db.get_or_create_collection(config.collection_name)
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
return ChromaRetriever(**config.model_dump())
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
return self._val_from_config_or_kwargs("index", config, **kwargs)
def _build_index_from_vector_store(
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
) -> VectorStoreIndex:
storage_context = StorageContext.from_defaults(vector_store=vector_store)
old_index = self._extract_index(config, **kwargs)
new_index = VectorStoreIndex(
nodes=list(old_index.docstore.docs.values()),
storage_context=storage_context,
embed_model=old_index._embed_model,
)
return new_index
get_retriever = RetrieverFactory().get_retriever

24
metagpt/rag/interface.py Normal file
View file

@ -0,0 +1,24 @@
"""RAG Interfaces."""
from typing import Protocol, runtime_checkable
@runtime_checkable
class RAGObject(Protocol):
"""Support rag add object."""
def rag_key(self) -> str:
"""For rag search."""
def model_dump_json(self) -> str:
"""For rag persist.
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
"""
@runtime_checkable
class NoEmbedding(Protocol):
"""Some retriever does not require embeddings, e.g. BM25"""
_no_embedding: bool

View file

@ -0,0 +1 @@
"""Rankers init"""

View file

@ -0,0 +1,19 @@
"""Base Ranker."""
from abc import abstractmethod
from typing import Optional
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import NodeWithScore, QueryBundle
class RAGRanker(BaseNodePostprocessor):
"""inherit from llama_index"""
@abstractmethod
def _postprocess_nodes(
self,
nodes: list[NodeWithScore],
query_bundle: Optional[QueryBundle] = None,
) -> list[NodeWithScore]:
"""postprocess nodes."""

View file

@ -0,0 +1,5 @@
"""Retrievers init."""
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
__all__ = ["SimpleHybridRetriever"]

View file

@ -0,0 +1,47 @@
"""Base retriever."""
from abc import abstractmethod
from llama_index.core.retrievers import BaseRetriever
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
from metagpt.utils.reflection import check_methods
class RAGRetriever(BaseRetriever):
"""Inherit from llama_index"""
@abstractmethod
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
"""Retrieve nodes"""
class ModifiableRAGRetriever(RAGRetriever):
"""Support modification."""
@classmethod
def __subclasshook__(cls, C):
if cls is ModifiableRAGRetriever:
return check_methods(C, "add_nodes")
return NotImplemented
@abstractmethod
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""To support add docs, must inplement this func"""
class PersistableRAGRetriever(RAGRetriever):
"""Support persistent."""
@classmethod
def __subclasshook__(cls, C):
if cls is PersistableRAGRetriever:
return check_methods(C, "persist")
return NotImplemented
@abstractmethod
def persist(self, persist_dir: str, **kwargs) -> None:
"""To support persist, must inplement this func"""

View file

@ -0,0 +1,47 @@
"""BM25 retriever."""
from typing import Callable, Optional
from llama_index.core import VectorStoreIndex
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
from llama_index.core.schema import BaseNode, IndexNode
from llama_index.retrievers.bm25 import BM25Retriever
from rank_bm25 import BM25Okapi
class DynamicBM25Retriever(BM25Retriever):
"""BM25 retriever."""
def __init__(
self,
nodes: list[BaseNode],
tokenizer: Optional[Callable[[str], list[str]]] = None,
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
callback_manager: Optional[CallbackManager] = None,
objects: Optional[list[IndexNode]] = None,
object_map: Optional[dict] = None,
verbose: bool = False,
index: VectorStoreIndex = None,
) -> None:
super().__init__(
nodes=nodes,
tokenizer=tokenizer,
similarity_top_k=similarity_top_k,
callback_manager=callback_manager,
object_map=object_map,
objects=objects,
verbose=verbose,
)
self._index = index
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._nodes.extend(nodes)
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
self.bm25 = BM25Okapi(self._corpus)
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -0,0 +1,17 @@
"""Chroma retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class ChromaRetriever(VectorIndexRetriever):
"""Chroma retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes."""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist.
Chromadb automatically saves, so there is no need to implement."""

View file

@ -0,0 +1,16 @@
"""FAISS retriever."""
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.schema import BaseNode
class FAISSRetriever(VectorIndexRetriever):
"""FAISS retriever."""
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
"""Support add nodes"""
self._index.insert_nodes(nodes, **kwargs)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
self._index.storage_context.persist(persist_dir)

View file

@ -0,0 +1,48 @@
"""Hybrid retriever."""
import copy
from llama_index.core.schema import BaseNode, QueryType
from metagpt.rag.retrievers.base import RAGRetriever
class SimpleHybridRetriever(RAGRetriever):
"""A composite retriever that aggregates search results from multiple retrievers."""
def __init__(self, *retrievers):
self.retrievers: list[RAGRetriever] = retrievers
super().__init__()
async def _aretrieve(self, query: QueryType, **kwargs):
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
This method queries each retriever in the `retrievers` list with the given query and
additional keyword arguments. It then combines the results, ensuring that each node is
unique, based on the node's ID.
"""
all_nodes = []
for retriever in self.retrievers:
# Prevent retriever changing query
query_copy = copy.deepcopy(query)
nodes = await retriever.aretrieve(query_copy, **kwargs)
all_nodes.extend(nodes)
# combine all nodes
result = []
node_ids = set()
for n in all_nodes:
if n.node.node_id not in node_ids:
result.append(n)
node_ids.add(n.node.node_id)
return result
def add_nodes(self, nodes: list[BaseNode]) -> None:
"""Support add nodes."""
for r in self.retrievers:
r.add_nodes(nodes)
def persist(self, persist_dir: str, **kwargs) -> None:
"""Support persist."""
for r in self.retrievers:
r.persist(persist_dir, **kwargs)

124
metagpt/rag/schema.py Normal file
View file

@ -0,0 +1,124 @@
"""RAG schemas."""
from pathlib import Path
from typing import Any, Union
from llama_index.core.embeddings import BaseEmbedding
from llama_index.core.indices.base import BaseIndex
from llama_index.core.schema import TextNode
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from metagpt.rag.interface import RAGObject
class BaseRetrieverConfig(BaseModel):
"""Common config for retrievers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
class IndexRetrieverConfig(BaseRetrieverConfig):
"""Config for Index-basd retrievers."""
index: BaseIndex = Field(default=None, description="Index for retriver.")
class FAISSRetrieverConfig(IndexRetrieverConfig):
"""Config for FAISS-based retrievers."""
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
class BM25RetrieverConfig(IndexRetrieverConfig):
"""Config for BM25-based retrievers."""
_no_embedding: bool = PrivateAttr(default=True)
class ChromaRetrieverConfig(IndexRetrieverConfig):
"""Config for Chroma-based retrievers."""
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BaseRankerConfig(BaseModel):
"""Common config for rankers.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
top_n: int = Field(default=5, description="The number of top results to return.")
class LLMRankerConfig(BaseRankerConfig):
"""Config for LLM-based rankers."""
llm: Any = Field(
default=None,
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
)
class BaseIndexConfig(BaseModel):
"""Common config for index.
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
"""
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
class VectorIndexConfig(BaseIndexConfig):
"""Config for vector-based index."""
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
class FAISSIndexConfig(VectorIndexConfig):
"""Config for faiss-based index."""
class ChromaIndexConfig(VectorIndexConfig):
"""Config for chroma-based index."""
collection_name: str = Field(default="metagpt", description="The name of the collection.")
class BM25IndexConfig(BaseIndexConfig):
"""Config for bm25-based index."""
_no_embedding: bool = PrivateAttr(default=True)
class ObjectNodeMetadata(BaseModel):
"""Metadata of ObjectNode."""
is_obj: bool = Field(default=True)
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
class ObjectNode(TextNode):
"""RAG add object."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
@staticmethod
def get_obj_metadata(obj: RAGObject) -> dict:
metadata = ObjectNodeMetadata(
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
)
return metadata.model_dump()

View file

View file

@ -0,0 +1,3 @@
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
__all__ = ["ChromaVectorStore"]

View file

@ -0,0 +1,290 @@
"""Chroma vector store.
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
"""
import math
from typing import Any, Dict, Generator, List, Optional, cast
import chromadb
from chromadb.api.models.Collection import Collection
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
from llama_index.core.utils import truncate_text
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
MetadataFilters,
VectorStoreQuery,
VectorStoreQueryResult,
)
from llama_index.core.vector_stores.utils import (
legacy_metadata_dict_to_node,
metadata_dict_to_node,
node_to_metadata_dict,
)
from metagpt.logs import logger
def _transform_chroma_filter_condition(condition: str) -> str:
"""Translate standard metadata filter op to Chroma specific spec."""
if condition == "and":
return "$and"
elif condition == "or":
return "$or"
else:
raise ValueError(f"Filter condition {condition} not supported")
def _transform_chroma_filter_operator(operator: str) -> str:
"""Translate standard metadata filter operator to Chroma specific spec."""
if operator == "!=":
return "$ne"
elif operator == "==":
return "$eq"
elif operator == ">":
return "$gt"
elif operator == "<":
return "$lt"
elif operator == ">=":
return "$gte"
elif operator == "<=":
return "$lte"
else:
raise ValueError(f"Filter operator {operator} not supported")
def _to_chroma_filter(
standard_filters: MetadataFilters,
) -> dict:
"""Translate standard metadata filters to Chroma specific spec."""
filters = {}
filters_list = []
condition = standard_filters.condition or "and"
condition = _transform_chroma_filter_condition(condition)
if standard_filters.filters:
for filter in standard_filters.filters:
if filter.operator:
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
else:
filters_list.append({filter.key: filter.value})
if len(filters_list) == 1:
# If there is only one filter, return it directly
return filters_list[0]
elif len(filters_list) > 1:
filters[condition] = filters_list
return filters
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class ChromaVectorStore(BasePydanticVectorStore):
"""Chroma vector store.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
stores_text: bool = True
flat_metadata: bool = True
collection_name: Optional[str]
host: Optional[str]
port: Optional[str]
ssl: bool
headers: Optional[Dict[str, str]]
persist_dir: Optional[str]
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
_collection: Any = PrivateAttr()
def __init__(
self,
chroma_collection: Optional[Any] = None,
collection_name: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: Optional[dict] = None,
**kwargs: Any,
) -> None:
"""Init params."""
collection_kwargs = collection_kwargs or {}
if chroma_collection is None:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
self._collection = cast(Collection, chroma_collection)
super().__init__(
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_name=collection_name,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs or {},
)
@classmethod
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
try:
from chromadb import Collection
except ImportError:
raise ImportError(import_err_msg)
if not isinstance(collection, Collection):
raise Exception("argument is not chromadb collection instance")
return cls(chroma_collection=collection)
@classmethod
def from_params(
cls,
collection_name: str,
host: Optional[str] = None,
port: Optional[str] = None,
ssl: bool = False,
headers: Optional[Dict[str, str]] = None,
persist_dir: Optional[str] = None,
collection_kwargs: dict = {},
**kwargs: Any,
) -> "ChromaVectorStore":
if persist_dir:
client = chromadb.PersistentClient(path=persist_dir)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
elif host and port:
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
else:
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
return cls(
chroma_collection=collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
persist_dir=persist_dir,
collection_kwargs=collection_kwargs,
**kwargs,
)
@classmethod
def class_name(cls) -> str:
return "ChromaVectorStore"
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
"""Add nodes to index.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
"""
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = MAX_CHUNK_SIZE
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings = []
metadatas = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
for key in metadata_dict:
if metadata_dict[key] is None:
metadata_dict[key] = ""
metadatas.append(metadata_dict)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
"""
Delete nodes using with ref_doc_id.
Args:
ref_doc_id (str): The doc_id of the document to delete.
"""
self._collection.delete(where={"document_id": ref_doc_id})
@property
def client(self) -> Any:
"""Return client."""
return self._collection
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
"""Query index for top k most similar nodes.
Args:
query_embedding (List[float]): query embedding
similarity_top_k (int): top k most similar nodes
"""
if query.filters is not None:
if "where" in kwargs:
raise ValueError(
"Cannot specify metadata filters via both query and kwargs. "
"Use kwargs only for chroma specific items that are "
"not supported via the generic query interface."
)
where = _to_chroma_filter(query.filters)
else:
where = kwargs.pop("where", {})
results = self._collection.query(
query_embeddings=query.query_embedding,
n_results=query.similarity_top_k,
where=where,
**kwargs,
)
logger.debug(f"> Top {len(results['documents'])} nodes:")
nodes = []
similarities = []
ids = []
for node_id, text, metadata, distance in zip(
results["ids"][0],
results["documents"][0],
results["metadatas"][0],
results["distances"][0],
):
try:
node = metadata_dict_to_node(metadata)
node.set_content(text)
except Exception:
# NOTE: deprecated legacy logic for backward compatibility
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
node = TextNode(
text=text,
id_=node_id,
metadata=metadata,
start_char_idx=node_info.get("start", None),
end_char_idx=node_info.get("end", None),
relationships=relationships,
)
nodes.append(node)
similarity_score = math.exp(-distance)
similarities.append(similarity_score)
logger.debug(
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
)
ids.append(node_id)
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)

View file

@ -108,12 +108,6 @@ class RoleContext(BaseModel):
) # see `Role._set_react_mode` for definitions of the following two attributes
max_react_loop: int = 1
def check(self, role_id: str):
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
# self.long_term_memory.recover_memory(role_id, self)
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
pass
@property
def important_memory(self) -> list[Message]:
"""Retrieve information corresponding to the attention action."""
@ -311,8 +305,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
buffer during _observe.
"""
self.rc.watch = {any_to_str(t) for t in actions}
# check RoleContext after adding watch actions
self.rc.check(self.role_id)
def is_watch(self, caused_by: str):
return caused_by in self.rc.watch

View file

@ -11,7 +11,6 @@ from typing import Optional
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize, UserRequirement
from metagpt.document_store.base_store import BaseStore
from metagpt.roles import Role
from metagpt.tools.search_engine import SearchEngine
@ -27,7 +26,7 @@ class Sales(Role):
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
)
store: Optional[BaseStore] = Field(default=None, exclude=True)
store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface
@model_validator(mode="after")
def validate_stroe(self):

View file

@ -233,6 +233,10 @@ class Message(BaseModel):
def check_send_to(cls, send_to: Any) -> set:
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
@field_serializer("send_to", mode="plain")
def ser_send_to(self, send_to: set) -> list:
return list(send_to)
@field_serializer("instruct_content", mode="plain")
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
ic_dict = None
@ -276,6 +280,10 @@ class Message(BaseModel):
def __repr__(self):
return self.__str__()
def rag_key(self) -> str:
"""For search"""
return self.content
def to_dict(self) -> dict:
"""Return a dict containing `role` and `content` for the LLM call.l"""
return {"role": self.role, "content": self.content}

View file

@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum):
def __missing__(cls, key):
"""Default type conversion"""
return cls.CUSTOM
class SearchInterface:
async def asearch(self, *args, **kwargs):
...

View file

@ -0,0 +1,22 @@
import asyncio
import threading
from typing import Any
def run_coroutine_in_new_loop(coroutine) -> Any:
"""Runs a coroutine in a new, separate event loop on a different thread.
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
"""
new_loop = asyncio.new_event_loop()
t = threading.Thread(target=lambda: new_loop.run_forever())
t.start()
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
try:
return future.result()
finally:
new_loop.call_soon_threadsafe(new_loop.stop)
t.join()
new_loop.close()

View file

@ -5,12 +5,15 @@
@Author : alexanderwu
@File : embedding.py
"""
from langchain_community.embeddings import OpenAIEmbeddings
from llama_index.embeddings.openai import OpenAIEmbedding
from metagpt.config2 import config
def get_embedding():
def get_embedding() -> OpenAIEmbedding:
llm = config.get_openai_llm()
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
if llm is None:
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
return embedding

View file

@ -0,0 +1,18 @@
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
def check_methods(C, *methods):
"""Check if the class has methods. borrow from _collections_abc.
Useful when implementing implicit interfaces, such as defining an abstract class, isinstance can be used for determination without inheritance.
"""
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True

View file

@ -1,7 +1,6 @@
aiohttp==3.8.4
aiohttp==3.8.6
#azure_storage==0.37.0
channels==4.0.0
# chromadb
# Django==4.1.5
# docx==0.2.4
#faiss==1.5.3
@ -11,14 +10,20 @@ typer==0.9.0
# godot==0.1.1
# google_api_python_client==2.93.0 # Used by search_engine.py
lancedb==0.4.0
langchain==0.1.8
sqlalchemy==2.0.0 # along with langchain
llama-index-core==0.10.15
llama-index-embeddings-azure-openai==0.1.6
llama-index-embeddings-openai==0.1.5
llama-index-llms-azure-openai==0.1.4
llama-index-readers-file==0.1.4
llama-index-retrievers-bm25==0.1.3
llama-index-vector-stores-faiss==0.1.1
chromadb==0.4.23
loguru==0.6.0
meilisearch==0.21.0
numpy>=1.24.3,<1.25.0
openai==1.6.0
numpy==1.24.3
openai==1.6.1
openpyxl
beautifulsoup4==4.12.2
beautifulsoup4==4.12.3
pandas==2.0.3
pydantic==2.5.3
#pygame==2.1.3
@ -30,7 +35,7 @@ PyYAML==6.0.1
setuptools==65.6.3
tenacity==8.2.3
tiktoken==0.5.2
tqdm==4.65.0
tqdm==4.66.2
#unstructured[local-inference]
# selenium>4
# webdriver_manager<3.9
@ -61,7 +66,7 @@ typing-extensions==4.9.0
socksio~=1.0.0
gitignore-parser==0.1.9
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
websockets~=12.0
websockets~=11.0
networkx~=3.2.1
google-generativeai==0.3.2
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py

View file

@ -42,7 +42,7 @@ extras_require["test"] = [
"connexion[uvicorn]~=3.0.5",
"azure-cognitiveservices-speech~=1.31.0",
"aioboto3~=11.3.0",
"chromadb==0.4.14",
"chromadb==0.4.23",
"gradio==3.0.0",
"grpcio-status==1.48.2",
"pylint==3.0.3",

View file

@ -12,7 +12,7 @@ from metagpt.document_store.chromadb_store import ChromaStore
def test_chroma_store():
"""FIXMEchroma使用感觉很诡异一用Python就挂测试用例里也是"""
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
document_store = ChromaStore("sample_collection_1")
document_store = ChromaStore("sample_collection_1", get_or_create=True)
# 使用 write 方法添加多个文档
document_store.write(

View file

@ -6,8 +6,6 @@
@File : test_faiss_store.py
"""
from typing import Optional
import numpy as np
import pytest
@ -17,18 +15,24 @@ from metagpt.logs import logger
from metagpt.roles import Sales
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
num = len(texts)
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
return embeds
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
return embeds.tolist()
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
@pytest.mark.asyncio
async def test_search_json(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "example.json")
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -37,9 +41,10 @@ async def test_search_json(mocker):
@pytest.mark.asyncio
async def test_search_xlsx(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
role = Sales(profile="Sales", store=store)
query = "Which facial cleanser is good for oily skin?"
result = await role.run(query)
@ -48,9 +53,10 @@ async def test_search_xlsx(mocker):
@pytest.mark.asyncio
async def test_write(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
_faiss_store = store.write()
assert _faiss_store.docstore
assert _faiss_store.index
assert _faiss_store.storage_context.docstore
assert _faiss_store.storage_context.vector_store.client

View file

@ -2,32 +2,41 @@
# -*- coding: utf-8 -*-
# @Desc :
from typing import Optional
import numpy as np
dim = 1536 # openai embedding dim
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
text_embed_arr = [
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
{"text": "Write a Battle City", "embed": embed_ones_arrr},
{
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
"embed": np.zeros(shape=[1, dim]),
"embed": embed_zeros_arrr,
},
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
{
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
"embed": np.ones(shape=[1, dim]),
"embed": embed_ones_arrr,
},
]
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
idx = text_idx_dict.get(texts[0])
embed = text_embed_arr[idx].get("embed")
return embed
def mock_openai_embed_document(self, text: str) -> list[float]:
embeds = mock_openai_embed_documents(self, [text])
return embeds[0]
async def mock_openai_aembed_document(self, text: str) -> list[float]:
return mock_openai_embed_document(self, text)

View file

@ -12,13 +12,20 @@ from metagpt.memory.longterm_memory import LongTermMemory
from metagpt.roles.role import RoleContext
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
def test_ltm_search(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
@pytest.mark.asyncio
async def test_ltm_search(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
role_id = "UTUserLtm(Product Manager)"
from metagpt.environment import Environment
@ -31,39 +38,24 @@ def test_ltm_search(mocker):
idea = text_embed_arr[0].get("text", "Write a cli snake game")
message = Message(role="User", content=idea, cause_by=UserRequirement)
news = ltm.find_news([message])
news = await ltm.find_news([message])
assert len(news) == 1
ltm.add(message)
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
news = ltm.find_news([sim_message])
news = await ltm.find_news([sim_message])
assert len(news) == 0
ltm.add(sim_message)
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm.find_news([new_message])
news = await ltm.find_news([new_message])
assert len(news) == 1
ltm.add(new_message)
# restore from local index
ltm_new = LongTermMemory()
ltm_new.recover_memory(role_id, rc)
news = ltm_new.find_news([message])
assert len(news) == 0
ltm_new.recover_memory(role_id, rc)
news = ltm_new.find_news([sim_message])
assert len(news) == 0
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
news = ltm_new.find_news([new_message])
assert len(news) == 1
ltm_new.clear()
ltm.clear()
if __name__ == "__main__":

View file

@ -8,19 +8,28 @@ import shutil
from pathlib import Path
from typing import List
import pytest
from metagpt.actions import UserRequirement, WritePRD
from metagpt.actions.action_node import ActionNode
from metagpt.const import DATA_PATH
from metagpt.memory.memory_storage import MemoryStorage
from metagpt.schema import Message
from tests.metagpt.memory.mock_text_embed import (
mock_openai_aembed_document,
mock_openai_embed_document,
mock_openai_embed_documents,
text_embed_arr,
)
def test_idea_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
@pytest.mark.asyncio
async def test_idea_message(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
idea = text_embed_arr[0].get("text", "Write a cli snake game")
role_id = "UTUser1(Product Manager)"
@ -29,28 +38,32 @@ def test_idea_message(mocker):
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
memory_storage.recover_memory(role_id)
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False
def test_actionout_message(mocker):
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
@pytest.mark.asyncio
async def test_actionout_message(mocker):
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
mocker.patch(
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
)
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
@ -67,23 +80,22 @@ def test_actionout_message(mocker):
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
memory_storage: MemoryStorage = MemoryStorage()
messages = memory_storage.recover_memory(role_id)
assert len(messages) == 0
memory_storage.recover_memory(role_id)
memory_storage.add(message)
assert memory_storage.is_initialized is True
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(sim_message)
assert len(new_messages) == 0 # similar, return []
new_messages = await memory_storage.search_similar(sim_message)
assert len(new_messages) == 1 # similar, return []
new_conent = text_embed_arr[6].get(
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
)
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
new_messages = memory_storage.search_dissimilar(new_message)
assert new_messages[0].content == message.content
new_messages = await memory_storage.search_similar(new_message)
assert len(new_messages) == 0
memory_storage.clean()
assert memory_storage.is_initialized is False

View file

@ -0,0 +1,166 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Document, TextNode
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
class TestSimpleEngine:
@pytest.fixture
def mock_simple_directory_reader(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
@pytest.fixture
def mock_vector_store_index(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
@pytest.fixture
def mock_get_retriever(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
@pytest.fixture
def mock_get_rankers(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
@pytest.fixture
def mock_get_response_synthesizer(self, mocker):
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
def test_from_docs(
self,
mocker,
mock_simple_directory_reader,
mock_vector_store_index,
mock_get_retriever,
mock_get_rankers,
mock_get_response_synthesizer,
):
# Mock
mock_simple_directory_reader.return_value.load_data.return_value = [
Document(text="document1"),
Document(text="document2"),
]
mock_get_retriever.return_value = mocker.MagicMock()
mock_get_rankers.return_value = [mocker.MagicMock()]
mock_get_response_synthesizer.return_value = mocker.MagicMock()
# Setup
input_dir = "test_dir"
input_files = ["test_file1", "test_file2"]
transformations = [mocker.MagicMock()]
embed_model = mocker.MagicMock()
llm = mocker.MagicMock()
retriever_configs = [mocker.MagicMock()]
ranker_configs = [mocker.MagicMock()]
# Execute
engine = SimpleEngine.from_docs(
input_dir=input_dir,
input_files=input_files,
transformations=transformations,
embed_model=embed_model,
llm=llm,
retriever_configs=retriever_configs,
ranker_configs=ranker_configs,
)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
mock_vector_store_index.assert_called_once()
mock_get_retriever.assert_called_once_with(
configs=retriever_configs, index=mock_vector_store_index.return_value
)
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
assert isinstance(engine, SimpleEngine)
@pytest.mark.asyncio
async def test_asearch(self, mocker):
# Mock
test_query = "test query"
expected_result = "expected result"
mock_aquery = mocker.AsyncMock(return_value=expected_result)
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
engine.aquery = mock_aquery
# Execute
result = await engine.asearch(test_query)
# Assertions
mock_aquery.assert_called_once_with(test_query)
assert result == expected_result
@pytest.mark.asyncio
async def test_aretrieve(self, mocker):
# Mock
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
mock_super_aretrieve = mocker.patch(
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
)
mock_super_aretrieve.return_value = [TextNode(text="node_with_score", metadata={"is_obj": False})]
# Setup
engine = SimpleEngine(retriever=mocker.MagicMock())
test_query = "test query"
# Execute
result = await engine.aretrieve(test_query)
# Assertions
mock_query_bundle.assert_called_once_with(test_query)
mock_super_aretrieve.assert_called_once_with("query_bundle")
assert result[0].text == "node_with_score"
def test_add_docs(self, mocker):
# Mock
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
mock_simple_directory_reader.return_value.load_data.return_value = [
Document(text="document1"),
Document(text="document2"),
]
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
mock_index._transformations = mocker.MagicMock()
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
mock_run_transformations.return_value = ["node1", "node2"]
# Setup
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
input_files = ["test_file1", "test_file2"]
# Execute
engine.add_docs(input_files=input_files)
# Assertions
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
def test_add_objs(self, mocker):
# Mock
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
# Setup
class CustomTextNode(TextNode):
def rag_key(self):
return ""
def model_dump_json(self):
return ""
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
# Execute
engine.add_objs(objs=objs)
# Assertions
assert mock_retriever.add_nodes.call_count == 1
for node in mock_retriever.add_nodes.call_args[0][0]:
assert isinstance(node, TextNode)
assert "is_obj" in node.metadata

View file

@ -0,0 +1,102 @@
import pytest
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
class TestGenericFactory:
@pytest.fixture
def creators(self):
return {
"type1": lambda name: f"Instance of type1 with {name}",
"type2": lambda name: f"Instance of type2 with {name}",
}
@pytest.fixture
def factory(self, creators):
return GenericFactory(creators=creators)
def test_get_instance_success(self, factory):
# Test successful retrieval of an instance
key = "type1"
instance = factory.get_instance(key, name="TestName")
assert instance == "Instance of type1 with TestName"
def test_get_instance_failure(self, factory):
# Test failure to retrieve an instance due to unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instance("unknown_key")
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
def test_get_instances_success(self, factory):
# Test successful retrieval of multiple instances
keys = ["type1", "type2"]
instances = factory.get_instances(keys, name="TestName")
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
assert instances == expected
@pytest.mark.parametrize(
"keys,expected_exception_message",
[
(["unknown_key"], "Creator not registered for key: unknown_key"),
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
],
)
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
# Test failure to retrieve instances due to at least one unregistered key
with pytest.raises(ValueError) as exc_info:
factory.get_instances(keys, name="TestName")
assert expected_exception_message in str(exc_info.value)
class DummyConfig:
"""A dummy config class for testing."""
def __init__(self, name):
self.name = name
class TestConfigBasedFactory:
@pytest.fixture
def config_creators(self):
return {
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
}
@pytest.fixture
def config_factory(self, config_creators):
return ConfigBasedFactory(creators=config_creators)
def test_get_instance_success(self, config_factory):
# Test successful retrieval of an instance
config = DummyConfig(name="TestConfig")
instance = config_factory.get_instance(config, extra="additional data")
assert instance == "Processed TestConfig with additional data"
def test_get_instance_failure(self, config_factory):
# Test failure to retrieve an instance due to unknown config type
class UnknownConfig:
pass
config = UnknownConfig()
with pytest.raises(ValueError) as exc_info:
config_factory.get_instance(config)
assert "Unknown config:" in str(exc_info.value)
def test_val_from_config_or_kwargs_priority(self):
# Test that the value from the config object has priority over kwargs
config = DummyConfig(name="ConfigName")
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "ConfigName"
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
# Test fallback to kwargs when config object does not have the value
config = DummyConfig(name=None)
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
assert result == "KwargsName"
def test_val_from_config_or_kwargs_key_error(self):
# Test KeyError when the key is not found in both config object and kwargs
config = DummyConfig(name=None)
with pytest.raises(KeyError) as exc_info:
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)

View file

@ -0,0 +1,41 @@
import pytest
from llama_index.core.llms import LLM
from llama_index.core.postprocessor import LLMRerank
from metagpt.rag.factories.ranker import RankerFactory
from metagpt.rag.schema import LLMRankerConfig
class TestRankerFactory:
@pytest.fixture
def ranker_factory(self) -> RankerFactory:
return RankerFactory()
@pytest.fixture
def mock_llm(self, mocker):
return mocker.MagicMock(spec=LLM)
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
default_rankers = ranker_factory.get_rankers()
assert len(default_rankers) == 0
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
rankers = ranker_factory.get_rankers(configs=[mock_config])
assert len(rankers) == 1
assert isinstance(rankers[0], LLMRerank)
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
ranker = ranker_factory._create_llm_ranker(mock_config)
assert isinstance(ranker, LLMRerank)
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
mock_config = LLMRankerConfig(llm=mock_llm)
extracted_llm = ranker_factory._extract_llm(config=mock_config)
assert extracted_llm == mock_llm
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
assert extracted_llm == mock_llm

View file

@ -0,0 +1,79 @@
import faiss
import pytest
from llama_index.core import VectorStoreIndex
from metagpt.rag.factories.retriever import RetrieverFactory
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
class TestRetrieverFactory:
@pytest.fixture
def retriever_factory(self):
return RetrieverFactory()
@pytest.fixture
def mock_faiss_index(self, mocker):
return mocker.MagicMock(spec=faiss.IndexFlatL2)
@pytest.fixture
def mock_vector_store_index(self, mocker):
mock = mocker.MagicMock(spec=VectorStoreIndex)
mock._embed_model = mocker.MagicMock()
mock.docstore.docs.values.return_value = []
return mock
def test_get_retriever_with_faiss_config(
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
):
mock_config = FAISSRetrieverConfig(dimensions=128)
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, FAISSRetriever)
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mock_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_config])
assert isinstance(retriever, DynamicBM25Retriever)
def test_get_retriever_with_multiple_configs_returns_hybrid(
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
):
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
mock_bm25_config = BM25RetrieverConfig()
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
assert isinstance(retriever, SimpleHybridRetriever)
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
mock_vector_store_index.as_retriever = mocker.MagicMock()
retriever = retriever_factory.get_retriever()
mock_vector_store_index.as_retriever.assert_called_once()
assert retriever is mock_vector_store_index.as_retriever.return_value
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
extracted_index = retriever_factory._extract_index(config=mock_config)
assert extracted_index == mock_vector_store_index
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
assert extracted_index == mock_vector_store_index

View file

@ -0,0 +1,37 @@
import pytest
from llama_index.core import VectorStoreIndex
from llama_index.core.schema import Node
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
class TestDynamicBM25Retriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc1.get_content.return_value = "Document content 1"
self.doc2 = mocker.MagicMock(spec=Node)
self.doc2.get_content.return_value = "Document content 2"
self.mock_nodes = [self.doc1, self.doc2]
# 模拟index
index = mocker.MagicMock(spec=VectorStoreIndex)
# 模拟nodes和tokenizer参数
mock_nodes = []
mock_tokenizer = mocker.MagicMock()
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
# 初始化DynamicBM25Retriever对象并提供必需的参数
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
def test_add_docs_updates_nodes_and_corpus(self):
# Execute
self.retriever.add_nodes(self.mock_nodes)
# Assertions
assert len(self.retriever._nodes) == len(self.mock_nodes)
assert len(self.retriever._corpus) == len(self.mock_nodes)
self.retriever._tokenizer.assert_called()
self.mock_bm25okapi.assert_called()

View file

@ -0,0 +1,22 @@
import pytest
from llama_index.core.schema import Node
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
class TestFAISSRetriever:
@pytest.fixture(autouse=True)
def setup(self, mocker):
# 创建模拟的Document对象
self.doc1 = mocker.MagicMock(spec=Node)
self.doc2 = mocker.MagicMock(spec=Node)
self.mock_nodes = [self.doc1, self.doc2]
# 模拟FAISSRetriever的_index属性
self.mock_index = mocker.MagicMock()
self.retriever = FAISSRetriever(self.mock_index)
def test_add_docs_calls_insert_for_each_document(self, mocker):
self.retriever.add_nodes(self.mock_nodes)
assert self.mock_index.insert_nodes.assert_called

View file

@ -0,0 +1,39 @@
from unittest.mock import AsyncMock
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
from metagpt.rag.retrievers import SimpleHybridRetriever
class TestSimpleHybridRetriever:
@pytest.mark.asyncio
async def test_aretrieve(self):
question = "test query"
# Create mock retrievers
mock_retriever1 = AsyncMock()
mock_retriever1.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="1"), score=1.0),
NodeWithScore(node=TextNode(id_="2"), score=0.95),
]
mock_retriever2 = AsyncMock()
mock_retriever2.aretrieve.return_value = [
NodeWithScore(node=TextNode(id_="2"), score=0.95),
NodeWithScore(node=TextNode(id_="3"), score=0.8),
]
# Instantiate the SimpleHybridRetriever with the mock retrievers
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
# Call the _aretrieve method
results = await hybrid_retriever._aretrieve(question)
# Check if the results are as expected
assert len(results) == 3 # Should be 3 unique nodes
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
# Check if the scores are correct (assuming you want the highest score)
node_scores = {node.node.node_id: node.score for node in results}
assert node_scores["2"] == 0.95