mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
upgrade llama-index to v0.10
This commit is contained in:
parent
19a9a98c0b
commit
c02dc5cea8
10 changed files with 32 additions and 25 deletions
|
|
@ -3,7 +3,7 @@ import asyncio
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.const import EXAMPLE_PATH
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import (
|
||||
BM25RetrieverConfig,
|
||||
|
|
@ -11,9 +11,14 @@ from metagpt.rag.schema import (
|
|||
LLMRankerConfig,
|
||||
)
|
||||
|
||||
DOC_PATH = EXAMPLE_PATH / "data/rag_writer.txt"
|
||||
DOC_PATH = EXAMPLE_DATA_PATH / "rag/writer.txt"
|
||||
QUESTION = "What are key qualities to be a good writer?"
|
||||
|
||||
TRAVEL_DOC_PATH = EXAMPLE_DATA_PATH / "rag/travel.txt"
|
||||
TRAVEL_QUESTION = "What does Bojan like?"
|
||||
|
||||
LLM_TIP = "If you not sure, just answer I don't know"
|
||||
|
||||
|
||||
class RAGExample:
|
||||
"""Show how to use RAG."""
|
||||
|
|
@ -63,8 +68,8 @@ class RAGExample:
|
|||
"""
|
||||
self._print_title("RAG Add Docs")
|
||||
|
||||
travel_question = "What does Bojan like? If you not sure, just answer I don't know"
|
||||
travel_filepath = EXAMPLE_PATH / "data/rag_travel.txt"
|
||||
travel_question = f"{TRAVEL_QUESTION}{LLM_TIP}"
|
||||
travel_filepath = TRAVEL_DOC_PATH
|
||||
|
||||
print("[Before add docs]")
|
||||
await self.rag_pipeline(question=travel_question, print_title=False)
|
||||
|
|
@ -83,7 +88,7 @@ class RAGExample:
|
|||
0. 100m Sprin..., 10.0
|
||||
|
||||
[Object Detail]
|
||||
{'name': 'foo', 'goal': 'Win The Game', 'tool': 'Red Bull Energy Drink'}
|
||||
{'name': 'Mike', 'goal': 'Win The 100-meter Sprint', 'tool': 'Red Bull Energy Drink'}
|
||||
"""
|
||||
|
||||
self._print_title("RAG Add Objs")
|
||||
|
|
@ -92,21 +97,21 @@ class RAGExample:
|
|||
"""Player"""
|
||||
|
||||
name: str = ""
|
||||
goal: str = "Win The Game"
|
||||
goal: str = "Win The 100-meter Sprint"
|
||||
tool: str = "Red Bull Energy Drink"
|
||||
|
||||
def rag_key(self) -> str:
|
||||
"""For search"""
|
||||
return self.goal
|
||||
|
||||
foo = Player(name="foo")
|
||||
question = f"{foo.rag_key()}"
|
||||
player = Player(name="Mike")
|
||||
question = f"{player.rag_key()}{LLM_TIP}"
|
||||
|
||||
print("[Before add objs]")
|
||||
await self._retrieve_and_print(question)
|
||||
|
||||
print("[After add objs]")
|
||||
self.engine.add_objs([foo])
|
||||
self.engine.add_objs([player])
|
||||
nodes = await self._retrieve_and_print(question)
|
||||
|
||||
print("[Object Detail]")
|
||||
|
|
|
|||
|
|
@ -6,23 +6,14 @@
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.const import DATA_PATH, EXAMPLE_PATH
|
||||
from metagpt.const import EXAMPLE_DATA_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Sales
|
||||
|
||||
|
||||
def get_store():
|
||||
llm = config.get_openai_llm()
|
||||
embedding = OpenAIEmbedding(api_key=llm.api_key, api_base=llm.base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
async def search():
|
||||
store = FaissStore(EXAMPLE_PATH / "example.json")
|
||||
store = FaissStore(EXAMPLE_DATA_PATH / "search_kb/example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
query = "Which facial cleanser is good for oily skin?"
|
||||
result = await role.run(query)
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT
|
|||
DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace"
|
||||
|
||||
EXAMPLE_PATH = METAGPT_ROOT / "examples"
|
||||
EXAMPLE_DATA_PATH = EXAMPLE_PATH / "data"
|
||||
DATA_PATH = METAGPT_ROOT / "data"
|
||||
TEST_DATA_PATH = METAGPT_ROOT / "tests/data"
|
||||
RESEARCH_PATH = DATA_PATH / "research"
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class FaissStore(LocalStore):
|
|||
return None
|
||||
vector_store = FaissVectorStore.from_persist_dir(persist_dir=self.cache_dir)
|
||||
storage_context = StorageContext.from_defaults(persist_dir=self.cache_dir, vector_store=vector_store)
|
||||
index = load_index_from_storage(storage_context)
|
||||
index = load_index_from_storage(storage_context, embed_model=self.embedding)
|
||||
|
||||
return index
|
||||
|
||||
|
|
@ -54,7 +54,9 @@ class FaissStore(LocalStore):
|
|||
# doc_store.add_documents(nodes)
|
||||
vector_store = FaissVectorStore(faiss_index=faiss.IndexFlatL2(1536))
|
||||
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
||||
index = VectorStoreIndex.from_documents(documents=documents, storage_context=storage_context)
|
||||
index = VectorStoreIndex.from_documents(
|
||||
documents=documents, storage_context=storage_context, embed_model=self.embedding
|
||||
)
|
||||
|
||||
return index
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from llama_index.core.response_synthesizers import (
|
|||
)
|
||||
from llama_index.core.retrievers import BaseRetriever
|
||||
from llama_index.core.schema import (
|
||||
BaseNode,
|
||||
NodeWithScore,
|
||||
QueryBundle,
|
||||
QueryType,
|
||||
|
|
@ -110,15 +111,22 @@ class SimpleEngine(RetrieverQueryEngine):
|
|||
|
||||
documents = SimpleDirectoryReader(input_files=input_files).load_data()
|
||||
nodes = run_transformations(documents, transformations=self.index._transformations)
|
||||
self.retriever.add_nodes(nodes)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def add_objs(self, objs: list[RAGObject]):
|
||||
"""Adds objects to the retriever, storing each object's original form in metadata for future reference."""
|
||||
self._ensure_retriever_modifiable()
|
||||
|
||||
nodes = [TextNode(text=obj.rag_key(), metadata={"obj": obj.model_dump()}) for obj in objs]
|
||||
self.retriever.add_nodes(nodes)
|
||||
self._save_nodes(nodes)
|
||||
|
||||
def _ensure_retriever_modifiable(self):
|
||||
if not isinstance(self.retriever, ModifiableRAGRetriever):
|
||||
raise TypeError(f"the retriever is not modifiable: {type(self.retriever)}")
|
||||
|
||||
def _save_nodes(self, nodes: list[BaseNode]):
|
||||
# for search in memory
|
||||
self.retriever.add_nodes(nodes)
|
||||
|
||||
# for persist
|
||||
self.index.insert_nodes(nodes)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ class SimpleHybridRetriever(RAGRetriever):
|
|||
"""
|
||||
all_nodes = []
|
||||
for retriever in self.retrievers:
|
||||
# 防止retriever可能改变query的属性
|
||||
# Prevent retriever changing query
|
||||
query_copy = copy.deepcopy(query)
|
||||
nodes = await retriever.aretrieve(query_copy, **kwargs)
|
||||
all_nodes.extend(nodes)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue