mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
commit
e783e5b208
61 changed files with 2353 additions and 248 deletions
5
.gitattributes
vendored
5
.gitattributes
vendored
|
|
@ -12,6 +12,11 @@
|
|||
*.jpg binary
|
||||
*.gif binary
|
||||
*.ico binary
|
||||
*.jpeg binary
|
||||
*.mp3 binary
|
||||
*.zip binary
|
||||
*.bin binary
|
||||
|
||||
|
||||
# Preserve original line endings for specific document files
|
||||
*.doc text eol=crlf
|
||||
|
|
|
|||
13
.gitignore
vendored
13
.gitignore
vendored
|
|
@ -27,6 +27,8 @@ share/python-wheels/
|
|||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
metagpt/tools/schemas/
|
||||
examples/data/search_kb/*.json
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python scripts from a template
|
||||
|
|
@ -151,9 +153,14 @@ allure-results
|
|||
.vscode
|
||||
|
||||
key.yaml
|
||||
data
|
||||
/data/
|
||||
data.ms
|
||||
examples/nb/
|
||||
examples/default__vector_store.json
|
||||
examples/docstore.json
|
||||
examples/graph_store.json
|
||||
examples/image__vector_store.json
|
||||
examples/index_store.json
|
||||
.chroma
|
||||
*~$*
|
||||
workspace/*
|
||||
|
|
@ -168,6 +175,7 @@ output
|
|||
tmp.png
|
||||
.dependencies.json
|
||||
tests/metagpt/utils/file_repo_git
|
||||
tests/data/rsp_cache_new.json
|
||||
*.tmp
|
||||
*.png
|
||||
htmlcov
|
||||
|
|
@ -178,4 +186,5 @@ cov.xml
|
|||
*.faiss
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
metagpt/tools/schemas
|
||||
*.dot
|
||||
.python-version
|
||||
|
|
|
|||
1
examples/data/rag/travel.txt
Normal file
1
examples/data/rag/travel.txt
Normal file
|
|
@ -0,0 +1 @@
|
|||
Bob likes traveling.
|
||||
109
examples/data/rag/writer.txt
Normal file
109
examples/data/rag/writer.txt
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
Productivity
|
||||
I think I am at least somewhat more productive than average, and people sometimes ask me for productivity tips. So I decided to just write them all down in one place.
|
||||
|
||||
Compound growth gets discussed as a financial concept, but it works in careers as well, and it is magic. A small productivity gain, compounded over 50 years, is worth a lot. So it’s worth figuring out how to optimize productivity. If you get 10% more done and 1% better every day compared to someone else, the compounded difference is massive.
|
||||
|
||||
What you work on
|
||||
|
||||
Famous writers have some essential qualities, creativity and discipline
|
||||
|
||||
It doesn’t matter how fast you move if it’s in a worthless direction. Picking the right thing to work on is the most important element of productivity and usually almost ignored. So think about it more! Independent thought is hard but it’s something you can get better at with practice.
|
||||
|
||||
The most impressive people I know have strong beliefs about the world, which is rare in the general population. If you find yourself always agreeing with whomever you last spoke with, that’s bad. You will of course be wrong sometimes, but develop the confidence to stick with your convictions. It will let you be courageous when you’re right about something important that most people don’t see.
|
||||
|
||||
I make sure to leave enough time in my schedule to think about what to work on. The best ways for me to do this are reading books, hanging out with interesting people, and spending time in nature.
|
||||
|
||||
I’ve learned that I can’t be very productive working on things I don’t care about or don’t like. So I just try not to put myself in a position where I have to do them (by delegating, avoiding, or something else). Stuff that you don’t like is a painful drag on morale and momentum.
|
||||
|
||||
By the way, here is an important lesson about delegation: remember that everyone else is also most productive when they’re doing what they like, and do what you’d want other people to do for you—try to figure out who likes (and is good at) doing what, and delegate that way.
|
||||
|
||||
If you find yourself not liking what you’re doing for a long period of time, seriously consider a major job change. Short-term burnout happens, but if it isn’t resolved with some time off, maybe it’s time to do something you’re more interested in.
|
||||
|
||||
I’ve been very fortunate to find work I like so much I’d do it for free, which makes it easy to be really productive.
|
||||
|
||||
It’s important to learn that you can learn anything you want, and that you can get better quickly. This feels like an unlikely miracle the first few times it happens, but eventually you learn to trust that you can do it.
|
||||
|
||||
Doing great work usually requires colleagues of some sort. Try to be around smart, productive, happy, and positive people that don’t belittle your ambitions. I love being around people who push me and inspire me to be better. To the degree you able to, avoid the opposite kind of people—the cost of letting them take up your mental cycles is horrific.
|
||||
|
||||
You have to both pick the right problem and do the work. There aren’t many shortcuts. If you’re going to do something really important, you are very likely going to work both smart and hard. The biggest prizes are heavily competed for. This isn’t true in every field (there are great mathematicians who never spend that many hours a week working) but it is in most.
|
||||
|
||||
Prioritization
|
||||
|
||||
Writers have to work hard to be successful
|
||||
|
||||
My system has three key pillars: “Make sure to get the important shit done”, “Don’t waste time on stupid shit”, and “make a lot of lists”.
|
||||
|
||||
I highly recommend using lists. I make lists of what I want to accomplish each year, each month, and each day. Lists are very focusing, and they help me with multitasking because I don’t have to keep as much in my head. If I’m not in the mood for some particular task, I can always find something else I’m excited to do.
|
||||
|
||||
I prefer lists written down on paper. It’s easy to add and remove tasks. I can access them during meetings without feeling rude. I re-transcribe lists frequently, which forces me to think about everything on the list and gives me an opportunity to add and remove items.
|
||||
|
||||
I don’t bother with categorization or trying to size tasks or anything like that (the most I do is put a star next to really important items).
|
||||
|
||||
I try to prioritize in a way that generates momentum. The more I get done, the better I feel, and then the more I get done. I like to start and end each day with something I can really make progress on.
|
||||
|
||||
I am relentless about getting my most important projects done—I’ve found that if I really want something to happen and I push hard enough, it usually happens.
|
||||
|
||||
I try to be ruthless about saying no to stuff, and doing non-critical things in the quickest way possible. I probably take this too far—for example, I am almost sure I am terse to the point of rudeness when replying to emails.
|
||||
|
||||
Passion and adaptability are key qualities to writers
|
||||
|
||||
I generally try to avoid meetings and conferences as I find the time cost to be huge—I get the most value out of time in my office. However, it is critical that you keep enough space in your schedule to allow for chance encounters and exposure to new people and ideas. Having an open network is valuable; though probably 90% of the random meetings I take are a waste of time, the other 10% really make up for it.
|
||||
|
||||
I find most meetings are best scheduled for 15-20 minutes, or 2 hours. The default of 1 hour is usually wrong, and leads to a lot of wasted time.
|
||||
|
||||
I have different times of day I try to use for different kinds of work. The first few hours of the morning are definitely my most productive time of the day, so I don’t let anyone schedule anything then. I try to do meetings in the afternoon. I take a break, or switch tasks, whenever I feel my attention starting to fade.
|
||||
|
||||
I don’t think most people value their time enough—I am surprised by the number of people I know who make $100 an hour and yet will spend a couple of hours doing something they don’t want to do to save $20.
|
||||
|
||||
Also, don’t fall into the trap of productivity porn—chasing productivity for its own sake isn’t helpful. Many people spend too much time thinking about how to perfectly optimize their system, and not nearly enough asking if they’re working on the right problems. It doesn’t matter what system you use or if you squeeze out every second if you’re working on the wrong thing.
|
||||
|
||||
The right goal is to allocate your year optimally, not your day.
|
||||
|
||||
Physical factors
|
||||
|
||||
Very likely what is optimal for me won’t be optimal for you. You’ll have to experiment to find out what works best for your body. It’s definitely worth doing—it helps in all aspects of life, and you’ll feel a lot better and happier overall.
|
||||
|
||||
It probably took a little bit of my time every week for a few years to arrive at what works best for me, but my sense is if I do a good job at all the below I’m at least 1.5x more productive than if not.
|
||||
|
||||
Sleep seems to be the most important physical factor in productivity for me. Some sort of sleep tracker to figure out how to sleep best is helpful. I’ve found the only thing I’m consistent with are in the set-it-and-forget-it category, and I really like the Emfit QS+Active.
|
||||
|
||||
I like a cold, dark, quiet room, and a great mattress (I resisted spending a bunch of money on a great mattress for years, which was stupid—it makes a huge difference to my sleep quality. I love this one). Not eating a lot in the few hours before sleep helps. Not drinking alcohol helps a lot, though I’m not willing to do that all the time.
|
||||
|
||||
I use a Chili Pad to be cold while I sleep if I can’t get the room cold enough, which is great but loud (I set it up to have the cooler unit outside my room).
|
||||
|
||||
When traveling, I use an eye mask and ear plugs.
|
||||
|
||||
Writers usually have empathy to write good books.
|
||||
|
||||
This is likely to be controversial, but I take a low dose of sleeping pills (like a third of a normal dose) or a very low dose of cannabis whenever I can’t sleep. I am a bad sleeper in general, and a particularly bad sleeper when I travel. It likely has tradeoffs, but so does not sleeping well. If you can already sleep well, I wouldn’t recommend this.
|
||||
|
||||
I use a full spectrum LED light most mornings for about 10-15 minutes while I catch up on email. It’s great—if you try nothing else in here, this is the thing I’d try. It’s a ridiculous gain for me. I like this one, and it’s easy to travel with.
|
||||
|
||||
Exercise is probably the second most important physical factor. I tried a number of different exercise programs for a few months each and the one that seemed best was lifting heavy weights 3x a week for an hour, and high intensity interval training occasionally. In addition to productivity gains, this is also the exercise program that makes me feel the best overall.
|
||||
|
||||
The third area is nutrition. I very rarely eat breakfast, so I get about 15 hours of fasting most days (except an espresso when I wake up). I know this is contrary to most advice, and I suspect it’s not optimal for most people, but it definitely works well for me.
|
||||
|
||||
Eating lots of sugar is the thing that makes me feel the worst and that I try hardest to avoid. I also try to avoid foods that aggravate my digestion or spike up inflammation (for example, very spicy foods). I don’t have much willpower when it comes to sweet things, so I mostly just try to keep junk food out of the house.
|
||||
|
||||
I have one big shot of espresso immediately when I wake up and one after lunch. I assume this is about 200mg total of caffeine per day. I tried a few other configurations; this was the one that worked by far the best. I otherwise aggressively avoid stimulants, but I will have more coffee if I’m super tired and really need to get something done.
|
||||
|
||||
If a writer want to be super, then should include innovative thinking.
|
||||
|
||||
I’m vegetarian and have been since I was a kid, and I supplement methyl B-12, Omega-3, Iron, and Vitamin D-3. I got to this list with a year or so of quarterly blood tests; it’s worked for me ever since (I re-test maybe every year and a half or so). There are many doctors who will happily work with you on a super comprehensive blood test (and services like WellnessFX). I also go out of my way to drink a lot of protein shakes, which I hate and I wouldn’t do if I weren’t vegetarian.
|
||||
|
||||
Other stuff
|
||||
|
||||
Here’s what I like in a workspace: natural light, quiet, knowing that I won’t be interrupted if I don’t want to be, long blocks of time, and being comfortable and relaxed (I’ve got a beautiful desk with a couple of 4k monitors on it in my office, but I spend almost all my time on my couch with my laptop).
|
||||
|
||||
I wrote custom software for the annoying things I have to do frequently, which is great. I also made an effort to learn to type really fast and the keyboard shortcuts that help with my workflow.
|
||||
|
||||
Like most people, I sometimes go through periods of a week or two where I just have no motivation to do anything (I suspect it may have something to do with nutrition). This sucks and always seems to happen at inconvenient times. I have not figured out what to do about it besides wait for the fog to lift, and to trust that eventually it always does. And I generally try to avoid people and situations that put me in bad moods, which is good advice whether you care about productivity or not.
|
||||
|
||||
In general, I think it’s good to overcommit a little bit. I find that I generally get done what I take on, and if I have a little bit too much to do it makes me more efficient at everything, which is a way to train to avoid distractions (a great habit to build!). However, overcommitting a lot is disastrous.
|
||||
|
||||
Don’t neglect your family and friends for the sake of productivity—that’s a very stupid tradeoff (and very likely a net productivity loss, because you’ll be less happy). Don’t neglect doing things you love or that clear your head either.
|
||||
|
||||
Finally, to repeat one more time: productivity in the wrong direction isn’t worth anything at all. Think more about what to work on.
|
||||
|
||||
Open-Mindedness and curiosity are essential to writers
|
||||
|
||||
211
examples/rag_pipeline.py
Normal file
211
examples/rag_pipeline.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""RAG pipeline"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_DATA_PATH
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
ChromaIndexConfig,
|
||||
ChromaRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
|
||||
TRAVEL_QUESTION = "What does Bob like?"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
|
||||
|
||||
class Player(BaseModel):
|
||||
"""To demonstrate rag add objs."""
|
||||
|
||||
name: str = ""
|
||||
goal: str = "Win The 100-meter Sprint."
|
||||
tool: str = "Red Bull Energy Drink."
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.goal
|
||||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
||||
Retrieve Result:
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
|
||||
Query Result:
|
||||
Passion, adaptability, open-mindedness, creativity, discipline, and empathy are key qualities to be a good writer.
|
||||
"""
|
||||
if print_title:
|
||||
self._print_title("Run Pipeline")
|
||||
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_retrieve_result(nodes)
|
||||
|
||||
answer = await self.engine.aquery(question)
|
||||
self._print_query_result(answer)
|
||||
|
||||
async def add_docs(self):
|
||||
"""This example show how to add docs.
|
||||
|
||||
Before add docs llm anwser I don't know.
|
||||
After add docs llm give the correct answer, will print something like:
|
||||
|
||||
[Before add docs]
|
||||
Retrieve Result:
|
||||
|
||||
Query Result:
|
||||
Empty Response
|
||||
|
||||
[After add docs]
|
||||
Retrieve Result:
|
||||
0. Bob like..., 10.0
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("Add Docs")
|
||||
|
||||
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
|
||||
travel_filepath = TRAVEL_DOC_PATH
|
||||
|
||||
logger.info("[Before add docs]")
|
||||
await self.run_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
logger.info("[After add docs]")
|
||||
self.engine.add_docs([travel_filepath])
|
||||
await self.run_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
async def add_objects(self, print_title=True):
|
||||
"""This example show how to add objects.
|
||||
|
||||
Before add docs, engine retrieve nothing.
|
||||
After add objects, engine give the correct answer, will print something like:
|
||||
|
||||
[Before add objs]
|
||||
Retrieve Result:
|
||||
|
||||
[After add objs]
|
||||
Retrieve Result:
|
||||
0. 100m Sprin..., 10.0
|
||||
|
||||
[Object Detail]
|
||||
{'name': 'Mike', 'goal': 'Win The 100-meter Sprint', 'tool': 'Red Bull Energy Drink'}
|
||||
"""
|
||||
if print_title:
|
||||
self._print_title("Add Objects")
|
||||
|
||||
player = Player(name="Mike")
|
||||
question = f"{player.rag_key()}"
|
||||
|
||||
logger.info("[Before add objs]")
|
||||
await self._retrieve_and_print(question)
|
||||
|
||||
logger.info("[After add objs]")
|
||||
self.engine.add_objs([player])
|
||||
|
||||
try:
|
||||
nodes = await self._retrieve_and_print(question)
|
||||
|
||||
logger.info("[Object Detail]")
|
||||
player: Player = nodes[0].metadata["obj"]
|
||||
logger.info(player.name)
|
||||
except Exception as e:
|
||||
logger.error(f"nodes is empty, llm don't answer correctly, exception: {e}")
|
||||
|
||||
async def init_objects(self):
|
||||
"""This example show how to from objs, will print something like:
|
||||
|
||||
Same as add_objects.
|
||||
"""
|
||||
self._print_title("Init Objects")
|
||||
|
||||
pre_engine = self.engine
|
||||
self.engine = SimpleEngine.from_objs(retriever_configs=[FAISSRetrieverConfig()])
|
||||
await self.add_objects(print_title=False)
|
||||
self.engine = pre_engine
|
||||
|
||||
async def init_and_query_chromadb(self):
|
||||
"""This example show how to use chromadb. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
"""
|
||||
self._print_title("Init And Query ChromaDB")
|
||||
|
||||
# save index
|
||||
output_dir = DATA_PATH / "rag"
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=ChromaIndexConfig(persist_path=output_dir),
|
||||
)
|
||||
|
||||
# query
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@staticmethod
|
||||
def _print_title(title):
|
||||
logger.info(f"{'#'*30} {title} {'#'*30}")
|
||||
|
||||
@staticmethod
|
||||
def _print_retrieve_result(result):
|
||||
"""Print retrieve result."""
|
||||
logger.info("Retrieve Result:")
|
||||
|
||||
for i, node in enumerate(result):
|
||||
logger.info(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
|
||||
logger.info("")
|
||||
|
||||
@staticmethod
|
||||
def _print_query_result(result):
|
||||
"""Print query result."""
|
||||
logger.info("Query Result:")
|
||||
|
||||
logger.info(f"{result}\n")
|
||||
|
||||
async def _retrieve_and_print(self, question):
|
||||
nodes = await self.engine.aretrieve(question)
|
||||
self._print_retrieve_result(nodes)
|
||||
return nodes
|
||||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
e = RAGExample()
|
||||
await e.run_pipeline()
|
||||
await e.add_docs()
|
||||
await e.add_objects()
|
||||
await e.init_objects()
|
||||
await e.init_and_query_chromadb()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
21
examples/rag_search.py
Normal file
21
examples/rag_search.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
"""Agent with RAG search."""
|
||||
|
||||
import asyncio
|
||||
|
||||
from examples.rag_pipeline import DOC_PATH, QUESTION
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
async def search():
|
||||
"""Agent with RAG search."""
|
||||
|
||||
store = SimpleEngine.from_docs(input_files=[DOC_PATH])
|
||||
role = Sales(profile="Sales", store=store)
|
||||
result = await role.run(QUESTION)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(search())
|
||||
|
|
@ -1,33 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@File : search_kb.py
|
||||
@Modified By: mashenquan, 2023-12-22. Delete useless codes.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def get_store():
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
async def search():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(search())
|
||||
|
|
@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
|||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
|
|
|
|||
|
|
@ -11,12 +11,9 @@ from pathlib import Path
|
|||
from typing import Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from langchain_community.document_loaders import (
|
||||
TextLoader,
|
||||
UnstructuredPDFLoader,
|
||||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from llama_index.core import Document, SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
from llama_index.readers.file import PDFReader
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
|
|
@ -29,7 +26,7 @@ def validate_cols(content_col: str, df: pd.DataFrame):
|
|||
raise ValueError("Content column not found in DataFrame.")
|
||||
|
||||
|
||||
def read_data(data_path: Path):
|
||||
def read_data(data_path: Path) -> Union[pd.DataFrame, list[Document]]:
|
||||
suffix = data_path.suffix
|
||||
if ".xlsx" == suffix:
|
||||
data = pd.read_excel(data_path)
|
||||
|
|
@ -38,14 +35,13 @@ def read_data(data_path: Path):
|
|||
elif ".json" == suffix:
|
||||
data = pd.read_json(data_path)
|
||||
elif suffix in (".docx", ".doc"):
|
||||
data = UnstructuredWordDocumentLoader(str(data_path), mode="elements").load()
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
elif ".txt" == suffix:
|
||||
data = TextLoader(str(data_path)).load()
|
||||
text_splitter = CharacterTextSplitter(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
texts = text_splitter.split_documents(data)
|
||||
data = texts
|
||||
data = SimpleDirectoryReader(input_files=[str(data_path)]).load_data()
|
||||
node_parser = SimpleNodeParser.from_defaults(separator="\n", chunk_size=256, chunk_overlap=0)
|
||||
data = node_parser.get_nodes_from_documents(data)
|
||||
elif ".pdf" == suffix:
|
||||
data = UnstructuredPDFLoader(str(data_path), mode="elements").load()
|
||||
data = PDFReader.load_data(str(data_path))
|
||||
else:
|
||||
raise NotImplementedError("File format not supported.")
|
||||
return data
|
||||
|
|
@ -150,9 +146,9 @@ class IndexableDocument(Document):
|
|||
metadatas.append({})
|
||||
return docs, metadatas
|
||||
|
||||
def _get_docs_and_metadatas_by_langchain(self) -> (list, list):
|
||||
def _get_docs_and_metadatas_by_llamaindex(self) -> (list, list):
|
||||
data = self.data
|
||||
docs = [i.page_content for i in data]
|
||||
docs = [i.text for i in data]
|
||||
metadatas = [i.metadata for i in data]
|
||||
return docs, metadatas
|
||||
|
||||
|
|
@ -160,7 +156,7 @@ class IndexableDocument(Document):
|
|||
if isinstance(self.data, pd.DataFrame):
|
||||
return self._get_docs_and_metadatas_by_df()
|
||||
elif isinstance(self.data, list):
|
||||
return self._get_docs_and_metadatas_by_langchain()
|
||||
return self._get_docs_and_metadatas_by_llamaindex()
|
||||
else:
|
||||
raise NotImplementedError("Data type not supported for metadata extraction.")
|
||||
|
||||
|
|
|
|||
|
|
@ -38,9 +38,9 @@ class LocalStore(BaseStore, ABC):
|
|||
if not self.store:
|
||||
self.store = self.write()
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
index_file = self.cache_dir / f"{self.fname}{index_ext}"
|
||||
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
|
||||
def _get_index_and_store_fname(self, index_ext=".json", docstore_ext=".json"):
|
||||
index_file = self.cache_dir / "default__vector_store" / index_ext
|
||||
store_file = self.cache_dir / "docstore" / docstore_ext
|
||||
return index_file, store_file
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ import chromadb
|
|||
class ChromaStore:
|
||||
"""If inherited from BaseStore, or importing other modules from metagpt, a Python exception occurs, which is strange."""
|
||||
|
||||
def __init__(self, name):
|
||||
def __init__(self, name: str, get_or_create: bool = False):
|
||||
client = chromadb.Client()
|
||||
collection = client.create_collection(name)
|
||||
collection = client.create_collection(name, get_or_create=get_or_create)
|
||||
self.client = client
|
||||
self.collection = collection
|
||||
|
||||
|
|
|
|||
|
|
@ -7,10 +7,14 @@
|
|||
"""
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
import faiss
|
||||
from llama_index.core import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.schema import Document, QueryBundle, TextNode
|
||||
from llama_index.core.storage import StorageContext
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.document import IndexableDocument
|
||||
from metagpt.document_store.base_store import LocalStore
|
||||
|
|
@ -20,36 +24,50 @@ from metagpt.utils.embedding import get_embedding
|
|||
|
||||
class FaissStore(LocalStore):
|
||||
def __init__(
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: BaseEmbedding = None
|
||||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: VectorStoreIndex
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
def _load(self) -> Optional["VectorStoreIndex"]:
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
|
||||
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
|
||||
index = load_index_from_storage(storage_context, embed_model=self.embedding)
|
||||
|
||||
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
|
||||
return index
|
||||
|
||||
def _write(self, docs, metadatas):
|
||||
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
|
||||
return store
|
||||
def _write(self, docs: list[str], metadatas: list[dict[str, Any]]) -> VectorStoreIndex:
|
||||
assert len(docs) == len(metadatas)
|
||||
documents = [Document(text=doc, metadata=metadatas[idx]) for idx, doc in enumerate(docs)]
|
||||
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents, storage_context=storage_context, embed_model=self.embedding
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.raw_data_path.parent, self.fname)
|
||||
self.store.storage_context.persist(self.cache_dir)
|
||||
|
||||
def search(self, query: str, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
retriever = self.store.as_retriever(similarity_top_k=k)
|
||||
rsp = retriever.retrieve(QueryBundle(query_str=query, embedding=self.embedding.get_text_embedding(query)))
|
||||
|
||||
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
rsp = self.store.similarity_search(query, k=k, **kwargs)
|
||||
logger.debug(rsp)
|
||||
if expand_cols:
|
||||
return str(sep.join([f"{x.page_content}: {x.metadata}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}: {x.node.metadata}" for x in rsp]))
|
||||
else:
|
||||
return str(sep.join([f"{x.page_content}" for x in rsp]))
|
||||
return str(sep.join([f"{x.node.text}" for x in rsp]))
|
||||
|
||||
async def asearch(self, *args, **kwargs):
|
||||
return await asyncio.to_thread(self.search, *args, **kwargs)
|
||||
|
|
@ -67,8 +85,12 @@ class FaissStore(LocalStore):
|
|||
|
||||
def add(self, texts: list[str], *args, **kwargs) -> list[str]:
|
||||
"""FIXME: Currently, the store is not updated after adding."""
|
||||
return self.store.add_texts(texts)
|
||||
texts_embeds = self.embedding.get_text_embedding_batch(texts)
|
||||
nodes = [TextNode(text=texts[idx], embedding=embed) for idx, embed in enumerate(texts_embeds)]
|
||||
self.store.insert_nodes(nodes)
|
||||
|
||||
return []
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
"""Currently, langchain does not provide a delete interface."""
|
||||
"""Currently, faiss does not provide a delete interface."""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -8,8 +8,6 @@ import re
|
|||
import time
|
||||
from typing import Any, Iterable
|
||||
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.config2 import config as CONFIG
|
||||
|
|
@ -17,6 +15,7 @@ from metagpt.environment.base_env import Environment
|
|||
from metagpt.environment.mincraft_env.const import MC_CKPT_DIR
|
||||
from metagpt.environment.mincraft_env.mincraft_ext_env import MincraftExtEnv
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
from metagpt.utils.common import load_mc_skills_code, read_json_file, write_json_file
|
||||
|
||||
|
||||
|
|
@ -48,9 +47,9 @@ class MincraftEnv(Environment, MincraftExtEnv):
|
|||
|
||||
runtime_status: bool = False # equal to action execution status: success or failed
|
||||
|
||||
vectordb: Chroma = Field(default_factory=Chroma)
|
||||
vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
|
||||
|
||||
qa_cache_questions_vectordb: Chroma = Field(default_factory=Chroma)
|
||||
qa_cache_questions_vectordb: ChromaVectorStore = Field(default_factory=ChromaVectorStore)
|
||||
|
||||
@property
|
||||
def progress(self):
|
||||
|
|
@ -73,16 +72,14 @@ class MincraftEnv(Environment, MincraftExtEnv):
|
|||
self.set_mc_resume()
|
||||
|
||||
def set_mc_resume(self):
|
||||
self.qa_cache_questions_vectordb = Chroma(
|
||||
self.qa_cache_questions_vectordb = ChromaVectorStore(
|
||||
collection_name="qa_cache_questions_vectordb",
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory=f"{MC_CKPT_DIR}/curriculum/vectordb",
|
||||
persist_dir=f"{MC_CKPT_DIR}/curriculum/vectordb",
|
||||
)
|
||||
|
||||
self.vectordb = Chroma(
|
||||
self.vectordb = ChromaVectorStore(
|
||||
collection_name="skill_vectordb",
|
||||
embedding_function=OpenAIEmbeddings(),
|
||||
persist_directory=f"{MC_CKPT_DIR}/skill/vectordb",
|
||||
persist_dir=f"{MC_CKPT_DIR}/skill/vectordb",
|
||||
)
|
||||
|
||||
if CONFIG.resume:
|
||||
|
|
|
|||
|
|
@ -29,16 +29,14 @@ class LongTermMemory(Memory):
|
|||
msg_from_recover: bool = False
|
||||
|
||||
def recover_memory(self, role_id: str, rc: RoleContext):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
logger.warning(f"It may the first time to run Agent {role_id}, the long-term memory is empty")
|
||||
logger.warning(f"It may the first time to run Role {role_id}, the long-term memory is empty")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Agent {role_id} has existing memory storage with {len(messages)} messages " f"and has recovered them."
|
||||
)
|
||||
logger.warning(f"Role {role_id} has existing memory storage and has recovered them.")
|
||||
self.msg_from_recover = True
|
||||
self.add_batch(messages)
|
||||
# self.add_batch(messages) # TODO no need
|
||||
self.msg_from_recover = False
|
||||
|
||||
def add(self, message: Message):
|
||||
|
|
@ -49,7 +47,7 @@ class LongTermMemory(Memory):
|
|||
# and ignore adding messages from recover repeatedly
|
||||
self.memory_storage.add(message)
|
||||
|
||||
def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
async def find_news(self, observed: list[Message], k=0) -> list[Message]:
|
||||
"""
|
||||
find news (previously unseen messages) from the the most recent k memories, from all memories when k=0
|
||||
1. find the short-term memory(stm) news
|
||||
|
|
@ -63,11 +61,14 @@ class LongTermMemory(Memory):
|
|||
ltm_news: list[Message] = []
|
||||
for mem in stm_news:
|
||||
# filter out messages similar to those seen previously in ltm, only keep fresh news
|
||||
mem_searched = self.memory_storage.search_dissimilar(mem)
|
||||
if len(mem_searched) > 0:
|
||||
mem_searched = await self.memory_storage.search_similar(mem)
|
||||
if len(mem_searched) == 0:
|
||||
ltm_news.append(mem)
|
||||
return ltm_news[-k:]
|
||||
|
||||
def persist(self):
|
||||
self.memory_storage.persist()
|
||||
|
||||
def delete(self, message: Message):
|
||||
super().delete(message)
|
||||
# TODO delete message in memory_storage
|
||||
|
|
|
|||
|
|
@ -3,115 +3,75 @@
|
|||
"""
|
||||
@Desc : the implement of memory storage
|
||||
"""
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from langchain.vectorstores.faiss import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
|
||||
from metagpt.const import DATA_PATH, MEM_TTL
|
||||
from metagpt.document_store.faiss_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
from metagpt.rag.schema import FAISSIndexConfig, FAISSRetrieverConfig
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
from metagpt.utils.serialize import deserialize_message, serialize_message
|
||||
|
||||
|
||||
class MemoryStorage(FaissStore):
|
||||
class MemoryStorage(object):
|
||||
"""
|
||||
The memory storage with Faiss as ANN search engine
|
||||
"""
|
||||
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: Embeddings = None):
|
||||
def __init__(self, mem_ttl: int = MEM_TTL, embedding: BaseEmbedding = None):
|
||||
self.role_id: str = None
|
||||
self.role_mem_path: str = None
|
||||
self.mem_ttl: int = mem_ttl # later use
|
||||
self.threshold: float = 0.1 # experience value. TODO The threshold to filter similar memories
|
||||
self._initialized: bool = False
|
||||
|
||||
self.embedding = embedding or get_embedding()
|
||||
self.store: FAISS = None # Faiss engine
|
||||
|
||||
self.faiss_engine = None
|
||||
|
||||
@property
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
|
||||
return FAISS.load_local(self.role_mem_path, self.embedding, self.role_id)
|
||||
|
||||
def recover_memory(self, role_id: str) -> list[Message]:
|
||||
self.role_id = role_id
|
||||
self.role_mem_path = Path(DATA_PATH / f"role_mem/{self.role_id}/")
|
||||
self.role_mem_path.mkdir(parents=True, exist_ok=True)
|
||||
self.cache_dir = self.role_mem_path
|
||||
|
||||
self.store = self._load()
|
||||
messages = []
|
||||
if not self.store:
|
||||
# TODO init `self.store` under here with raw faiss api instead under `add`
|
||||
pass
|
||||
if self.role_mem_path.joinpath("default__vector_store.json").exists():
|
||||
self.faiss_engine = SimpleEngine.from_index(
|
||||
index_config=FAISSIndexConfig(persist_path=self.cache_dir),
|
||||
retriever_configs=[FAISSRetrieverConfig()],
|
||||
embed_model=self.embedding,
|
||||
)
|
||||
else:
|
||||
for _id, document in self.store.docstore._dict.items():
|
||||
messages.append(deserialize_message(document.metadata.get("message_ser")))
|
||||
self._initialized = True
|
||||
|
||||
return messages
|
||||
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
if not self.role_mem_path:
|
||||
logger.error(f"You should call {self.__class__.__name__}.recover_memory fist when using LongTermMemory")
|
||||
return None, None
|
||||
index_fpath = Path(self.role_mem_path / f"{self.role_id}{index_ext}")
|
||||
storage_fpath = Path(self.role_mem_path / f"{self.role_id}{pkl_ext}")
|
||||
return index_fpath, storage_fpath
|
||||
|
||||
def persist(self):
|
||||
self.store.save_local(self.role_mem_path, self.role_id)
|
||||
logger.debug(f"Agent {self.role_id} persist memory into local")
|
||||
self.faiss_engine = SimpleEngine.from_objs(
|
||||
objs=[], retriever_configs=[FAISSRetrieverConfig()], embed_model=self.embedding
|
||||
)
|
||||
self._initialized = True
|
||||
|
||||
def add(self, message: Message) -> bool:
|
||||
"""add message into memory storage"""
|
||||
docs = [message.content]
|
||||
metadatas = [{"message_ser": serialize_message(message)}]
|
||||
if not self.store:
|
||||
# init Faiss
|
||||
self.store = self._write(docs, metadatas)
|
||||
self._initialized = True
|
||||
else:
|
||||
self.store.add_texts(texts=docs, metadatas=metadatas)
|
||||
self.persist()
|
||||
logger.info(f"Agent {self.role_id}'s memory_storage add a message")
|
||||
self.faiss_engine.add_objs([message])
|
||||
logger.info(f"Role {self.role_id}'s memory_storage add a message")
|
||||
|
||||
def search_dissimilar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for dissimilar messages"""
|
||||
if not self.store:
|
||||
return []
|
||||
|
||||
resp = self.store.similarity_search_with_score(query=message.content, k=k)
|
||||
async def search_similar(self, message: Message, k=4) -> list[Message]:
|
||||
"""search for similar messages"""
|
||||
# filter the result which score is smaller than the threshold
|
||||
filtered_resp = []
|
||||
for item, score in resp:
|
||||
# the smaller score means more similar relation
|
||||
if score < self.threshold:
|
||||
continue
|
||||
# convert search result into Memory
|
||||
metadata = item.metadata
|
||||
new_mem = deserialize_message(metadata.get("message_ser"))
|
||||
filtered_resp.append(new_mem)
|
||||
resp = await self.faiss_engine.aretrieve(message.content)
|
||||
for item in resp:
|
||||
if item.score < self.threshold:
|
||||
filtered_resp.append(item.metadata.get("obj"))
|
||||
return filtered_resp
|
||||
|
||||
def clean(self):
|
||||
index_fpath, storage_fpath = self._get_index_and_store_fname()
|
||||
if index_fpath and index_fpath.exists():
|
||||
index_fpath.unlink(missing_ok=True)
|
||||
if storage_fpath and storage_fpath.exists():
|
||||
storage_fpath.unlink(missing_ok=True)
|
||||
|
||||
self.store = None
|
||||
shutil.rmtree(self.cache_dir, ignore_errors=True)
|
||||
self._initialized = False
|
||||
|
||||
def persist(self):
|
||||
if self.faiss_engine:
|
||||
self.faiss_engine.retriever._index.storage_context.persist(self.cache_dir)
|
||||
|
|
|
|||
0
metagpt/rag/__init__.py
Normal file
0
metagpt/rag/__init__.py
Normal file
5
metagpt/rag/engines/__init__.py
Normal file
5
metagpt/rag/engines/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Engines init"""
|
||||
|
||||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
259
metagpt/rag/engines/simple.py
Normal file
259
metagpt/rag/engines/simple.py
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
"""Simple Engine."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from llama_index.core import SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.embeddings.mock_embed_model import MockEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion.pipeline import run_transformations
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import (
|
||||
BaseSynthesizer,
|
||||
get_response_synthesizer,
|
||||
)
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import (
|
||||
BaseNode,
|
||||
Document,
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
TransformComponent,
|
||||
)
|
||||
|
||||
from metagpt.rag.factories import (
|
||||
get_index,
|
||||
get_rag_embedding,
|
||||
get_rag_llm,
|
||||
get_rankers,
|
||||
get_retriever,
|
||||
)
|
||||
from metagpt.rag.interface import NoEmbedding, RAGObject
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever, PersistableRAGRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BaseRankerConfig,
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ObjectNode,
|
||||
)
|
||||
from metagpt.utils.common import import_class
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""SimpleEngine is designed to be simple and straightforward.
|
||||
|
||||
It is a lightweight and easy-to-use search engine that integrates
|
||||
document reading, embedding, indexing, retrieving, and ranking functionalities
|
||||
into a single, straightforward workflow. It is designed to quickly set up a
|
||||
search engine from a collection of documents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
retriever: BaseRetriever,
|
||||
response_synthesizer: Optional[BaseSynthesizer] = None,
|
||||
node_postprocessors: Optional[list[BaseNodePostprocessor]] = None,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
index: Optional[BaseIndex] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
retriever=retriever,
|
||||
response_synthesizer=response_synthesizer,
|
||||
node_postprocessors=node_postprocessors,
|
||||
callback_manager=callback_manager,
|
||||
)
|
||||
self.index = index
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
input_dir: str = None,
|
||||
input_files: list[str] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From docs.
|
||||
|
||||
Must provide either `input_dir` or `input_files`.
|
||||
|
||||
Args:
|
||||
input_dir: Path to the directory.
|
||||
input_files: List of file paths to read (Optional; overrides input_dir, exclude).
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not input_dir and not input_files:
|
||||
raise ValueError("Must provide either `input_dir` or `input_files`.")
|
||||
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
cls._fix_document_metadata(documents)
|
||||
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_objs(
|
||||
cls,
|
||||
objs: Optional[list[RAGObject]] = None,
|
||||
transformations: Optional[list[TransformComponent]] = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""From objs.
|
||||
|
||||
Args:
|
||||
objs: List of RAGObject.
|
||||
transformations: Parse documents to nodes. Default [SentenceSplitter].
|
||||
embed_model: Parse nodes to embedding. Must supported by llama index. Default OpenAIEmbedding.
|
||||
llm: Must supported by llama index. Default OpenAI.
|
||||
retriever_configs: Configuration for retrievers. If more than one config, will use SimpleHybridRetriever.
|
||||
ranker_configs: Configuration for rankers.
|
||||
"""
|
||||
if not objs and any(isinstance(config, BM25RetrieverConfig) for config in retriever_configs):
|
||||
raise ValueError("In BM25RetrieverConfig, Objs must not be empty.")
|
||||
|
||||
objs = objs or []
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
index = VectorStoreIndex(
|
||||
nodes=nodes,
|
||||
transformations=transformations or [SentenceSplitter()],
|
||||
embed_model=cls._resolve_embed_model(embed_model, retriever_configs),
|
||||
)
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
@classmethod
|
||||
def from_index(
|
||||
cls,
|
||||
index_config: BaseIndexConfig,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
"""Load from previously maintained index by self.persist(), index_config contains persis_path."""
|
||||
index = get_index(index_config, embed_model=cls._resolve_embed_model(embed_model, [index_config]))
|
||||
return cls._from_index(index, llm=llm, retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
|
||||
async def aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Allow query to be str."""
|
||||
query_bundle = QueryBundle(query) if isinstance(query, str) else query
|
||||
|
||||
nodes = await super().aretrieve(query_bundle)
|
||||
self._try_reconstruct_obj(nodes)
|
||||
return nodes
|
||||
|
||||
def add_docs(self, input_files: list[str]):
|
||||
"""Add docs to retriever. retriever must has add_nodes func."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
self._fix_document_metadata(documents)
|
||||
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [ObjectNode(text=obj.rag_key(), metadata=ObjectNode.get_obj_metadata(obj)) for obj in objs]
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: Union[str, os.PathLike], **kwargs):
|
||||
"""Persist."""
|
||||
self._ensure_retriever_persistable()
|
||||
|
||||
self._persist(str(persist_dir), **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _from_index(
|
||||
cls,
|
||||
index: BaseIndex,
|
||||
llm: LLM = None,
|
||||
retriever_configs: list[BaseRetrieverConfig] = None,
|
||||
ranker_configs: list[BaseRankerConfig] = None,
|
||||
) -> "SimpleEngine":
|
||||
llm = llm or get_rag_llm()
|
||||
retriever = get_retriever(configs=retriever_configs, index=index) # Default index.as_retriever
|
||||
rankers = get_rankers(configs=ranker_configs, llm=llm) # Default []
|
||||
|
||||
return cls(
|
||||
retriever=retriever,
|
||||
node_postprocessors=rankers,
|
||||
response_synthesizer=get_response_synthesizer(llm=llm),
|
||||
index=index,
|
||||
)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
self._ensure_retriever_of_type(ModifiableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_persistable(self):
|
||||
self._ensure_retriever_of_type(PersistableRAGRetriever)
|
||||
|
||||
def _ensure_retriever_of_type(self, required_type: BaseRetriever):
|
||||
"""Ensure that self.retriever is required_type, or at least one of its components, if it's a SimpleHybridRetriever.
|
||||
|
||||
Args:
|
||||
required_type: The class that the retriever is expected to be an instance of.
|
||||
"""
|
||||
if isinstance(self.retriever, SimpleHybridRetriever):
|
||||
if not any(isinstance(r, required_type) for r in self.retriever.retrievers):
|
||||
raise TypeError(
|
||||
f"Must have at least one retriever of type {required_type.__name__} in SimpleHybridRetriever"
|
||||
)
|
||||
|
||||
if not isinstance(self.retriever, required_type):
|
||||
raise TypeError(f"The retriever is not of type {required_type.__name__}: {type(self.retriever)}")
|
||||
|
||||
def _save_nodes(self, nodes: list[BaseNode]):
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
def _persist(self, persist_dir: str, **kwargs):
|
||||
self.retriever.persist(persist_dir, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _try_reconstruct_obj(nodes: list[NodeWithScore]):
|
||||
"""If node is object, then dynamically reconstruct object, and save object to node.metadata["obj"]."""
|
||||
for node in nodes:
|
||||
if node.metadata.get("is_obj", False):
|
||||
obj_cls = import_class(node.metadata["obj_cls_name"], node.metadata["obj_mod_name"])
|
||||
obj_dict = json.loads(node.metadata["obj_json"])
|
||||
node.metadata["obj"] = obj_cls(**obj_dict)
|
||||
|
||||
@staticmethod
|
||||
def _fix_document_metadata(documents: list[Document]):
|
||||
"""LlamaIndex keep metadata['file_path'], which is unnecessary, maybe deleted in the near future."""
|
||||
for doc in documents:
|
||||
doc.excluded_embed_metadata_keys.append("file_path")
|
||||
|
||||
@staticmethod
|
||||
def _resolve_embed_model(embed_model: BaseEmbedding = None, configs: list[Any] = None) -> BaseEmbedding:
|
||||
if configs and all(isinstance(c, NoEmbedding) for c in configs):
|
||||
return MockEmbedding(embed_dim=1)
|
||||
|
||||
return embed_model or get_rag_embedding()
|
||||
9
metagpt/rag/factories/__init__.py
Normal file
9
metagpt/rag/factories/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""RAG factories"""
|
||||
|
||||
from metagpt.rag.factories.retriever import get_retriever
|
||||
from metagpt.rag.factories.ranker import get_rankers
|
||||
from metagpt.rag.factories.embedding import get_rag_embedding
|
||||
from metagpt.rag.factories.index import get_index
|
||||
from metagpt.rag.factories.llm import get_rag_llm
|
||||
|
||||
__all__ = ["get_retriever", "get_rankers", "get_rag_embedding", "get_index", "get_rag_llm"]
|
||||
59
metagpt/rag/factories/base.py
Normal file
59
metagpt/rag/factories/base.py
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
"""Base Factory."""
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
|
||||
class GenericFactory:
|
||||
"""Designed to get objects based on any keys."""
|
||||
|
||||
def __init__(self, creators: dict[Any, Callable] = None):
|
||||
"""Creators is a dictionary.
|
||||
|
||||
Keys are identifiers, and the values are the associated creator function, which create objects.
|
||||
"""
|
||||
self._creators = creators or {}
|
||||
|
||||
def get_instances(self, keys: list[Any], **kwargs) -> list[Any]:
|
||||
"""Get instances by keys."""
|
||||
return [self.get_instance(key, **kwargs) for key in keys]
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Get instance by key.
|
||||
|
||||
Raise Exception if key not found.
|
||||
"""
|
||||
creator = self._creators.get(key)
|
||||
if creator:
|
||||
return creator(**kwargs)
|
||||
|
||||
raise ValueError(f"Creator not registered for key: {key}")
|
||||
|
||||
|
||||
class ConfigBasedFactory(GenericFactory):
|
||||
"""Designed to get objects based on object type."""
|
||||
|
||||
def get_instance(self, key: Any, **kwargs) -> Any:
|
||||
"""Key is config, such as a pydantic model.
|
||||
|
||||
Call func by the type of key, and the key will be passed to func.
|
||||
"""
|
||||
creator = self._creators.get(type(key))
|
||||
if creator:
|
||||
return creator(key, **kwargs)
|
||||
|
||||
raise ValueError(f"Unknown config: {key}")
|
||||
|
||||
@staticmethod
|
||||
def _val_from_config_or_kwargs(key: str, config: object = None, **kwargs) -> Any:
|
||||
"""It prioritizes the configuration object's value unless it is None, in which case it looks into kwargs."""
|
||||
if config is not None and hasattr(config, key):
|
||||
val = getattr(config, key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
if key in kwargs:
|
||||
return kwargs[key]
|
||||
|
||||
raise KeyError(
|
||||
f"The key '{key}' is required but not provided in either configuration object or keyword arguments."
|
||||
)
|
||||
37
metagpt/rag/factories/embedding.py
Normal file
37
metagpt/rag/factories/embedding.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""RAG Embedding Factory."""
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.rag.factories.base import GenericFactory
|
||||
|
||||
|
||||
class RAGEmbeddingFactory(GenericFactory):
|
||||
"""Create LlamaIndex Embedding with MetaGPT's config."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMType.OPENAI: self._create_openai,
|
||||
LLMType.AZURE: self._create_azure,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rag_embedding(self, key: LLMType = None) -> BaseEmbedding:
|
||||
"""Key is LLMType, default use config.llm.api_type."""
|
||||
return super().get_instance(key or config.llm.api_type)
|
||||
|
||||
def _create_openai(self):
|
||||
return OpenAIEmbedding(api_key=config.llm.api_key, api_base=config.llm.base_url)
|
||||
|
||||
def _create_azure(self):
|
||||
return AzureOpenAIEmbedding(
|
||||
azure_endpoint=config.llm.base_url,
|
||||
api_key=config.llm.api_key,
|
||||
api_version=config.llm.api_version,
|
||||
)
|
||||
|
||||
|
||||
get_rag_embedding = RAGEmbeddingFactory().get_rag_embedding
|
||||
63
metagpt/rag/factories/index.py
Normal file
63
metagpt/rag/factories/index.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""RAG Index Factory."""
|
||||
|
||||
import chromadb
|
||||
from llama_index.core import StorageContext, VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import (
|
||||
BaseIndexConfig,
|
||||
BM25IndexConfig,
|
||||
ChromaIndexConfig,
|
||||
FAISSIndexConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RAGIndexFactory(ConfigBasedFactory):
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSIndexConfig: self._create_faiss,
|
||||
ChromaIndexConfig: self._create_chroma,
|
||||
BM25IndexConfig: self._create_bm25,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_index(self, config: BaseIndexConfig, **kwargs) -> BaseIndex:
|
||||
"""Key is PersistType."""
|
||||
return super().get_instance(config, **kwargs)
|
||||
|
||||
def _create_faiss(self, config: FAISSIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
vector_store = FaissVectorStore.from_persist_dir(str(config.persist_path))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store, persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
def _create_chroma(self, config: ChromaIndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
db = chromadb.PersistentClient(str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
index = VectorStoreIndex.from_vector_store(
|
||||
vector_store,
|
||||
embed_model=embed_model,
|
||||
)
|
||||
return index
|
||||
|
||||
def _create_bm25(self, config: BM25IndexConfig, **kwargs) -> VectorStoreIndex:
|
||||
embed_model = self._extract_embed_model(config, **kwargs)
|
||||
|
||||
storage_context = StorageContext.from_defaults(persist_dir=config.persist_path)
|
||||
index = load_index_from_storage(storage_context=storage_context, embed_model=embed_model)
|
||||
return index
|
||||
|
||||
def _extract_embed_model(self, config, **kwargs) -> BaseEmbedding:
|
||||
return self._val_from_config_or_kwargs("embed_model", config, **kwargs)
|
||||
|
||||
|
||||
get_index = RAGIndexFactory().get_index
|
||||
54
metagpt/rag/factories/llm.py
Normal file
54
metagpt/rag/factories/llm.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
"""RAG LLM."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW
|
||||
from llama_index.core.llms import (
|
||||
CompletionResponse,
|
||||
CompletionResponseGen,
|
||||
CustomLLM,
|
||||
LLMMetadata,
|
||||
)
|
||||
from llama_index.core.llms.callbacks import llm_completion_callback
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.async_helper import run_coroutine_in_new_loop
|
||||
from metagpt.utils.token_counter import TOKEN_MAX
|
||||
|
||||
|
||||
class RAGLLM(CustomLLM):
|
||||
"""LlamaIndex's LLM is different from MetaGPT's LLM.
|
||||
|
||||
Inherit CustomLLM from llamaindex, making MetaGPT's LLM can be used by LlamaIndex.
|
||||
"""
|
||||
|
||||
model_infer: BaseLLM = Field(..., description="The MetaGPT's LLM.")
|
||||
context_window: int = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW)
|
||||
num_output: int = config.llm.max_token
|
||||
model_name: str = config.llm.model
|
||||
|
||||
@property
|
||||
def metadata(self) -> LLMMetadata:
|
||||
"""Get LLM metadata."""
|
||||
return LLMMetadata(context_window=self.context_window, num_output=self.num_output, model_name=self.model_name)
|
||||
|
||||
@llm_completion_callback()
|
||||
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
|
||||
return run_coroutine_in_new_loop(self.acomplete(prompt, **kwargs))
|
||||
|
||||
@llm_completion_callback()
|
||||
async def acomplete(self, prompt: str, formatted: bool = False, **kwargs: Any) -> CompletionResponse:
|
||||
text = await self.model_infer.aask(msg=prompt, stream=False)
|
||||
return CompletionResponse(text=text)
|
||||
|
||||
@llm_completion_callback()
|
||||
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
|
||||
...
|
||||
|
||||
|
||||
def get_rag_llm(model_infer: BaseLLM = None) -> RAGLLM:
|
||||
"""Get llm that can be used by LlamaIndex."""
|
||||
return RAGLLM(model_infer=model_infer or LLM())
|
||||
35
metagpt/rag/factories/ranker.py
Normal file
35
metagpt/rag/factories/ranker.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
"""RAG Ranker Factory."""
|
||||
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.schema import BaseRankerConfig, LLMRankerConfig
|
||||
|
||||
|
||||
class RankerFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
LLMRankerConfig: self._create_llm_ranker,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_rankers(self, configs: list[BaseRankerConfig] = None, **kwargs) -> list[BaseNodePostprocessor]:
|
||||
"""Creates and returns a retriever instance based on the provided configurations."""
|
||||
if not configs:
|
||||
return []
|
||||
|
||||
return super().get_instances(configs, **kwargs)
|
||||
|
||||
def _create_llm_ranker(self, config: LLMRankerConfig, **kwargs) -> LLMRerank:
|
||||
config.llm = self._extract_llm(config, **kwargs)
|
||||
return LLMRerank(**config.model_dump())
|
||||
|
||||
def _extract_llm(self, config: BaseRankerConfig = None, **kwargs) -> LLM:
|
||||
return self._val_from_config_or_kwargs("llm", config, **kwargs)
|
||||
|
||||
|
||||
get_rankers = RankerFactory().get_rankers
|
||||
86
metagpt/rag/factories/retriever.py
Normal file
86
metagpt/rag/factories/retriever.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
"""RAG Retriever Factory."""
|
||||
|
||||
import copy
|
||||
|
||||
import chromadb
|
||||
import faiss
|
||||
from llama_index.core import StorageContext, VectorStoreIndex
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.chroma_retriever import ChromaRetriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import (
|
||||
BaseRetrieverConfig,
|
||||
BM25RetrieverConfig,
|
||||
ChromaRetrieverConfig,
|
||||
FAISSRetrieverConfig,
|
||||
IndexRetrieverConfig,
|
||||
)
|
||||
from metagpt.rag.vector_stores.chroma import ChromaVectorStore
|
||||
|
||||
|
||||
class RetrieverFactory(ConfigBasedFactory):
|
||||
"""Modify creators for dynamically instance implementation."""
|
||||
|
||||
def __init__(self):
|
||||
creators = {
|
||||
FAISSRetrieverConfig: self._create_faiss_retriever,
|
||||
BM25RetrieverConfig: self._create_bm25_retriever,
|
||||
ChromaRetrieverConfig: self._create_chroma_retriever,
|
||||
}
|
||||
super().__init__(creators)
|
||||
|
||||
def get_retriever(self, configs: list[BaseRetrieverConfig] = None, **kwargs) -> RAGRetriever:
|
||||
"""Creates and returns a retriever instance based on the provided configurations.
|
||||
|
||||
If multiple retrievers, using SimpleHybridRetriever.
|
||||
"""
|
||||
if not configs:
|
||||
return self._create_default(**kwargs)
|
||||
|
||||
retrievers = super().get_instances(configs, **kwargs)
|
||||
|
||||
return SimpleHybridRetriever(*retrievers) if len(retrievers) > 1 else retrievers[0]
|
||||
|
||||
def _create_default(self, **kwargs) -> RAGRetriever:
|
||||
return self._extract_index(**kwargs).as_retriever()
|
||||
|
||||
def _create_faiss_retriever(self, config: FAISSRetrieverConfig, **kwargs) -> FAISSRetriever:
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(config.dimensions))
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return FAISSRetriever(**config.model_dump())
|
||||
|
||||
def _create_bm25_retriever(self, config: BM25RetrieverConfig, **kwargs) -> DynamicBM25Retriever:
|
||||
config.index = copy.deepcopy(self._extract_index(config, **kwargs))
|
||||
nodes = list(config.index.docstore.docs.values())
|
||||
return DynamicBM25Retriever(nodes=nodes, **config.model_dump())
|
||||
|
||||
def _create_chroma_retriever(self, config: ChromaRetrieverConfig, **kwargs) -> ChromaRetriever:
|
||||
db = chromadb.PersistentClient(path=str(config.persist_path))
|
||||
chroma_collection = db.get_or_create_collection(config.collection_name)
|
||||
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
|
||||
config.index = self._build_index_from_vector_store(config, vector_store, **kwargs)
|
||||
return ChromaRetriever(**config.model_dump())
|
||||
|
||||
def _extract_index(self, config: BaseRetrieverConfig = None, **kwargs) -> VectorStoreIndex:
|
||||
return self._val_from_config_or_kwargs("index", config, **kwargs)
|
||||
|
||||
def _build_index_from_vector_store(
|
||||
self, config: IndexRetrieverConfig, vector_store: BasePydanticVectorStore, **kwargs
|
||||
) -> VectorStoreIndex:
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
old_index = self._extract_index(config, **kwargs)
|
||||
new_index = VectorStoreIndex(
|
||||
nodes=list(old_index.docstore.docs.values()),
|
||||
storage_context=storage_context,
|
||||
embed_model=old_index._embed_model,
|
||||
)
|
||||
return new_index
|
||||
|
||||
|
||||
get_retriever = RetrieverFactory().get_retriever
|
||||
24
metagpt/rag/interface.py
Normal file
24
metagpt/rag/interface.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""RAG Interfaces."""
|
||||
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class RAGObject(Protocol):
|
||||
"""Support rag add object."""
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For rag search."""
|
||||
|
||||
def model_dump_json(self) -> str:
|
||||
"""For rag persist.
|
||||
|
||||
Pydantic Model don't need to implement this, as there is a built-in function named model_dump_json.
|
||||
"""
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class NoEmbedding(Protocol):
|
||||
"""Some retriever does not require embeddings, e.g. BM25"""
|
||||
|
||||
_no_embedding: bool
|
||||
1
metagpt/rag/rankers/__init__.py
Normal file
1
metagpt/rag/rankers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Rankers init"""
|
||||
19
metagpt/rag/rankers/base.py
Normal file
19
metagpt/rag/rankers/base.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Base Ranker."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import NodeWithScore, QueryBundle
|
||||
|
||||
|
||||
class RAGRanker(BaseNodePostprocessor):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""postprocess nodes."""
|
||||
5
metagpt/rag/retrievers/__init__.py
Normal file
5
metagpt/rag/retrievers/__init__.py
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
"""Retrievers init."""
|
||||
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
47
metagpt/rag/retrievers/base.py
Normal file
47
metagpt/rag/retrievers/base.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""Base retriever."""
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import BaseNode, NodeWithScore, QueryType
|
||||
|
||||
from metagpt.utils.reflection import check_methods
|
||||
|
||||
|
||||
class RAGRetriever(BaseRetriever):
|
||||
"""Inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""Retrieve nodes"""
|
||||
|
||||
|
||||
class ModifiableRAGRetriever(RAGRetriever):
|
||||
"""Support modification."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is ModifiableRAGRetriever:
|
||||
return check_methods(C, "add_nodes")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""To support add docs, must inplement this func"""
|
||||
|
||||
|
||||
class PersistableRAGRetriever(RAGRetriever):
|
||||
"""Support persistent."""
|
||||
|
||||
@classmethod
|
||||
def __subclasshook__(cls, C):
|
||||
if cls is PersistableRAGRetriever:
|
||||
return check_methods(C, "persist")
|
||||
return NotImplemented
|
||||
|
||||
@abstractmethod
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""To support persist, must inplement this func"""
|
||||
47
metagpt/rag/retrievers/bm25_retriever.py
Normal file
47
metagpt/rag/retrievers/bm25_retriever.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
"""BM25 retriever."""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.schema import BaseNode, IndexNode
|
||||
from llama_index.retrievers.bm25 import BM25Retriever
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
|
||||
class DynamicBM25Retriever(BM25Retriever):
|
||||
"""BM25 retriever."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: list[BaseNode],
|
||||
tokenizer: Optional[Callable[[str], list[str]]] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[list[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
index: VectorStoreIndex = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
nodes=nodes,
|
||||
tokenizer=tokenizer,
|
||||
similarity_top_k=similarity_top_k,
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
self._index = index
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._nodes.extend(nodes)
|
||||
self._corpus = [self._tokenizer(node.get_content()) for node in self._nodes]
|
||||
self.bm25 = BM25Okapi(self._corpus)
|
||||
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
17
metagpt/rag/retrievers/chroma_retriever.py
Normal file
17
metagpt/rag/retrievers/chroma_retriever.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Chroma retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class ChromaRetriever(VectorIndexRetriever):
|
||||
"""Chroma retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes."""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist.
|
||||
|
||||
Chromadb automatically saves, so there is no need to implement."""
|
||||
16
metagpt/rag/retrievers/faiss_retriever.py
Normal file
16
metagpt/rag/retrievers/faiss_retriever.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
"""FAISS retriever."""
|
||||
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
from llama_index.core.schema import BaseNode
|
||||
|
||||
|
||||
class FAISSRetriever(VectorIndexRetriever):
|
||||
"""FAISS retriever."""
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode], **kwargs) -> None:
|
||||
"""Support add nodes"""
|
||||
self._index.insert_nodes(nodes, **kwargs)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
self._index.storage_context.persist(persist_dir)
|
||||
48
metagpt/rag/retrievers/hybrid_retriever.py
Normal file
48
metagpt/rag/retrievers/hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Hybrid retriever."""
|
||||
|
||||
import copy
|
||||
|
||||
from llama_index.core.schema import BaseNode, QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
class SimpleHybridRetriever(RAGRetriever):
|
||||
"""A composite retriever that aggregates search results from multiple retrievers."""
|
||||
|
||||
def __init__(self, *retrievers):
|
||||
self.retrievers: list[RAGRetriever] = retrievers
|
||||
super().__init__()
|
||||
|
||||
async def _aretrieve(self, query: QueryType, **kwargs):
|
||||
"""Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
|
||||
This method queries each retriever in the `retrievers` list with the given query and
|
||||
additional keyword arguments. It then combines the results, ensuring that each node is
|
||||
unique, based on the node's ID.
|
||||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
# Prevent retriever changing query
|
||||
query_copy = copy.deepcopy(query)
|
||||
nodes = await retriever.aretrieve(query_copy, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
# combine all nodes
|
||||
result = []
|
||||
node_ids = set()
|
||||
for n in all_nodes:
|
||||
if n.node.node_id not in node_ids:
|
||||
result.append(n)
|
||||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
|
||||
def add_nodes(self, nodes: list[BaseNode]) -> None:
|
||||
"""Support add nodes."""
|
||||
for r in self.retrievers:
|
||||
r.add_nodes(nodes)
|
||||
|
||||
def persist(self, persist_dir: str, **kwargs) -> None:
|
||||
"""Support persist."""
|
||||
for r in self.retrievers:
|
||||
r.persist(persist_dir, **kwargs)
|
||||
124
metagpt/rag/schema.py
Normal file
124
metagpt/rag/schema.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
"""RAG schemas."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Union
|
||||
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
|
||||
from metagpt.rag.interface import RAGObject
|
||||
|
||||
|
||||
class BaseRetrieverConfig(BaseModel):
|
||||
"""Common config for retrievers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.retriever.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
similarity_top_k: int = Field(default=5, description="Number of top-k similar results to return during retrieval.")
|
||||
|
||||
|
||||
class IndexRetrieverConfig(BaseRetrieverConfig):
|
||||
"""Config for Index-basd retrievers."""
|
||||
|
||||
index: BaseIndex = Field(default=None, description="Index for retriver.")
|
||||
|
||||
|
||||
class FAISSRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for FAISS-based retrievers."""
|
||||
|
||||
dimensions: int = Field(default=1536, description="Dimensionality of the vectors for FAISS index construction.")
|
||||
|
||||
|
||||
class BM25RetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for BM25-based retrievers."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ChromaRetrieverConfig(IndexRetrieverConfig):
|
||||
"""Config for Chroma-based retrievers."""
|
||||
|
||||
persist_path: Union[str, Path] = Field(default="./chroma_db", description="The directory to save data.")
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BaseRankerConfig(BaseModel):
|
||||
"""Common config for rankers.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.ranker.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
top_n: int = Field(default=5, description="The number of top results to return.")
|
||||
|
||||
|
||||
class LLMRankerConfig(BaseRankerConfig):
|
||||
"""Config for LLM-based rankers."""
|
||||
|
||||
llm: Any = Field(
|
||||
default=None,
|
||||
description="The LLM to rerank with. using Any instead of LLM, as llama_index.core.llms.LLM is pydantic.v1.",
|
||||
)
|
||||
|
||||
|
||||
class BaseIndexConfig(BaseModel):
|
||||
"""Common config for index.
|
||||
|
||||
If add new subconfig, it is necessary to add the corresponding instance implementation in rag.factories.index.
|
||||
"""
|
||||
|
||||
persist_path: Union[str, Path] = Field(description="The directory of saved data.")
|
||||
|
||||
|
||||
class VectorIndexConfig(BaseIndexConfig):
|
||||
"""Config for vector-based index."""
|
||||
|
||||
embed_model: BaseEmbedding = Field(default=None, description="Embed model.")
|
||||
|
||||
|
||||
class FAISSIndexConfig(VectorIndexConfig):
|
||||
"""Config for faiss-based index."""
|
||||
|
||||
|
||||
class ChromaIndexConfig(VectorIndexConfig):
|
||||
"""Config for chroma-based index."""
|
||||
|
||||
collection_name: str = Field(default="metagpt", description="The name of the collection.")
|
||||
|
||||
|
||||
class BM25IndexConfig(BaseIndexConfig):
|
||||
"""Config for bm25-based index."""
|
||||
|
||||
_no_embedding: bool = PrivateAttr(default=True)
|
||||
|
||||
|
||||
class ObjectNodeMetadata(BaseModel):
|
||||
"""Metadata of ObjectNode."""
|
||||
|
||||
is_obj: bool = Field(default=True)
|
||||
obj: Any = Field(default=None, description="When rag retrieve, will reconstruct obj from obj_json")
|
||||
obj_json: str = Field(..., description="The json of object, e.g. obj.model_dump_json()")
|
||||
obj_cls_name: str = Field(..., description="The class name of object, e.g. obj.__class__.__name__")
|
||||
obj_mod_name: str = Field(..., description="The module name of class, e.g. obj.__class__.__module__")
|
||||
|
||||
|
||||
class ObjectNode(TextNode):
|
||||
"""RAG add object."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.excluded_llm_metadata_keys = list(ObjectNodeMetadata.model_fields.keys())
|
||||
self.excluded_embed_metadata_keys = self.excluded_llm_metadata_keys
|
||||
|
||||
@staticmethod
|
||||
def get_obj_metadata(obj: RAGObject) -> dict:
|
||||
metadata = ObjectNodeMetadata(
|
||||
obj_json=obj.model_dump_json(), obj_cls_name=obj.__class__.__name__, obj_mod_name=obj.__class__.__module__
|
||||
)
|
||||
|
||||
return metadata.model_dump()
|
||||
0
metagpt/rag/vector_stores/__init__.py
Normal file
0
metagpt/rag/vector_stores/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
3
metagpt/rag/vector_stores/chroma/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.vector_stores.chroma.base import ChromaVectorStore
|
||||
|
||||
__all__ = ["ChromaVectorStore"]
|
||||
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
290
metagpt/rag/vector_stores/chroma/base.py
Normal file
|
|
@ -0,0 +1,290 @@
|
|||
"""Chroma vector store.
|
||||
|
||||
Refs to https://github.com/run-llama/llama_index/blob/v0.10.12/llama-index-integrations/vector_stores/llama-index-vector-stores-chroma/llama_index/vector_stores/chroma/base.py.
|
||||
The repo requires onnxruntime = "^1.17.0", which is too new for many OS systems, such as CentOS7.
|
||||
"""
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, Generator, List, Optional, cast
|
||||
|
||||
import chromadb
|
||||
from chromadb.api.models.Collection import Collection
|
||||
from llama_index.core.bridge.pydantic import Field, PrivateAttr
|
||||
from llama_index.core.schema import BaseNode, MetadataMode, TextNode
|
||||
from llama_index.core.utils import truncate_text
|
||||
from llama_index.core.vector_stores.types import (
|
||||
BasePydanticVectorStore,
|
||||
MetadataFilters,
|
||||
VectorStoreQuery,
|
||||
VectorStoreQueryResult,
|
||||
)
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
legacy_metadata_dict_to_node,
|
||||
metadata_dict_to_node,
|
||||
node_to_metadata_dict,
|
||||
)
|
||||
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def _transform_chroma_filter_condition(condition: str) -> str:
|
||||
"""Translate standard metadata filter op to Chroma specific spec."""
|
||||
if condition == "and":
|
||||
return "$and"
|
||||
elif condition == "or":
|
||||
return "$or"
|
||||
else:
|
||||
raise ValueError(f"Filter condition {condition} not supported")
|
||||
|
||||
|
||||
def _transform_chroma_filter_operator(operator: str) -> str:
|
||||
"""Translate standard metadata filter operator to Chroma specific spec."""
|
||||
if operator == "!=":
|
||||
return "$ne"
|
||||
elif operator == "==":
|
||||
return "$eq"
|
||||
elif operator == ">":
|
||||
return "$gt"
|
||||
elif operator == "<":
|
||||
return "$lt"
|
||||
elif operator == ">=":
|
||||
return "$gte"
|
||||
elif operator == "<=":
|
||||
return "$lte"
|
||||
else:
|
||||
raise ValueError(f"Filter operator {operator} not supported")
|
||||
|
||||
|
||||
def _to_chroma_filter(
|
||||
standard_filters: MetadataFilters,
|
||||
) -> dict:
|
||||
"""Translate standard metadata filters to Chroma specific spec."""
|
||||
filters = {}
|
||||
filters_list = []
|
||||
condition = standard_filters.condition or "and"
|
||||
condition = _transform_chroma_filter_condition(condition)
|
||||
if standard_filters.filters:
|
||||
for filter in standard_filters.filters:
|
||||
if filter.operator:
|
||||
filters_list.append({filter.key: {_transform_chroma_filter_operator(filter.operator): filter.value}})
|
||||
else:
|
||||
filters_list.append({filter.key: filter.value})
|
||||
if len(filters_list) == 1:
|
||||
# If there is only one filter, return it directly
|
||||
return filters_list[0]
|
||||
elif len(filters_list) > 1:
|
||||
filters[condition] = filters_list
|
||||
return filters
|
||||
|
||||
|
||||
import_err_msg = "`chromadb` package not found, please run `pip install chromadb`"
|
||||
MAX_CHUNK_SIZE = 41665 # One less than the max chunk size for ChromaDB
|
||||
|
||||
|
||||
def chunk_list(lst: List[BaseNode], max_chunk_size: int) -> Generator[List[BaseNode], None, None]:
|
||||
"""Yield successive max_chunk_size-sized chunks from lst.
|
||||
Args:
|
||||
lst (List[BaseNode]): list of nodes with embeddings
|
||||
max_chunk_size (int): max chunk size
|
||||
Yields:
|
||||
Generator[List[BaseNode], None, None]: list of nodes with embeddings
|
||||
"""
|
||||
for i in range(0, len(lst), max_chunk_size):
|
||||
yield lst[i : i + max_chunk_size]
|
||||
|
||||
|
||||
class ChromaVectorStore(BasePydanticVectorStore):
|
||||
"""Chroma vector store.
|
||||
In this vector store, embeddings are stored within a ChromaDB collection.
|
||||
During query time, the index uses ChromaDB to query for the top
|
||||
k most similar nodes.
|
||||
Args:
|
||||
chroma_collection (chromadb.api.models.Collection.Collection):
|
||||
ChromaDB collection instance
|
||||
"""
|
||||
|
||||
stores_text: bool = True
|
||||
flat_metadata: bool = True
|
||||
collection_name: Optional[str]
|
||||
host: Optional[str]
|
||||
port: Optional[str]
|
||||
ssl: bool
|
||||
headers: Optional[Dict[str, str]]
|
||||
persist_dir: Optional[str]
|
||||
collection_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
_collection: Any = PrivateAttr()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
chroma_collection: Optional[Any] = None,
|
||||
collection_name: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
collection_kwargs = collection_kwargs or {}
|
||||
if chroma_collection is None:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
self._collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
self._collection = cast(Collection, chroma_collection)
|
||||
super().__init__(
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
collection_name=collection_name,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_collection(cls, collection: Any) -> "ChromaVectorStore":
|
||||
try:
|
||||
from chromadb import Collection
|
||||
except ImportError:
|
||||
raise ImportError(import_err_msg)
|
||||
if not isinstance(collection, Collection):
|
||||
raise Exception("argument is not chromadb collection instance")
|
||||
return cls(chroma_collection=collection)
|
||||
|
||||
@classmethod
|
||||
def from_params(
|
||||
cls,
|
||||
collection_name: str,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
ssl: bool = False,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
persist_dir: Optional[str] = None,
|
||||
collection_kwargs: dict = {},
|
||||
**kwargs: Any,
|
||||
) -> "ChromaVectorStore":
|
||||
if persist_dir:
|
||||
client = chromadb.PersistentClient(path=persist_dir)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
elif host and port:
|
||||
client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers)
|
||||
collection = client.get_or_create_collection(name=collection_name, **collection_kwargs)
|
||||
else:
|
||||
raise ValueError("Either `persist_dir` or (`host`,`port`) must be specified")
|
||||
return cls(
|
||||
chroma_collection=collection,
|
||||
host=host,
|
||||
port=port,
|
||||
ssl=ssl,
|
||||
headers=headers,
|
||||
persist_dir=persist_dir,
|
||||
collection_kwargs=collection_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "ChromaVectorStore"
|
||||
|
||||
def add(self, nodes: List[BaseNode], **add_kwargs: Any) -> List[str]:
|
||||
"""Add nodes to index.
|
||||
Args:
|
||||
nodes: List[BaseNode]: list of nodes with embeddings
|
||||
"""
|
||||
if not self._collection:
|
||||
raise ValueError("Collection not initialized")
|
||||
max_chunk_size = MAX_CHUNK_SIZE
|
||||
node_chunks = chunk_list(nodes, max_chunk_size)
|
||||
all_ids = []
|
||||
for node_chunk in node_chunks:
|
||||
embeddings = []
|
||||
metadatas = []
|
||||
ids = []
|
||||
documents = []
|
||||
for node in node_chunk:
|
||||
embeddings.append(node.get_embedding())
|
||||
metadata_dict = node_to_metadata_dict(node, remove_text=True, flat_metadata=self.flat_metadata)
|
||||
for key in metadata_dict:
|
||||
if metadata_dict[key] is None:
|
||||
metadata_dict[key] = ""
|
||||
metadatas.append(metadata_dict)
|
||||
ids.append(node.node_id)
|
||||
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
|
||||
self._collection.add(
|
||||
embeddings=embeddings,
|
||||
ids=ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents,
|
||||
)
|
||||
all_ids.extend(ids)
|
||||
return all_ids
|
||||
|
||||
def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None:
|
||||
"""
|
||||
Delete nodes using with ref_doc_id.
|
||||
Args:
|
||||
ref_doc_id (str): The doc_id of the document to delete.
|
||||
"""
|
||||
self._collection.delete(where={"document_id": ref_doc_id})
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
"""Return client."""
|
||||
return self._collection
|
||||
|
||||
def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult:
|
||||
"""Query index for top k most similar nodes.
|
||||
Args:
|
||||
query_embedding (List[float]): query embedding
|
||||
similarity_top_k (int): top k most similar nodes
|
||||
"""
|
||||
if query.filters is not None:
|
||||
if "where" in kwargs:
|
||||
raise ValueError(
|
||||
"Cannot specify metadata filters via both query and kwargs. "
|
||||
"Use kwargs only for chroma specific items that are "
|
||||
"not supported via the generic query interface."
|
||||
)
|
||||
where = _to_chroma_filter(query.filters)
|
||||
else:
|
||||
where = kwargs.pop("where", {})
|
||||
results = self._collection.query(
|
||||
query_embeddings=query.query_embedding,
|
||||
n_results=query.similarity_top_k,
|
||||
where=where,
|
||||
**kwargs,
|
||||
)
|
||||
logger.debug(f"> Top {len(results['documents'])} nodes:")
|
||||
nodes = []
|
||||
similarities = []
|
||||
ids = []
|
||||
for node_id, text, metadata, distance in zip(
|
||||
results["ids"][0],
|
||||
results["documents"][0],
|
||||
results["metadatas"][0],
|
||||
results["distances"][0],
|
||||
):
|
||||
try:
|
||||
node = metadata_dict_to_node(metadata)
|
||||
node.set_content(text)
|
||||
except Exception:
|
||||
# NOTE: deprecated legacy logic for backward compatibility
|
||||
metadata, node_info, relationships = legacy_metadata_dict_to_node(metadata)
|
||||
node = TextNode(
|
||||
text=text,
|
||||
id_=node_id,
|
||||
metadata=metadata,
|
||||
start_char_idx=node_info.get("start", None),
|
||||
end_char_idx=node_info.get("end", None),
|
||||
relationships=relationships,
|
||||
)
|
||||
nodes.append(node)
|
||||
similarity_score = math.exp(-distance)
|
||||
similarities.append(similarity_score)
|
||||
logger.debug(
|
||||
f"> [Node {node_id}] [Similarity score: {similarity_score}] " f"{truncate_text(str(text), 100)}"
|
||||
)
|
||||
ids.append(node_id)
|
||||
return VectorStoreQueryResult(nodes=nodes, similarities=similarities, ids=ids)
|
||||
|
|
@ -108,12 +108,6 @@ class RoleContext(BaseModel):
|
|||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "enable_longterm_memory") and CONFIG.enable_longterm_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
# self.memory = self.long_term_memory # use memory to act as long_term_memory for unify operation
|
||||
pass
|
||||
|
||||
@property
|
||||
def important_memory(self) -> list[Message]:
|
||||
"""Retrieve information corresponding to the attention action."""
|
||||
|
|
@ -311,8 +305,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel):
|
|||
buffer during _observe.
|
||||
"""
|
||||
self.rc.watch = {any_to_str(t) for t in actions}
|
||||
# check RoleContext after adding watch actions
|
||||
self.rc.check(self.role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self.rc.watch
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ from typing import Optional
|
|||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
|
@ -27,7 +26,7 @@ class Sales(Role):
|
|||
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
|
||||
)
|
||||
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_stroe(self):
|
||||
|
|
|
|||
|
|
@ -233,6 +233,10 @@ class Message(BaseModel):
|
|||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("send_to", mode="plain")
|
||||
def ser_send_to(self, send_to: set) -> list:
|
||||
return list(send_to)
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[dict, None]:
|
||||
ic_dict = None
|
||||
|
|
@ -276,6 +280,10 @@ class Message(BaseModel):
|
|||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.content
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Return a dict containing `role` and `content` for the LLM call.l"""
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum):
|
|||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class SearchInterface:
|
||||
async def asearch(self, *args, **kwargs):
|
||||
...
|
||||
|
|
|
|||
22
metagpt/utils/async_helper.py
Normal file
22
metagpt/utils/async_helper.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import asyncio
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
|
||||
def run_coroutine_in_new_loop(coroutine) -> Any:
|
||||
"""Runs a coroutine in a new, separate event loop on a different thread.
|
||||
|
||||
This function is useful when try to execute an async function within a sync function, but encounter the error `RuntimeError: This event loop is already running`.
|
||||
"""
|
||||
new_loop = asyncio.new_event_loop()
|
||||
t = threading.Thread(target=lambda: new_loop.run_forever())
|
||||
t.start()
|
||||
|
||||
future = asyncio.run_coroutine_threadsafe(coroutine, new_loop)
|
||||
|
||||
try:
|
||||
return future.result()
|
||||
finally:
|
||||
new_loop.call_soon_threadsafe(new_loop.stop)
|
||||
t.join()
|
||||
new_loop.close()
|
||||
|
|
@ -5,12 +5,15 @@
|
|||
@Author : alexanderwu
|
||||
@File : embedding.py
|
||||
"""
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_embedding():
|
||||
def get_embedding() -> OpenAIEmbedding:
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbeddings(openai_api_key=llm.api_key, openai_api_base=llm.base_url)
|
||||
if llm is None:
|
||||
raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.")
|
||||
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return embedding
|
||||
|
|
|
|||
18
metagpt/utils/reflection.py
Normal file
18
metagpt/utils/reflection.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
|
||||
|
||||
|
||||
def check_methods(C, *methods):
|
||||
"""Check if the class has methods. borrow from _collections_abc.
|
||||
|
||||
Useful when implementing implicit interfaces, such as defining an abstract class, isinstance can be used for determination without inheritance.
|
||||
"""
|
||||
mro = C.__mro__
|
||||
for method in methods:
|
||||
for B in mro:
|
||||
if method in B.__dict__:
|
||||
if B.__dict__[method] is None:
|
||||
return NotImplemented
|
||||
break
|
||||
else:
|
||||
return NotImplemented
|
||||
return True
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
aiohttp==3.8.4
|
||||
aiohttp==3.8.6
|
||||
#azure_storage==0.37.0
|
||||
channels==4.0.0
|
||||
# chromadb
|
||||
# Django==4.1.5
|
||||
# docx==0.2.4
|
||||
#faiss==1.5.3
|
||||
|
|
@ -11,14 +10,20 @@ typer==0.9.0
|
|||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0 # Used by search_engine.py
|
||||
lancedb==0.4.0
|
||||
langchain==0.1.8
|
||||
sqlalchemy==2.0.0 # along with langchain
|
||||
llama-index-core==0.10.15
|
||||
llama-index-embeddings-azure-openai==0.1.6
|
||||
llama-index-embeddings-openai==0.1.5
|
||||
llama-index-llms-azure-openai==0.1.4
|
||||
llama-index-readers-file==0.1.4
|
||||
llama-index-retrievers-bm25==0.1.3
|
||||
llama-index-vector-stores-faiss==0.1.1
|
||||
chromadb==0.4.23
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy>=1.24.3,<1.25.0
|
||||
openai==1.6.0
|
||||
numpy==1.24.3
|
||||
openai==1.6.1
|
||||
openpyxl
|
||||
beautifulsoup4==4.12.2
|
||||
beautifulsoup4==4.12.3
|
||||
pandas==2.0.3
|
||||
pydantic==2.5.3
|
||||
#pygame==2.1.3
|
||||
|
|
@ -30,7 +35,7 @@ PyYAML==6.0.1
|
|||
setuptools==65.6.3
|
||||
tenacity==8.2.3
|
||||
tiktoken==0.5.2
|
||||
tqdm==4.65.0
|
||||
tqdm==4.66.2
|
||||
#unstructured[local-inference]
|
||||
# selenium>4
|
||||
# webdriver_manager<3.9
|
||||
|
|
@ -61,7 +66,7 @@ typing-extensions==4.9.0
|
|||
socksio~=1.0.0
|
||||
gitignore-parser==0.1.9
|
||||
# connexion[uvicorn]~=3.0.5 # Used by metagpt/tools/openapi_v3_hello.py
|
||||
websockets~=12.0
|
||||
websockets~=11.0
|
||||
networkx~=3.2.1
|
||||
google-generativeai==0.3.2
|
||||
playwright>=1.26 # used at metagpt/tools/libs/web_scraping.py
|
||||
|
|
|
|||
2
setup.py
2
setup.py
|
|
@ -42,7 +42,7 @@ extras_require["test"] = [
|
|||
"connexion[uvicorn]~=3.0.5",
|
||||
"azure-cognitiveservices-speech~=1.31.0",
|
||||
"aioboto3~=11.3.0",
|
||||
"chromadb==0.4.14",
|
||||
"chromadb==0.4.23",
|
||||
"gradio==3.0.0",
|
||||
"grpcio-status==1.48.2",
|
||||
"pylint==3.0.3",
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from metagpt.document_store.chromadb_store import ChromaStore
|
|||
def test_chroma_store():
|
||||
"""FIXME:chroma使用感觉很诡异,一用Python就挂,测试用例里也是"""
|
||||
# 创建 ChromaStore 实例,使用 'sample_collection' 集合
|
||||
document_store = ChromaStore("sample_collection_1")
|
||||
document_store = ChromaStore("sample_collection_1", get_or_create=True)
|
||||
|
||||
# 使用 write 方法添加多个文档
|
||||
document_store.write(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,6 @@
|
|||
@File : test_faiss_store.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
|
|
@ -17,18 +15,24 @@ from metagpt.logs import logger
|
|||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
num = len(texts)
|
||||
embeds = np.random.randint(1, 100, size=(num, 1536)) # 1536: openai embedding dim
|
||||
embeds = (embeds - embeds.mean(axis=0)) / (embeds.std(axis=0))
|
||||
return embeds
|
||||
embeds = (embeds - embeds.mean(axis=0)) / embeds.std(axis=0)
|
||||
return embeds.tolist()
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -37,9 +41,10 @@ async def test_search_json(mocker):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -48,9 +53,10 @@ async def test_search_xlsx(mocker):
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/search_kb/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.docstore
|
||||
assert _faiss_store.index
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
|
|
@ -2,32 +2,41 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
dim = 1536 # openai embedding dim
|
||||
embed_zeros_arrr = np.zeros(shape=[1, dim]).tolist()
|
||||
embed_ones_arrr = np.ones(shape=[1, dim]).tolist()
|
||||
|
||||
text_embed_arr = [
|
||||
{"text": "Write a cli snake game", "embed": np.zeros(shape=[1, dim])}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "Write a 2048 web game", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a Battle City", "embed": np.ones(shape=[1, dim])},
|
||||
{"text": "Write a cli snake game", "embed": embed_zeros_arrr}, # mock data, same as below
|
||||
{"text": "Write a game of cli snake", "embed": embed_zeros_arrr},
|
||||
{"text": "Write a 2048 web game", "embed": embed_ones_arrr},
|
||||
{"text": "Write a Battle City", "embed": embed_ones_arrr},
|
||||
{
|
||||
"text": "The user has requested the creation of a command-line interface (CLI) snake game",
|
||||
"embed": np.zeros(shape=[1, dim]),
|
||||
"embed": embed_zeros_arrr,
|
||||
},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": np.zeros(shape=[1, dim])},
|
||||
{"text": "The request is command-line interface (CLI) snake game", "embed": embed_zeros_arrr},
|
||||
{
|
||||
"text": "Incorporate basic features of a snake game such as scoring and increasing difficulty",
|
||||
"embed": np.ones(shape=[1, dim]),
|
||||
"embed": embed_ones_arrr,
|
||||
},
|
||||
]
|
||||
|
||||
text_idx_dict = {item["text"]: idx for idx, item in enumerate(text_embed_arr)}
|
||||
|
||||
|
||||
def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int] = 0) -> list[list[float]]:
|
||||
def mock_openai_embed_documents(self, texts: list[str], show_progress: bool = False) -> list[list[float]]:
|
||||
idx = text_idx_dict.get(texts[0])
|
||||
embed = text_embed_arr[idx].get("embed")
|
||||
return embed
|
||||
|
||||
|
||||
def mock_openai_embed_document(self, text: str) -> list[float]:
|
||||
embeds = mock_openai_embed_documents(self, [text])
|
||||
return embeds[0]
|
||||
|
||||
|
||||
async def mock_openai_aembed_document(self, text: str) -> list[float]:
|
||||
return mock_openai_embed_document(self, text)
|
||||
|
|
|
|||
|
|
@ -12,13 +12,20 @@ from metagpt.memory.longterm_memory import LongTermMemory
|
|||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_ltm_search(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_ltm_search(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
|
@ -31,39 +38,24 @@ def test_ltm_search(mocker):
|
|||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([message])
|
||||
news = await ltm.find_news([message])
|
||||
assert len(news) == 1
|
||||
ltm.add(message)
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([sim_message])
|
||||
news = await ltm.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
ltm.add(sim_message)
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm.find_news([new_message])
|
||||
news = await ltm.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
ltm.add(new_message)
|
||||
|
||||
# restore from local index
|
||||
ltm_new = LongTermMemory()
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([message])
|
||||
assert len(news) == 0
|
||||
|
||||
ltm_new.recover_memory(role_id, rc)
|
||||
news = ltm_new.find_news([sim_message])
|
||||
assert len(news) == 0
|
||||
|
||||
new_idea = text_embed_arr[3].get("text", "Write a Battle City")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
news = ltm_new.find_news([new_message])
|
||||
assert len(news) == 1
|
||||
|
||||
ltm_new.clear()
|
||||
ltm.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -8,19 +8,28 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
from tests.metagpt.memory.mock_text_embed import (
|
||||
mock_openai_aembed_document,
|
||||
mock_openai_embed_document,
|
||||
mock_openai_embed_documents,
|
||||
text_embed_arr,
|
||||
)
|
||||
|
||||
|
||||
def test_idea_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_idea_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
idea = text_embed_arr[0].get("text", "Write a cli snake game")
|
||||
role_id = "UTUser1(Product Manager)"
|
||||
|
|
@ -29,28 +38,32 @@ def test_idea_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_idea = text_embed_arr[1].get("text", "Write a game of cli snake")
|
||||
sim_message = Message(role="User", content=sim_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_idea = text_embed_arr[2].get("text", "Write a 2048 web game")
|
||||
new_message = Message(role="User", content=new_idea, cause_by=UserRequirement)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
||||
|
||||
def test_actionout_message(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
@pytest.mark.asyncio
|
||||
async def test_actionout_message(mocker):
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embeddings", mock_openai_embed_documents)
|
||||
mocker.patch("llama_index.embeddings.openai.base.OpenAIEmbedding._get_text_embedding", mock_openai_embed_document)
|
||||
mocker.patch(
|
||||
"llama_index.embeddings.openai.base.OpenAIEmbedding._aget_query_embedding", mock_openai_aembed_document
|
||||
)
|
||||
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
|
|
@ -67,23 +80,22 @@ def test_actionout_message(mocker):
|
|||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
assert len(messages) == 0
|
||||
memory_storage.recover_memory(role_id)
|
||||
|
||||
memory_storage.add(message)
|
||||
assert memory_storage.is_initialized is True
|
||||
|
||||
sim_conent = text_embed_arr[5].get("text", "The request is command-line interface (CLI) snake game")
|
||||
sim_message = Message(content=sim_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(sim_message)
|
||||
assert len(new_messages) == 0 # similar, return []
|
||||
new_messages = await memory_storage.search_similar(sim_message)
|
||||
assert len(new_messages) == 1 # similar, return []
|
||||
|
||||
new_conent = text_embed_arr[6].get(
|
||||
"text", "Incorporate basic features of a snake game such as scoring and increasing difficulty"
|
||||
)
|
||||
new_message = Message(content=new_conent, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD)
|
||||
new_messages = memory_storage.search_dissimilar(new_message)
|
||||
assert new_messages[0].content == message.content
|
||||
new_messages = await memory_storage.search_similar(new_message)
|
||||
assert len(new_messages) == 0
|
||||
|
||||
memory_storage.clean()
|
||||
assert memory_storage.is_initialized is False
|
||||
|
|
|
|||
166
tests/metagpt/rag/engines/test_simple.py
Normal file
166
tests/metagpt/rag/engines/test_simple.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Document, TextNode
|
||||
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.retrievers.base import ModifiableRAGRetriever
|
||||
|
||||
|
||||
class TestSimpleEngine:
|
||||
@pytest.fixture
|
||||
def mock_simple_directory_reader(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_retriever(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_retriever")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_rankers(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_rankers")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_response_synthesizer(self, mocker):
|
||||
return mocker.patch("metagpt.rag.engines.simple.get_response_synthesizer")
|
||||
|
||||
def test_from_docs(
|
||||
self,
|
||||
mocker,
|
||||
mock_simple_directory_reader,
|
||||
mock_vector_store_index,
|
||||
mock_get_retriever,
|
||||
mock_get_rankers,
|
||||
mock_get_response_synthesizer,
|
||||
):
|
||||
# Mock
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
mock_get_retriever.return_value = mocker.MagicMock()
|
||||
mock_get_rankers.return_value = [mocker.MagicMock()]
|
||||
mock_get_response_synthesizer.return_value = mocker.MagicMock()
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
transformations = [mocker.MagicMock()]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
retriever_configs = [mocker.MagicMock()]
|
||||
ranker_configs = [mocker.MagicMock()]
|
||||
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
transformations=transformations,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
retriever_configs=retriever_configs,
|
||||
ranker_configs=ranker_configs,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_vector_store_index.assert_called_once()
|
||||
mock_get_retriever.assert_called_once_with(
|
||||
configs=retriever_configs, index=mock_vector_store_index.return_value
|
||||
)
|
||||
mock_get_rankers.assert_called_once_with(configs=ranker_configs, llm=llm)
|
||||
mock_get_response_synthesizer.assert_called_once_with(llm=llm)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = mocker.AsyncMock(return_value=expected_result)
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self, mocker):
|
||||
# Mock
|
||||
mock_query_bundle = mocker.patch("metagpt.rag.engines.simple.QueryBundle", return_value="query_bundle")
|
||||
mock_super_aretrieve = mocker.patch(
|
||||
"metagpt.rag.engines.simple.RetrieverQueryEngine.aretrieve", new_callable=mocker.AsyncMock
|
||||
)
|
||||
mock_super_aretrieve.return_value = [TextNode(text="node_with_score", metadata={"is_obj": False})]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
test_query = "test query"
|
||||
|
||||
# Execute
|
||||
result = await engine.aretrieve(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_query_bundle.assert_called_once_with(test_query)
|
||||
mock_super_aretrieve.assert_called_once_with("query_bundle")
|
||||
assert result[0].text == "node_with_score"
|
||||
|
||||
def test_add_docs(self, mocker):
|
||||
# Mock
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = [
|
||||
Document(text="document1"),
|
||||
Document(text="document2"),
|
||||
]
|
||||
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
mock_index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock_index._transformations = mocker.MagicMock()
|
||||
|
||||
mock_run_transformations = mocker.patch("metagpt.rag.engines.simple.run_transformations")
|
||||
mock_run_transformations.return_value = ["node1", "node2"]
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mock_index)
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
|
||||
# Execute
|
||||
engine.add_docs(input_files=input_files)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_files=input_files)
|
||||
mock_retriever.add_nodes.assert_called_once_with(["node1", "node2"])
|
||||
|
||||
def test_add_objs(self, mocker):
|
||||
# Mock
|
||||
mock_retriever = mocker.MagicMock(spec=ModifiableRAGRetriever)
|
||||
|
||||
# Setup
|
||||
class CustomTextNode(TextNode):
|
||||
def rag_key(self):
|
||||
return ""
|
||||
|
||||
def model_dump_json(self):
|
||||
return ""
|
||||
|
||||
objs = [CustomTextNode(text=f"text_{i}", metadata={"obj": f"obj_{i}"}) for i in range(2)]
|
||||
engine = SimpleEngine(retriever=mock_retriever, index=mocker.MagicMock())
|
||||
|
||||
# Execute
|
||||
engine.add_objs(objs=objs)
|
||||
|
||||
# Assertions
|
||||
assert mock_retriever.add_nodes.call_count == 1
|
||||
for node in mock_retriever.add_nodes.call_args[0][0]:
|
||||
assert isinstance(node, TextNode)
|
||||
assert "is_obj" in node.metadata
|
||||
102
tests/metagpt/rag/factories/test_base.py
Normal file
102
tests/metagpt/rag/factories/test_base.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.rag.factories.base import ConfigBasedFactory, GenericFactory
|
||||
|
||||
|
||||
class TestGenericFactory:
|
||||
@pytest.fixture
|
||||
def creators(self):
|
||||
return {
|
||||
"type1": lambda name: f"Instance of type1 with {name}",
|
||||
"type2": lambda name: f"Instance of type2 with {name}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def factory(self, creators):
|
||||
return GenericFactory(creators=creators)
|
||||
|
||||
def test_get_instance_success(self, factory):
|
||||
# Test successful retrieval of an instance
|
||||
key = "type1"
|
||||
instance = factory.get_instance(key, name="TestName")
|
||||
assert instance == "Instance of type1 with TestName"
|
||||
|
||||
def test_get_instance_failure(self, factory):
|
||||
# Test failure to retrieve an instance due to unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instance("unknown_key")
|
||||
assert "Creator not registered for key: unknown_key" in str(exc_info.value)
|
||||
|
||||
def test_get_instances_success(self, factory):
|
||||
# Test successful retrieval of multiple instances
|
||||
keys = ["type1", "type2"]
|
||||
instances = factory.get_instances(keys, name="TestName")
|
||||
expected = ["Instance of type1 with TestName", "Instance of type2 with TestName"]
|
||||
assert instances == expected
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"keys,expected_exception_message",
|
||||
[
|
||||
(["unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
(["type1", "unknown_key"], "Creator not registered for key: unknown_key"),
|
||||
],
|
||||
)
|
||||
def test_get_instances_with_failure(self, factory, keys, expected_exception_message):
|
||||
# Test failure to retrieve instances due to at least one unregistered key
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
factory.get_instances(keys, name="TestName")
|
||||
assert expected_exception_message in str(exc_info.value)
|
||||
|
||||
|
||||
class DummyConfig:
|
||||
"""A dummy config class for testing."""
|
||||
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
|
||||
class TestConfigBasedFactory:
|
||||
@pytest.fixture
|
||||
def config_creators(self):
|
||||
return {
|
||||
DummyConfig: lambda config, **kwargs: f"Processed {config.name} with {kwargs.get('extra', 'no extra')}",
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def config_factory(self, config_creators):
|
||||
return ConfigBasedFactory(creators=config_creators)
|
||||
|
||||
def test_get_instance_success(self, config_factory):
|
||||
# Test successful retrieval of an instance
|
||||
config = DummyConfig(name="TestConfig")
|
||||
instance = config_factory.get_instance(config, extra="additional data")
|
||||
assert instance == "Processed TestConfig with additional data"
|
||||
|
||||
def test_get_instance_failure(self, config_factory):
|
||||
# Test failure to retrieve an instance due to unknown config type
|
||||
class UnknownConfig:
|
||||
pass
|
||||
|
||||
config = UnknownConfig()
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
config_factory.get_instance(config)
|
||||
assert "Unknown config:" in str(exc_info.value)
|
||||
|
||||
def test_val_from_config_or_kwargs_priority(self):
|
||||
# Test that the value from the config object has priority over kwargs
|
||||
config = DummyConfig(name="ConfigName")
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "ConfigName"
|
||||
|
||||
def test_val_from_config_or_kwargs_fallback_to_kwargs(self):
|
||||
# Test fallback to kwargs when config object does not have the value
|
||||
config = DummyConfig(name=None)
|
||||
result = ConfigBasedFactory._val_from_config_or_kwargs("name", config, name="KwargsName")
|
||||
assert result == "KwargsName"
|
||||
|
||||
def test_val_from_config_or_kwargs_key_error(self):
|
||||
# Test KeyError when the key is not found in both config object and kwargs
|
||||
config = DummyConfig(name=None)
|
||||
with pytest.raises(KeyError) as exc_info:
|
||||
ConfigBasedFactory._val_from_config_or_kwargs("missing_key", config)
|
||||
assert "The key 'missing_key' is required but not provided" in str(exc_info.value)
|
||||
41
tests/metagpt/rag/factories/test_ranker.py
Normal file
41
tests/metagpt/rag/factories/test_ranker.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import pytest
|
||||
from llama_index.core.llms import LLM
|
||||
from llama_index.core.postprocessor import LLMRerank
|
||||
|
||||
from metagpt.rag.factories.ranker import RankerFactory
|
||||
from metagpt.rag.schema import LLMRankerConfig
|
||||
|
||||
|
||||
class TestRankerFactory:
|
||||
@pytest.fixture
|
||||
def ranker_factory(self) -> RankerFactory:
|
||||
return RankerFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
return mocker.MagicMock(spec=LLM)
|
||||
|
||||
def test_get_rankers_with_no_configs(self, ranker_factory: RankerFactory, mock_llm, mocker):
|
||||
mocker.patch.object(ranker_factory, "_extract_llm", return_value=mock_llm)
|
||||
default_rankers = ranker_factory.get_rankers()
|
||||
assert len(default_rankers) == 0
|
||||
|
||||
def test_get_rankers_with_configs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
rankers = ranker_factory.get_rankers(configs=[mock_config])
|
||||
assert len(rankers) == 1
|
||||
assert isinstance(rankers[0], LLMRerank)
|
||||
|
||||
def test_create_llm_ranker_creates_correct_instance(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
ranker = ranker_factory._create_llm_ranker(mock_config)
|
||||
assert isinstance(ranker, LLMRerank)
|
||||
|
||||
def test_extract_llm_from_config(self, ranker_factory: RankerFactory, mock_llm):
|
||||
mock_config = LLMRankerConfig(llm=mock_llm)
|
||||
extracted_llm = ranker_factory._extract_llm(config=mock_config)
|
||||
assert extracted_llm == mock_llm
|
||||
|
||||
def test_extract_llm_from_kwargs(self, ranker_factory: RankerFactory, mock_llm):
|
||||
extracted_llm = ranker_factory._extract_llm(llm=mock_llm)
|
||||
assert extracted_llm == mock_llm
|
||||
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
79
tests/metagpt/rag/factories/test_retriever.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import faiss
|
||||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
|
||||
from metagpt.rag.factories.retriever import RetrieverFactory
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
from metagpt.rag.retrievers.hybrid_retriever import SimpleHybridRetriever
|
||||
from metagpt.rag.schema import BM25RetrieverConfig, FAISSRetrieverConfig
|
||||
|
||||
|
||||
class TestRetrieverFactory:
|
||||
@pytest.fixture
|
||||
def retriever_factory(self):
|
||||
return RetrieverFactory()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_faiss_index(self, mocker):
|
||||
return mocker.MagicMock(spec=faiss.IndexFlatL2)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_store_index(self, mocker):
|
||||
mock = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
mock._embed_model = mocker.MagicMock()
|
||||
mock.docstore.docs.values.return_value = []
|
||||
return mock
|
||||
|
||||
def test_get_retriever_with_faiss_config(
|
||||
self, retriever_factory: RetrieverFactory, mock_faiss_index, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mocker.patch("faiss.IndexFlatL2", return_value=mock_faiss_index)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, FAISSRetriever)
|
||||
|
||||
def test_get_retriever_with_bm25_config(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mock_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_config])
|
||||
|
||||
assert isinstance(retriever, DynamicBM25Retriever)
|
||||
|
||||
def test_get_retriever_with_multiple_configs_returns_hybrid(
|
||||
self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index
|
||||
):
|
||||
mock_faiss_config = FAISSRetrieverConfig(dimensions=128)
|
||||
mock_bm25_config = BM25RetrieverConfig()
|
||||
mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
|
||||
retriever = retriever_factory.get_retriever(configs=[mock_faiss_config, mock_bm25_config])
|
||||
|
||||
assert isinstance(retriever, SimpleHybridRetriever)
|
||||
|
||||
def test_create_default_retriever(self, retriever_factory: RetrieverFactory, mocker, mock_vector_store_index):
|
||||
mocker.patch.object(retriever_factory, "_extract_index", return_value=mock_vector_store_index)
|
||||
mock_vector_store_index.as_retriever = mocker.MagicMock()
|
||||
|
||||
retriever = retriever_factory.get_retriever()
|
||||
|
||||
mock_vector_store_index.as_retriever.assert_called_once()
|
||||
assert retriever is mock_vector_store_index.as_retriever.return_value
|
||||
|
||||
def test_extract_index_from_config(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
mock_config = FAISSRetrieverConfig(index=mock_vector_store_index)
|
||||
|
||||
extracted_index = retriever_factory._extract_index(config=mock_config)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
|
||||
def test_extract_index_from_kwargs(self, retriever_factory: RetrieverFactory, mock_vector_store_index):
|
||||
extracted_index = retriever_factory._extract_index(index=mock_vector_store_index)
|
||||
|
||||
assert extracted_index == mock_vector_store_index
|
||||
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
37
tests/metagpt/rag/retrievers/test_bm25_retriever.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import pytest
|
||||
from llama_index.core import VectorStoreIndex
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.bm25_retriever import DynamicBM25Retriever
|
||||
|
||||
|
||||
class TestDynamicBM25Retriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc1.get_content.return_value = "Document content 1"
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.doc2.get_content.return_value = "Document content 2"
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟index
|
||||
index = mocker.MagicMock(spec=VectorStoreIndex)
|
||||
|
||||
# 模拟nodes和tokenizer参数
|
||||
mock_nodes = []
|
||||
mock_tokenizer = mocker.MagicMock()
|
||||
self.mock_bm25okapi = mocker.patch("rank_bm25.BM25Okapi.__init__", return_value=None)
|
||||
|
||||
# 初始化DynamicBM25Retriever对象,并提供必需的参数
|
||||
self.retriever = DynamicBM25Retriever(nodes=mock_nodes, tokenizer=mock_tokenizer, index=index)
|
||||
|
||||
def test_add_docs_updates_nodes_and_corpus(self):
|
||||
# Execute
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
# Assertions
|
||||
assert len(self.retriever._nodes) == len(self.mock_nodes)
|
||||
assert len(self.retriever._corpus) == len(self.mock_nodes)
|
||||
self.retriever._tokenizer.assert_called()
|
||||
self.mock_bm25okapi.assert_called()
|
||||
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
22
tests/metagpt/rag/retrievers/test_faiss_retriever.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import pytest
|
||||
from llama_index.core.schema import Node
|
||||
|
||||
from metagpt.rag.retrievers.faiss_retriever import FAISSRetriever
|
||||
|
||||
|
||||
class TestFAISSRetriever:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self, mocker):
|
||||
# 创建模拟的Document对象
|
||||
self.doc1 = mocker.MagicMock(spec=Node)
|
||||
self.doc2 = mocker.MagicMock(spec=Node)
|
||||
self.mock_nodes = [self.doc1, self.doc2]
|
||||
|
||||
# 模拟FAISSRetriever的_index属性
|
||||
self.mock_index = mocker.MagicMock()
|
||||
self.retriever = FAISSRetriever(self.mock_index)
|
||||
|
||||
def test_add_docs_calls_insert_for_each_document(self, mocker):
|
||||
self.retriever.add_nodes(self.mock_nodes)
|
||||
|
||||
assert self.mock_index.insert_nodes.assert_called
|
||||
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.core.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
]
|
||||
|
||||
# Instantiate the SimpleHybridRetriever with the mock retrievers
|
||||
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
|
||||
|
||||
# Call the _aretrieve method
|
||||
results = await hybrid_retriever._aretrieve(question)
|
||||
|
||||
# Check if the results are as expected
|
||||
assert len(results) == 3 # Should be 3 unique nodes
|
||||
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
|
||||
|
||||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
Loading…
Add table
Add a link
Reference in a new issue