mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
rag pipeline
This commit is contained in:
parent
cc91df59e5
commit
29d36948bf
18 changed files with 372 additions and 15 deletions
|
|
@ -1,10 +0,0 @@
|
|||
[
|
||||
{
|
||||
"source": "Which facial cleanser is good for oily skin?",
|
||||
"output": "ABC cleanser is preferred by many with oily skin."
|
||||
},
|
||||
{
|
||||
"source": "Is L'Oreal good to use?",
|
||||
"output": "L'Oreal is a popular brand with many positive reviews."
|
||||
}
|
||||
]
|
||||
Binary file not shown.
102
examples/rag_pipeline.py
Normal file
102
examples/rag_pipeline.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""RAG pipeline"""
|
||||
import asyncio
|
||||
|
||||
import faiss
|
||||
from llama_index import (
|
||||
ServiceContext,
|
||||
SimpleDirectoryReader,
|
||||
StorageContext,
|
||||
VectorStoreIndex,
|
||||
)
|
||||
from llama_index.postprocessor import LLMRerank
|
||||
from llama_index.retrievers import BM25Retriever, VectorIndexRetriever
|
||||
from llama_index.vector_stores.faiss import FaissVectorStore
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
DOC_PATH = EXAMPLE_PATH / "data/rag.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
TOPK = 5
|
||||
|
||||
|
||||
def print_result(nodes, extra="retrieve"):
|
||||
"""print retrieve/rerank result"""
|
||||
print("-" * 50)
|
||||
print(f"{extra} result")
|
||||
for i, node in enumerate(nodes):
|
||||
print(f"{i}. {node.text[:10]}..., {node.score}")
|
||||
|
||||
|
||||
async def rag_pipeline():
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
||||
--------------------------------------------------
|
||||
faiss retrieve result
|
||||
0. I highly r..., 0.3958844542503357
|
||||
1. I wrote cu..., 0.41629382967948914
|
||||
2. Productivi..., 0.4318419098854065
|
||||
3. Some sort ..., 0.45991092920303345
|
||||
--------------------------------------------------
|
||||
bm25 retrieve result
|
||||
0. I highly r..., 0.19445682103516615
|
||||
1. Some sort ..., 0.18688966233196197
|
||||
2. Productivi..., 0.17071309618829872
|
||||
3. I wrote cu..., 0.15878996566615383
|
||||
--------------------------------------------------
|
||||
hybrid retrieve result
|
||||
0. I highly r..., 0.3958844542503357
|
||||
1. I wrote cu..., 0.41629382967948914
|
||||
2. Productivi..., 0.4318419098854065
|
||||
3. Some sort ..., 0.45991092920303345
|
||||
--------------------------------------------------
|
||||
llm ranker result
|
||||
0. Productivi..., 10.0
|
||||
1. I wrote cu..., 7.0
|
||||
2. I highly r..., 5.0
|
||||
"""
|
||||
# Documents, there are many readers can load documents.
|
||||
documents = SimpleDirectoryReader(input_files=[DOC_PATH]).load_data()
|
||||
|
||||
# Service Conext, a bundle of resources for llm/embedding/node_parse.
|
||||
service_context = ServiceContext.from_defaults(llm=get_default_llm(), embed_model=get_embedding())
|
||||
|
||||
# Nodes, chunks of documents.
|
||||
node_parser = service_context.node_parser
|
||||
nodes = node_parser.get_nodes_from_documents(documents)
|
||||
|
||||
# Index-FAISS
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) # dimensions of text-ada-embedding-002
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
vector_index = VectorStoreIndex(nodes=nodes, storage_context=storage_context, service_context=service_context)
|
||||
|
||||
# Retriever-FAISS
|
||||
faiss_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=TOPK)
|
||||
faiss_retrieve_nodes = await faiss_retriever.aretrieve(QUESTION)
|
||||
print_result(faiss_retrieve_nodes, extra="faiss retrieve")
|
||||
|
||||
# Retriever-BM25
|
||||
bm25_retriever = BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=TOPK)
|
||||
bm25_retrieve_nodes = await bm25_retriever.aretrieve(QUESTION)
|
||||
print_result(bm25_retrieve_nodes, extra="bm25 retrieve")
|
||||
|
||||
# Retriever-Hybrid
|
||||
hybrid_retriever = SimpleHybridRetriever(faiss_retriever, bm25_retriever)
|
||||
hybrid_retrieve_nodes = await hybrid_retriever.aretrieve(QUESTION)
|
||||
print_result(hybrid_retrieve_nodes, extra="hybrid retrieve")
|
||||
|
||||
# Ranker-LLM
|
||||
llm_ranker = LLMRerank(top_n=TOPK, service_context=service_context)
|
||||
llm_rank_nodes = llm_ranker.postprocess_nodes(faiss_retrieve_nodes, query_str=QUESTION)
|
||||
print_result(llm_rank_nodes, extra="llm ranker")
|
||||
|
||||
|
||||
async def main():
|
||||
"""RAG pipeline"""
|
||||
await rag_pipeline()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
19
examples/rag_search.py
Normal file
19
examples/rag_search.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Agent with RAG search"""
|
||||
import asyncio
|
||||
|
||||
from examples.rag_pipeline import DOC_PATH, QUESTION, TOPK
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
async def search():
|
||||
"""Agent with RAG search"""
|
||||
store = SimpleEngine.from_docs(input_files=[DOC_PATH], similarity_top_k=TOPK)
|
||||
role = Sales(profile="Sales", store=store)
|
||||
result = await role.run(QUESTION)
|
||||
logger.info(result)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(search())
|
||||
0
metagpt/rag/__init__.py
Normal file
0
metagpt/rag/__init__.py
Normal file
3
metagpt/rag/engines/__init__.py
Normal file
3
metagpt/rag/engines/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from metagpt.rag.engines.simple import SimpleEngine
|
||||
|
||||
__all__ = ["SimpleEngine"]
|
||||
48
metagpt/rag/engines/simple.py
Normal file
48
metagpt/rag/engines/simple.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Simple Engine."""
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import ServiceContext, SimpleDirectoryReader, VectorStoreIndex
|
||||
from llama_index.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.embeddings.base import BaseEmbedding
|
||||
from llama_index.llms.llm import LLM
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
|
||||
from metagpt.rag.llm import get_default_llm
|
||||
from metagpt.utils.embedding import get_embedding
|
||||
|
||||
|
||||
class SimpleEngine(RetrieverQueryEngine):
|
||||
"""
|
||||
SimpleEngine is a search engine that uses a vector index for retrieving documents.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_docs(
|
||||
cls,
|
||||
input_dir: str = None,
|
||||
input_files: list = None,
|
||||
embed_model: BaseEmbedding = None,
|
||||
llm: LLM = None,
|
||||
# node parser kwargs
|
||||
chunk_size: Optional[int] = None,
|
||||
chunk_overlap: Optional[int] = None,
|
||||
# retrieve kwargs
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
) -> "SimpleEngine":
|
||||
"""This engine is designed to be simple and straightforward"""
|
||||
documents = SimpleDirectoryReader(input_dir=input_dir, input_files=input_files).load_data()
|
||||
service_context = ServiceContext.from_defaults(
|
||||
embed_model=embed_model or get_embedding(),
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
llm=llm or get_default_llm(),
|
||||
)
|
||||
index = VectorStoreIndex.from_documents(documents, service_context=service_context)
|
||||
retriever = VectorIndexRetriever(index=index, similarity_top_k=similarity_top_k)
|
||||
|
||||
return SimpleEngine(retriever=retriever)
|
||||
|
||||
async def asearch(self, content: str, **kwargs) -> str:
|
||||
"""Inplement tools.SearchInterface"""
|
||||
return await self.aquery(content)
|
||||
7
metagpt/rag/llm.py
Normal file
7
metagpt/rag/llm.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from llama_index.llms import OpenAI
|
||||
|
||||
from metagpt.config2 import config
|
||||
|
||||
|
||||
def get_default_llm() -> OpenAI:
|
||||
return OpenAI(api_base=config.llm.base_url, api_key=config.llm.api_key)
|
||||
0
metagpt/rag/rankers/__init__.py
Normal file
0
metagpt/rag/rankers/__init__.py
Normal file
20
metagpt/rag/rankers/base.py
Normal file
20
metagpt/rag/rankers/base.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Base Ranker."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from llama_index import QueryBundle
|
||||
from llama_index.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.schema import NodeWithScore
|
||||
|
||||
|
||||
class RAGRanker(BaseNodePostprocessor):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: list[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> list[NodeWithScore]:
|
||||
"""postprocess nodes."""
|
||||
4
metagpt/rag/retrievers/__init__.py
Normal file
4
metagpt/rag/retrievers/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
"""init"""
|
||||
from metagpt.rag.retrievers.hybrid import SimpleHybridRetriever
|
||||
|
||||
__all__ = ["SimpleHybridRetriever"]
|
||||
18
metagpt/rag/retrievers/base.py
Normal file
18
metagpt/rag/retrievers/base.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
"""Base retriever."""
|
||||
|
||||
|
||||
from abc import abstractmethod
|
||||
|
||||
from llama_index.retrievers import BaseRetriever
|
||||
from llama_index.schema import NodeWithScore, QueryType
|
||||
|
||||
|
||||
class RAGRetriever(BaseRetriever):
|
||||
"""inherit from llama_index"""
|
||||
|
||||
@abstractmethod
|
||||
async def _aretrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
|
||||
def _retrieve(self, query: QueryType) -> list[NodeWithScore]:
|
||||
"""retrieve nodes"""
|
||||
36
metagpt/rag/retrievers/hybrid.py
Normal file
36
metagpt/rag/retrievers/hybrid.py
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
"""Hybrid retriever."""
|
||||
from llama_index.schema import QueryType
|
||||
|
||||
from metagpt.rag.retrievers.base import RAGRetriever
|
||||
|
||||
|
||||
class SimpleHybridRetriever(RAGRetriever):
|
||||
"""
|
||||
SimpleHybridRetriever is a composite retriever that aggregates search results from multiple retrievers.
|
||||
"""
|
||||
|
||||
def __init__(self, *retrievers):
|
||||
self.retrievers: list[RAGRetriever] = retrievers
|
||||
super().__init__()
|
||||
|
||||
async def _aretrieve(self, query: QueryType, **kwargs):
|
||||
"""
|
||||
Asynchronously retrieves and aggregates search results from all configured retrievers.
|
||||
|
||||
This method queries each retriever in the `retrievers` list with the given query and
|
||||
additional keyword arguments. It then combines the results, ensuring that each node is
|
||||
unique, based on the node's ID.
|
||||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
nodes = await retriever.aretrieve(query, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
||||
# combine all nodes
|
||||
result = []
|
||||
node_ids = set()
|
||||
for n in all_nodes:
|
||||
if n.node.node_id not in node_ids:
|
||||
result.append(n)
|
||||
node_ids.add(n.node.node_id)
|
||||
return result
|
||||
|
|
@ -11,7 +11,6 @@ from typing import Optional
|
|||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
|
@ -27,7 +26,7 @@ class Sales(Role):
|
|||
"delivered with the professionalism and courtesy expected of a seasoned sales guide."
|
||||
)
|
||||
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
store: Optional[object] = Field(default=None, exclude=True) # must inplement tools.SearchInterface
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_stroe(self):
|
||||
|
|
|
|||
|
|
@ -30,3 +30,8 @@ class WebBrowserEngineType(Enum):
|
|||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
||||
|
||||
class SearchInterface:
|
||||
async def asearch(self, *args, **kwargs):
|
||||
...
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def mock_openai_embed_documents(self, texts: list[str], chunk_size: Optional[int
|
|||
async def test_search_json(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -39,7 +39,7 @@ async def test_search_json(mocker):
|
|||
async def test_search_xlsx(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
@ -50,7 +50,7 @@ async def test_search_xlsx(mocker):
|
|||
async def test_write(mocker):
|
||||
mocker.patch("langchain_community.embeddings.openai.OpenAIEmbeddings.embed_documents", mock_openai_embed_documents)
|
||||
|
||||
store = FaissStore(EXAMPLE_PATH / "example.xlsx", meta_col="Answer", content_col="Question")
|
||||
store = FaissStore(EXAMPLE_PATH / "data/example.xlsx", meta_col="Answer", content_col="Question")
|
||||
_faiss_store = store.write()
|
||||
assert _faiss_store.storage_context.docstore
|
||||
assert _faiss_store.storage_context.vector_store.client
|
||||
|
|
|
|||
67
tests/metagpt/rag/engine/test_simple.py
Normal file
67
tests/metagpt/rag/engine/test_simple.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.rag import SimpleEngine
|
||||
|
||||
|
||||
class TestSimpleEngineFromDocs:
|
||||
def test_from_docs(self, mocker):
|
||||
# Mock dependencies
|
||||
mock_simple_directory_reader = mocker.patch("metagpt.rag.engines.simple.SimpleDirectoryReader")
|
||||
mock_simple_directory_reader.return_value.load_data.return_value = ["document1", "document2"]
|
||||
|
||||
mock_service_context = mocker.patch("metagpt.rag.engines.simple.ServiceContext.from_defaults")
|
||||
mock_vector_store_index = mocker.patch("metagpt.rag.engines.simple.VectorStoreIndex.from_documents")
|
||||
mock_vector_index_retriever = mocker.patch("metagpt.rag.engines.simple.VectorIndexRetriever")
|
||||
|
||||
# Setup
|
||||
input_dir = "test_dir"
|
||||
input_files = ["test_file1", "test_file2"]
|
||||
embed_model = mocker.MagicMock()
|
||||
llm = mocker.MagicMock()
|
||||
chunk_size = 100
|
||||
chunk_overlap = 10
|
||||
similarity_top_k = 5
|
||||
|
||||
# Execute
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_dir=input_dir,
|
||||
input_files=input_files,
|
||||
embed_model=embed_model,
|
||||
llm=llm,
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
similarity_top_k=similarity_top_k,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
mock_simple_directory_reader.assert_called_once_with(input_dir=input_dir, input_files=input_files)
|
||||
mock_service_context.assert_called_once_with(
|
||||
embed_model=embed_model, chunk_size=chunk_size, chunk_overlap=chunk_overlap, llm=llm
|
||||
)
|
||||
mock_vector_store_index.assert_called_once_with(
|
||||
["document1", "document2"], service_context=mock_service_context.return_value
|
||||
)
|
||||
mock_vector_index_retriever.assert_called_once_with(
|
||||
index=mock_vector_store_index.return_value, similarity_top_k=similarity_top_k
|
||||
)
|
||||
assert isinstance(engine, SimpleEngine)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_asearch_calls_aquery(self, mocker):
|
||||
# Mock
|
||||
test_query = "test query"
|
||||
expected_result = "expected result"
|
||||
mock_aquery = AsyncMock(return_value=expected_result)
|
||||
|
||||
# Setup
|
||||
engine = SimpleEngine(retriever=mocker.MagicMock())
|
||||
engine.aquery = mock_aquery
|
||||
|
||||
# Execute
|
||||
result = await engine.asearch(test_query)
|
||||
|
||||
# Assertions
|
||||
mock_aquery.assert_called_once_with(test_query)
|
||||
assert result == expected_result
|
||||
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
39
tests/metagpt/rag/retrievers/test_hybrid_retriever.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from llama_index.schema import NodeWithScore, TextNode
|
||||
|
||||
from metagpt.rag.retrievers import SimpleHybridRetriever
|
||||
|
||||
|
||||
class TestSimpleHybridRetriever:
|
||||
@pytest.mark.asyncio
|
||||
async def test_aretrieve(self):
|
||||
question = "test query"
|
||||
|
||||
# Create mock retrievers
|
||||
mock_retriever1 = AsyncMock()
|
||||
mock_retriever1.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="1"), score=1.0),
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
]
|
||||
|
||||
mock_retriever2 = AsyncMock()
|
||||
mock_retriever2.aretrieve.return_value = [
|
||||
NodeWithScore(node=TextNode(id_="2"), score=0.95),
|
||||
NodeWithScore(node=TextNode(id_="3"), score=0.8),
|
||||
]
|
||||
|
||||
# Instantiate the SimpleHybridRetriever with the mock retrievers
|
||||
hybrid_retriever = SimpleHybridRetriever(mock_retriever1, mock_retriever2)
|
||||
|
||||
# Call the _aretrieve method
|
||||
results = await hybrid_retriever._aretrieve(question)
|
||||
|
||||
# Check if the results are as expected
|
||||
assert len(results) == 3 # Should be 3 unique nodes
|
||||
assert set(node.node.node_id for node in results) == {"1", "2", "3"}
|
||||
|
||||
# Check if the scores are correct (assuming you want the highest score)
|
||||
node_scores = {node.node.node_id: node.score for node in results}
|
||||
assert node_scores["2"] == 0.95
|
||||
Loading…
Add table
Add a link
Reference in a new issue