diff --git a/examples/search_kb.py b/examples/search_kb.py index 5d61bbe02..c70cad2fd 100644 --- a/examples/search_kb.py +++ b/examples/search_kb.py @@ -5,12 +5,13 @@ """ import asyncio -from metagpt.actions import Action +from langchain.embeddings import OpenAIEmbeddings + +from metagpt.config import CONFIG from metagpt.const import DATA_PATH from metagpt.document_store import FaissStore from metagpt.logs import logger from metagpt.roles import Sales -from metagpt.schema import Message """ example.json, e.g. [ @@ -26,14 +27,15 @@ from metagpt.schema import Message """ +def get_store(): + embedding = OpenAIEmbeddings(openai_api_key=CONFIG.openai_api_key, openai_api_base=CONFIG.openai_base_url) + return FaissStore(DATA_PATH / "example.json", embedding=embedding) + + async def search(): - store = FaissStore(DATA_PATH / "example.json") - role = Sales(profile="Sales", store=store) - role._watch({Action}) - queries = [ - Message(content="Which facial cleanser is good for oily skin?", cause_by=Action), - Message(content="Is L'Oreal good to use?", cause_by=Action), - ] + role = Sales(profile="Sales", store=get_store()) + queries = ["Which facial cleanser is good for oily skin?", "Is L'Oreal good to use?"] + for query in queries: logger.info(f"User: {query}") result = await role.run(query) diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index bc1319291..25af21795 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -5,7 +5,7 @@ @Author : alexanderwu @File : search_google.py """ -from typing import Optional +from typing import Any, Optional import pydantic from pydantic import Field, root_validator @@ -111,7 +111,7 @@ class SearchAndSummarize(Action): llm: BaseGPTAPI = Field(default_factory=LLM) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine - search_func: Optional[str] = None + search_func: Optional[Any] = None search_engine: SearchEngine = None result = "" diff --git a/metagpt/document_store/base_store.py b/metagpt/document_store/base_store.py index 5de377d21..b719d1083 100644 --- a/metagpt/document_store/base_store.py +++ b/metagpt/document_store/base_store.py @@ -33,6 +33,7 @@ class LocalStore(BaseStore, ABC): raise FileNotFoundError self.config = Config() self.raw_data_path = raw_data_path + self.fname = self.raw_data_path.stem if not cache_dir: cache_dir = raw_data_path.parent self.cache_dir = cache_dir @@ -40,10 +41,9 @@ class LocalStore(BaseStore, ABC): if not self.store: self.store = self.write() - def _get_index_and_store_fname(self): - fname = self.raw_data_path.name.split(".")[0] - index_file = self.cache_dir / f"{fname}.index" - store_file = self.cache_dir / f"{fname}.pkl" + def _get_index_and_store_fname(self, index_ext=".index", pkl_ext=".pkl"): + index_file = self.cache_dir / f"{self.fname}{index_ext}" + store_file = self.cache_dir / f"{self.fname}{pkl_ext}" return index_file, store_file @abstractmethod diff --git a/metagpt/document_store/faiss_store.py b/metagpt/document_store/faiss_store.py index b1faa3538..320e7518f 100644 --- a/metagpt/document_store/faiss_store.py +++ b/metagpt/document_store/faiss_store.py @@ -6,13 +6,12 @@ @File : faiss_store.py """ import asyncio -import pickle from pathlib import Path from typing import Optional -import faiss from langchain.embeddings import OpenAIEmbeddings from langchain.vectorstores import FAISS +from langchain_core.embeddings import Embeddings from metagpt.const import DATA_PATH from metagpt.document import IndexableDocument @@ -21,35 +20,29 @@ from metagpt.logs import logger class FaissStore(LocalStore): - def __init__(self, raw_data_path: Path, cache_dir=None, meta_col="source", content_col="output"): + def __init__( + self, raw_data: Path, cache_dir=None, meta_col="source", content_col="output", embedding: Embeddings = None + ): self.meta_col = meta_col self.content_col = content_col - super().__init__(raw_data_path, cache_dir) + self.embedding = embedding or OpenAIEmbeddings() + super().__init__(raw_data, cache_dir) def _load(self) -> Optional["FaissStore"]: - index_file, store_file = self._get_index_and_store_fname() + index_file, store_file = self._get_index_and_store_fname(index_ext=".faiss") # langchain FAISS using .faiss + if not (index_file.exists() and store_file.exists()): logger.info("Missing at least one of index_file/store_file, load failed and return None") return None - index = faiss.read_index(str(index_file)) - with open(str(store_file), "rb") as f: - store = pickle.load(f) - store.index = index - return store + + return FAISS.load_local(self.raw_data_path.parent, self.embedding, self.fname) def _write(self, docs, metadatas): - store = FAISS.from_texts(docs, OpenAIEmbeddings(openai_api_version="2020-11-07"), metadatas=metadatas) + store = FAISS.from_texts(docs, self.embedding, metadatas=metadatas) return store def persist(self): - index_file, store_file = self._get_index_and_store_fname() - store = self.store - index = self.store.index - faiss.write_index(store.index, str(index_file)) - store.index = None - with open(store_file, "wb") as f: - pickle.dump(store, f) - store.index = index + self.store.save_local(self.raw_data_path.parent, self.fname) def search(self, query, expand_cols=False, sep="\n", *args, k=5, **kwargs): rsp = self.store.similarity_search(query, k=k, **kwargs) diff --git a/metagpt/roles/sales.py b/metagpt/roles/sales.py index 76abf10f3..1ef93f6f3 100644 --- a/metagpt/roles/sales.py +++ b/metagpt/roles/sales.py @@ -8,7 +8,8 @@ from typing import Optional -from metagpt.actions import SearchAndSummarize +from metagpt.actions import SearchAndSummarize, UserRequirement +from metagpt.document_store.base_store import BaseStore from metagpt.roles import Role from metagpt.tools import SearchEngineType @@ -23,7 +24,7 @@ class Sales(Role): "but pretend to be what I know. Note that each of my replies will be replied in the tone of a " "professional guide" - store: Optional[str] = None + store: Optional[BaseStore] = None def __init__(self, **kwargs): super().__init__(**kwargs) @@ -35,3 +36,4 @@ class Sales(Role): else: action = SearchAndSummarize() self._init_actions([action]) + self._watch([UserRequirement]) diff --git a/requirements.txt b/requirements.txt index eaff5c4b2..9954a9941 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ typer # godot==0.1.1 # google_api_python_client==2.93.0 lancedb==0.1.16 -langchain==0.0.231 +langchain==0.0.352 loguru==0.6.0 meilisearch==0.21.0 numpy==1.24.3