mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-05 22:02:38 +02:00
rag add es
This commit is contained in:
parent
8c218a1e55
commit
191a86f93e
7 changed files with 157 additions and 31 deletions
|
|
@ -1,6 +1,7 @@
|
|||
"""RAG pipeline"""
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -11,6 +12,9 @@ from metagpt.rag.schema import (
|
|||
BM25RetrieverConfig,
|
||||
ChromaIndexConfig,
|
||||
ChromaRetrieverConfig,
|
||||
ElasticsearchIndexConfig,
|
||||
ElasticsearchRetrieverConfig,
|
||||
ElasticsearchStoreConfig,
|
||||
FAISSRetrieverConfig,
|
||||
LLMRankerConfig,
|
||||
)
|
||||
|
|
@ -24,6 +28,17 @@ TRAVEL_QUESTION = "What does Bob like?"
|
|||
LLM_TIP = "If you not sure, just answer I don't know."
|
||||
|
||||
|
||||
def catch_exception(func):
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"{func.__name__} exception: {e}")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class Player(BaseModel):
|
||||
"""To demonstrate rag add objs."""
|
||||
|
||||
|
|
@ -39,12 +54,22 @@ class Player(BaseModel):
|
|||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
def __init__(self, engine: SimpleEngine = None):
|
||||
self._engine = engine
|
||||
|
||||
@property
|
||||
def engine(self):
|
||||
if not self._engine:
|
||||
self._engine = SimpleEngine.from_docs(
|
||||
input_files=[DOC_PATH],
|
||||
retriever_configs=[FAISSRetrieverConfig(), BM25RetrieverConfig()],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
return self._engine
|
||||
|
||||
@engine.setter
|
||||
def engine(self, value: SimpleEngine):
|
||||
self._engine = value
|
||||
|
||||
async def run_pipeline(self, question=QUESTION, print_title=True):
|
||||
"""This example run rag pipeline, use faiss&bm25 retriever and llm ranker, will print something like:
|
||||
|
|
@ -97,6 +122,7 @@ class RAGExample:
|
|||
self.engine.add_docs([travel_filepath])
|
||||
await self.run_pipeline(question=travel_question, print_title=False)
|
||||
|
||||
@catch_exception
|
||||
async def add_objects(self, print_title=True):
|
||||
"""This example show how to add objects.
|
||||
|
||||
|
|
@ -154,20 +180,43 @@ class RAGExample:
|
|||
"""
|
||||
self._print_title("Init And Query ChromaDB")
|
||||
|
||||
# save index
|
||||
# 1.save index
|
||||
output_dir = DATA_PATH / "rag"
|
||||
SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ChromaRetrieverConfig(persist_path=output_dir)],
|
||||
)
|
||||
|
||||
# load index
|
||||
engine = SimpleEngine.from_index(
|
||||
index_config=ChromaIndexConfig(persist_path=output_dir),
|
||||
# 2.load index
|
||||
engine = SimpleEngine.from_index(index_config=ChromaIndexConfig(persist_path=output_dir))
|
||||
|
||||
# 3.query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@catch_exception
|
||||
async def init_and_query_es(self):
|
||||
"""This example show how to use es. how to save and load index. will print something like:
|
||||
|
||||
Query Result:
|
||||
Bob likes traveling.
|
||||
|
||||
If `Unclosed client session`, it's llamaindex elasticsearch problem, maybe fixed later.
|
||||
"""
|
||||
self._print_title("Init And Query Elasticsearch")
|
||||
|
||||
# 1.create es index and save docs
|
||||
store_config = ElasticsearchStoreConfig(index_name="travel", es_url="http://127.0.0.1:9200")
|
||||
engine = SimpleEngine.from_docs(
|
||||
input_files=[TRAVEL_DOC_PATH],
|
||||
retriever_configs=[ElasticsearchRetrieverConfig(store_config=store_config)],
|
||||
)
|
||||
|
||||
# query
|
||||
answer = engine.query(TRAVEL_QUESTION)
|
||||
# 2.load index
|
||||
engine = SimpleEngine.from_index(index_config=ElasticsearchIndexConfig(store_config=store_config))
|
||||
|
||||
# 3.query
|
||||
answer = await engine.aquery(TRAVEL_QUESTION)
|
||||
self._print_query_result(answer)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -205,6 +254,7 @@ async def main():
|
|||
await e.add_objects()
|
||||
await e.init_objects()
|
||||
await e.init_and_query_chromadb()
|
||||
await e.init_and_query_es()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue