diff --git a/examples/data/rag_travel.txt b/examples/data/rag/travel.txt similarity index 100% rename from examples/data/rag_travel.txt rename to examples/data/rag/travel.txt diff --git a/examples/data/rag_writer.txt b/examples/data/rag/writer.txt similarity index 100% rename from examples/data/rag_writer.txt rename to examples/data/rag/writer.txt diff --git a/examples/data/example.json b/examples/data/search_kb/example.json similarity index 100% rename from examples/data/example.json rename to examples/data/search_kb/example.json diff --git a/examples/data/example.xlsx b/examples/data/search_kb/example.xlsx similarity index 100% rename from examples/data/example.xlsx rename to examples/data/search_kb/example.xlsx diff --git a/examples/rag_pipeline.py b/examples/rag_pipeline.py index 70c592a1e..6e8e5a2cc 100644 --- a/examples/rag_pipeline.py +++ b/examples/rag_pipeline.py @@ -3,7 +3,7 @@ import asyncio from pydantic import BaseModel -from metagpt.const import EXAMPLE_PATH +from metagpt.const import EXAMPLE_DATA_PATH from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ( BM25RetrieverConfig, @@ -11,9 +11,14 @@ from metagpt.rag.schema import ( LLMRankerConfig, ) -DOC_PATH = EXAMPLE_PATH / "data/rag_writer.txt" +DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt" QUESTION = "What are key qualities to be a good writer?" +TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt" +TRAVEL_QUESTION = "What does Bojan like?" + +LLM_TIP = "If you not sure, just answer I don't know" + class RAGExample: """Show how to use RAG.""" @@ -63,8 +68,8 @@ class RAGExample: """ self._print_title("RAG Add Docs") - travel_question = "What does Bojan like? If you not sure, just answer I don't know" - travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt" + travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}" + travel_filepath = TRAVEL_DOC_PATH print("[Before add docs]") await self.rag_pipeline(question=travel_question, print_title=False) @@ -83,7 +88,7 @@ class RAGExample: 0. 100m Sprin..., 10.0 [Object Detail] - {'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'} + {'name': 'Mike', 'goal': 'Win The 100-meter Sprint', 'tool': 'Red Bull Energy Drink'} """ self._print_title("RAG Add Objs") @@ -92,21 +97,21 @@ class RAGExample: """Player""" name: str = "" - goal: str = "Win The Game" + goal: str = "Win The 100-meter Sprint" tool: str = "Red Bull Energy Drink" def rag_key(self) -> str: """For search""" return self.goal - foo = Player(name="foo") - question = f"{foo.rag_key()}" + player = Player(name="Mike") + question = f"{player.rag_key()}{LLM_TIP}" print("[Before add objs]") await self._retrieve_and_print(question) print("[After add objs]") - self.engine.add_objs([foo]) + self.engine.add_objs([player]) nodes = await self._retrieve_and_print(question) print("[Object Detail]") diff --git a/examples/search_kb.py b/examples/search_kb.py index ec234b7e9..c52977b43 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -6,23 +6,14 @@ """ import asyncio -from llama_index.embeddings import OpenAIEmbedding - -from metagpt.config2 import config -from metagpt.const import DATA_PATH, EXAMPLE_PATH +from metagpt.const import EXAMPLE_DATA_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger from metagpt.roles import Sales -def get_store(): - llm = config.get_openai_llm() - embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url) - return FaissStore(DATA_PATH / "example.json", embedding=embedding) - - async def search(): - store = FaissStore(EXAMPLE_PATH / "example.json") + store = FaissStore(EXAMPLE_DATA_PATH / "search_kb/example.json") role = Sales(profile="Sales", store=store) query = "Which facial cleanser is good for oily skin?" result = await role.run(query) diff --git a/metagpt/const.py b/metagpt/const.py index a5e3ea9c2..6dbbfe0c1 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" EXAMPLE_PATH = METAGPT_ROOT / "examples" +EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data" DATA_PATH = METAGPT_ROOT / "data" TEST_DATA_PATH = METAGPT_ROOT / "tests/data" RESEARCH_PATH = DATA_PATH / "research" diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index f8ce05072..25d1211b3 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -40,7 +40,7 @@ class FaissStore(LocalStore): return None vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir) storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store) - index = load_index_from_storage(storage_context) + index = load_index_from_storage(storage_context, embed_model=self.embedding) return index @@ -54,7 +54,9 @@ class FaissStore(LocalStore): # doc_store.add_documents(nodes) vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536)) storage_context = StorageContext.from_defaults(vector_store=vector_store) - index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context) + index = VectorStoreIndex.from_documents( + documents=documents, storage_context=storage_context, embed_model=self.embedding + ) return index diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index ca09f1059..5f81f6309 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -18,6 +18,7 @@ from llama_index.core.response_synthesizers import ( ) from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import ( + BaseNode, NodeWithScore, QueryBundle, QueryType, @@ -110,15 +111,22 @@ class SimpleEngine(RetrieverQueryEngine): documents = SimpleDirectoryReader(input_files=input_files).load_data() nodes = run_transformations(documents, transformations=self.index._transformations) - self.retriever.add_nodes(nodes) + self._save_nodes(nodes) def add_objs(self, objs: list[RAGObject]): """Adds objects to the retriever, storing each object's original form in metadata for future reference.""" self._ensure_retriever_modifiable() nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs] - self.retriever.add_nodes(nodes) + self._save_nodes(nodes) def _ensure_retriever_modifiable(self): if not isinstance(self.retriever, ModifiableRAGRetriever): raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}") + + def _save_nodes(self, nodes: list[BaseNode]): + # for search in memory + self.retriever.add_nodes(nodes) + + # for persist + self.index.insert_nodes(nodes) diff --git a/metagpt/rag/retrievers/hybrid_retriever.py b/metagpt/rag/retrievers/hybrid_retriever.py index 3074a4053..1a752855a 100644 --- a/metagpt/rag/retrievers/hybrid_retriever.py +++ b/metagpt/rag/retrievers/hybrid_retriever.py @@ -22,7 +22,7 @@ class SimpleHybridRetriever(RAGRetriever): """ all_nodes = [] for retriever in self.retrievers: - # 防止retriever可能改变query的属性 + # Prevent retriever changing query query_copy = copy.deepcopy(query) nodes = await retriever.aretrieve(query_copy, **kwargs) all_nodes.extend(nodes)