mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
upgrade langchain and simplify faiss load/save
This commit is contained in:
parent
49377c9db0
commit
322ac4aa40
6 changed files with 34 additions and 38 deletions
|
|
@ -5,12 +5,13 @@
|
|||
"""
|
||||
import asyncio
|
||||
|
||||
from metagpt.actions import Action
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.document_store import FaissStore
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Sales
|
||||
from metagpt.schema import Message
|
||||
|
||||
""" example.json, e.g.
|
||||
[
|
||||
|
|
@ -26,14 +27,15 @@ from metagpt.schema import Message
|
|||
"""
|
||||
|
||||
|
||||
def get_store():
|
||||
embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url)
|
||||
return FaissStore(DATA_PATH / "example.json", embedding=embedding)
|
||||
|
||||
|
||||
async def search():
|
||||
store = FaissStore(DATA_PATH / "example.json")
|
||||
role = Sales(profile="Sales", store=store)
|
||||
role._watch({Action})
|
||||
queries = [
|
||||
Message(content="Which facial cleanser is good for oily skin?", cause_by=Action),
|
||||
Message(content="Is L'Oreal good to use?", cause_by=Action),
|
||||
]
|
||||
role = Sales(profile="Sales", store=get_store())
|
||||
queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"]
|
||||
|
||||
for query in queries:
|
||||
logger.info(f"User: {query}")
|
||||
result = await role.run(query)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_google.py
|
||||
"""
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import Field, root_validator
|
||||
|
|
@ -111,7 +111,7 @@ class SearchAndSummarize(Action):
|
|||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[str] = None
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
|
||||
result = ""
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class LocalStore(BaseStore, ABC):
|
|||
raise FileNotFoundError
|
||||
self.config = Config()
|
||||
self.raw_data_path = raw_data_path
|
||||
self.fname = self.raw_data_path.name.split(".")[0]
|
||||
if not cache_dir:
|
||||
cache_dir = raw_data_path.parent
|
||||
self.cache_dir = cache_dir
|
||||
|
|
@ -40,10 +41,9 @@ class LocalStore(BaseStore, ABC):
|
|||
if not self.store:
|
||||
self.store = self.write()
|
||||
|
||||
def _get_index_and_store_fname(self):
|
||||
fname = self.raw_data_path.name.split(".")[0]
|
||||
index_file = self.cache_dir / f"{fname}.index"
|
||||
store_file = self.cache_dir / f"{fname}.pkl"
|
||||
def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"):
|
||||
index_file = self.cache_dir / f"{self.fname}{index_ext}"
|
||||
store_file = self.cache_dir / f"{self.fname}{pkl_ext}"
|
||||
return index_file, store_file
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -6,13 +6,12 @@
|
|||
@File : faiss_store.py
|
||||
"""
|
||||
import asyncio
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import faiss
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain_core.embeddings import Embeddings
|
||||
|
||||
from metagpt.const import DATA_PATH
|
||||
from metagpt.document import IndexableDocument
|
||||
|
|
@ -21,35 +20,29 @@ from metagpt.logs import logger
|
|||
|
||||
|
||||
class FaissStore(LocalStore):
|
||||
def __init__(self, raw_data_path: Path, cache_dir=None, meta_col="source", content_col="output"):
|
||||
def __init__(
|
||||
self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None
|
||||
):
|
||||
self.meta_col = meta_col
|
||||
self.content_col = content_col
|
||||
super().__init__(raw_data_path, cache_dir)
|
||||
self.embedding = embedding or OpenAIEmbeddings()
|
||||
super().__init__(raw_data, cache_dir)
|
||||
|
||||
def _load(self) -> Optional["FaissStore"]:
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss
|
||||
|
||||
if not (index_file.exists() and store_file.exists()):
|
||||
logger.info("Missing at least one of index_file/store_file, load failed and return None")
|
||||
return None
|
||||
index = faiss.read_index(str(index_file))
|
||||
with open(str(store_file), "rb") as f:
|
||||
store = pickle.load(f)
|
||||
store.index = index
|
||||
return store
|
||||
|
||||
return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname)
|
||||
|
||||
def _write(self, docs, metadatas):
|
||||
store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_version="2020-11-07"), metadatas=metadatas)
|
||||
store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas)
|
||||
return store
|
||||
|
||||
def persist(self):
|
||||
index_file, store_file = self._get_index_and_store_fname()
|
||||
store = self.store
|
||||
index = self.store.index
|
||||
faiss.write_index(store.index, str(index_file))
|
||||
store.index = None
|
||||
with open(store_file, "wb") as f:
|
||||
pickle.dump(store, f)
|
||||
store.index = index
|
||||
self.store.save_local(self.raw_data_path.parent, self.fname)
|
||||
|
||||
def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs):
|
||||
rsp = self.store.similarity_search(query, k=k, **kwargs)
|
||||
|
|
|
|||
|
|
@ -6,9 +6,9 @@
|
|||
@File : sales.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from metagpt.actions import SearchAndSummarize
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
|
@ -23,7 +23,7 @@ class Sales(Role):
|
|||
"but pretend to be what I know. Note that each of my replies will be replied in the tone of a "
|
||||
"professional guide"
|
||||
|
||||
store: Optional[str] = None
|
||||
store: Optional[Any] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -35,3 +35,4 @@ class Sales(Role):
|
|||
else:
|
||||
action = SearchAndSummarize()
|
||||
self._init_actions([action])
|
||||
self._watch([UserRequirement])
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ typer
|
|||
# godot==0.1.1
|
||||
# google_api_python_client==2.93.0
|
||||
lancedb==0.1.16
|
||||
langchain==0.0.231
|
||||
langchain==0.0.352
|
||||
loguru==0.6.0
|
||||
meilisearch==0.21.0
|
||||
numpy==1.24.3
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue