From 6e6e91660db88230fda8a25f947fc36fb6c14d28 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 17 Aug 2023 17:37:20 +0800 Subject: [PATCH] Make the SearchEngine more user-friendly. --- .dockerignore | 7 +++ config/config.yaml | 5 ++ metagpt/config.py | 20 +++--- metagpt/tools/__init__.py | 12 ++-- metagpt/tools/search_engine_ddg.py | 27 ++++----- metagpt/tools/search_engine_googleapi.py | 77 +++++++++++++++--------- metagpt/tools/search_engine_serpapi.py | 32 +++++----- metagpt/tools/search_engine_serper.py | 61 ++++++++----------- requirements.txt | 1 - setup.py | 2 + 10 files changed, 133 insertions(+), 111 deletions(-) create mode 100644 .dockerignore diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..2968dd34d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,7 @@ +workspace +tmp +build +workspace +dist +data +geckodriver.log diff --git a/config/config.yaml b/config/config.yaml index 303f4824b..274cdf469 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -25,12 +25,17 @@ RPM: 10 #### for Search +## Supported values: serpapi/google/serper/ddg +#SEARCH_ENGINE: serpapi + ## Visit https://serpapi.com/ to get key. #SERPAPI_API_KEY: "YOUR_API_KEY" + ## Visit https://console.cloud.google.com/apis/credentials to get key. #GOOGLE_API_KEY: "YOUR_API_KEY" ## Visit https://programmablesearchengine.google.com/controlpanel/create to get id. #GOOGLE_CSE_ID: "YOUR_CSE_ID" + ## Visit https://serper.dev/ to get key. #SERPER_API_KEY: "YOUR_API_KEY" diff --git a/metagpt/config.py b/metagpt/config.py index d7339caf5..21f180455 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -45,8 +45,9 @@ class Config(metaclass=Singleton): self.global_proxy = self._get("GLOBAL_PROXY") self.openai_api_key = self._get("OPENAI_API_KEY") self.anthropic_api_key = self._get("Anthropic_API_KEY") - if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) \ - and (not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key): + if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and ( + not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key + ): raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY first") self.openai_api_base = self._get("OPENAI_API_BASE") if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base: @@ -62,26 +63,25 @@ class Config(metaclass=Singleton): self.max_tokens_rsp = self._get("MAX_TOKENS", 2048) self.deployment_id = self._get("DEPLOYMENT_ID") - self.claude_api_key = self._get('Anthropic_API_KEY') + self.claude_api_key = self._get("Anthropic_API_KEY") self.serpapi_api_key = self._get("SERPAPI_API_KEY") self.serper_api_key = self._get("SERPER_API_KEY") self.google_api_key = self._get("GOOGLE_API_KEY") self.google_cse_id = self._get("GOOGLE_CSE_ID") - self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE) - - self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright")) + self.search_engine = SearchEngineType(self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)) + self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", WebBrowserEngineType.PLAYWRIGHT)) self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium") self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome") - self.long_term_memory = self._get('LONG_TERM_MEMORY', False) + self.long_term_memory = self._get("LONG_TERM_MEMORY", False) if self.long_term_memory: logger.warning("LONG_TERM_MEMORY is True") self.max_budget = self._get("MAX_BUDGET", 10.0) self.total_cost = 0.0 - self.puppeteer_config = self._get("PUPPETEER_CONFIG","") - self.mmdc = self._get("MMDC","mmdc") - self.calc_usage = self._get("CALC_USAGE",True) + self.puppeteer_config = self._get("PUPPETEER_CONFIG", "") + self.mmdc = self._get("MMDC", "mmdc") + self.calc_usage = self._get("CALC_USAGE", True) self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY") self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT") diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index e1f921c05..d98087e4b 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -7,15 +7,15 @@ """ -from enum import Enum, auto +from enum import Enum class SearchEngineType(Enum): - SERPAPI_GOOGLE = auto() - DIRECT_GOOGLE = auto() - SERPER_GOOGLE = auto() - DUCK_DUCK_GO = auto() - CUSTOM_ENGINE = auto() + SERPAPI_GOOGLE = "serpapi" + SERPER_GOOGLE = "serper" + DIRECT_GOOGLE = "google" + DUCK_DUCK_GO = "ddg" + CUSTOM_ENGINE = "custom" class WebBrowserEngineType(Enum): diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py index c054afed1..57bc61b82 100644 --- a/metagpt/tools/search_engine_ddg.py +++ b/metagpt/tools/search_engine_ddg.py @@ -7,11 +7,15 @@ import json from concurrent import futures from typing import Literal, overload -from duckduckgo_search import DDGS -from googleapiclient.errors import HttpError +try: + from duckduckgo_search import DDGS +except ImportError: + raise ImportError( + "To use this module, you should have the `duckduckgo_search` Python package installed. " + "You can install it by running the command: `pip install -e.[search-ddg]`" + ) from metagpt.config import CONFIG -from metagpt.logs import logger class DDGAPIWrapper: @@ -19,6 +23,7 @@ class DDGAPIWrapper: To use this module, you should have the `duckduckgo_search` Python package installed. """ + def __init__( self, *, @@ -77,15 +82,8 @@ class DDGAPIWrapper: query, max_results, ) - try: - search_results = await future - # Extract the search result items from the response + search_results = await future - except HttpError as e: - # Handle errors in the API call - logger.exception(f"fail to search {query} for {e}") - search_results = [] - # Return the list of search result URLs if as_string: return json.dumps(search_results, ensure_ascii=False) @@ -93,11 +91,8 @@ class DDGAPIWrapper: def _search_from_ddgs(self, query: str, max_results: int): return [ - { - "link": i["href"], - "snippet": i["body"], - "title": i["title"] - } for (_, i) in zip(range(max_results), self.ddgs.text(query)) + {"link": i["href"], "snippet": i["body"], "title": i["title"]} + for (_, i) in zip(range(max_results), self.ddgs.text(query)) ] diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index c226ca8d2..b9faf2ced 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -5,30 +5,61 @@ from __future__ import annotations import asyncio import json from concurrent import futures +from typing import Optional from urllib.parse import urlparse import httplib2 -from googleapiclient.discovery import build -from googleapiclient.errors import HttpError +from pydantic import BaseModel, validator from metagpt.config import CONFIG from metagpt.logs import logger +try: + from googleapiclient.discovery import build + from googleapiclient.errors import HttpError +except ImportError: + raise ImportError( + "To use this module, you should have the `google-api-python-client` Python package installed. " + "You can install it by running the command: `pip install -e.[search-google]`" + ) -class GoogleAPIWrapper: - """Wrapper around GoogleAPI. - To use this module, you should have the `google-api-python-client` Python package installed - and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See - https://programmablesearchengine.google.com/controlpanel/all. - """ - def __init__( - self, - *, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, - ): - build_kwargs = {"developerKey": CONFIG.google_api_key} +class GoogleAPIWrapper(BaseModel): + google_api_key: Optional[str] = None + google_cse_id: Optional[str] = None + loop: Optional[asyncio.AbstractEventLoop] = None + executor: Optional[futures.Executor] = None + + class Config: + arbitrary_types_allowed = True + + @validator("google_api_key", always=True) + @classmethod + def check_google_api_key(cls, val: str): + val = val or CONFIG.google_api_key + if not val: + raise ValueError( + "To use, make sure you provide the google_api_key when constructing an object. Alternatively, " + "ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain " + "an API key from https://console.cloud.google.com/apis/credentials." + ) + return val + + @validator("google_cse_id", always=True) + @classmethod + def check_google_cse_id(cls, val: str): + val = val or CONFIG.google_cse_id + if not val: + raise ValueError( + "To use, make sure you provide the google_cse_id when constructing an object. Alternatively, " + "ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain " + "an API key from https://programmablesearchengine.google.com/controlpanel/create." + ) + return val + + @property + def google_api_client(self): + build_kwargs = {"developerKey": self.google_api_key} if CONFIG.global_proxy: parse_result = urlparse(CONFIG.global_proxy) proxy_type = parse_result.scheme @@ -42,10 +73,7 @@ class GoogleAPIWrapper: ), ) service = build("customsearch", "v1", **build_kwargs) - self.google_api_client = service.cse() - self.custom_search_engine_id = CONFIG.google_cse_id - self.loop = loop - self.executor = executor + return service.cse() async def run( self, @@ -69,12 +97,7 @@ class GoogleAPIWrapper: """ loop = self.loop or asyncio.get_event_loop() future = loop.run_in_executor( - self.executor, - self.google_api_client.list( - q=query, - num=max_results, - cx=self.custom_search_engine_id - ).execute + self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute ) try: result = await future @@ -85,13 +108,13 @@ class GoogleAPIWrapper: # Handle errors in the API call logger.exception(f"fail to search {query} for {e}") search_results = [] - + focus = focus or ["snippet", "link", "title"] details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] # Return the list of search result URLs if as_string: return safe_google_results(details) - + return details diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 3d2d7cfe4..750184198 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -8,19 +8,12 @@ from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator -from metagpt.config import Config +from metagpt.config import CONFIG class SerpAPIWrapper(BaseModel): - """Wrapper around SerpAPI. - - To use, you should have the ``google-search-results`` python package installed, - and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass - `serpapi_api_key` as a named parameter to the constructor. - """ - search_engine: Any #: :meta private: params: dict = Field( default={ @@ -30,14 +23,25 @@ class SerpAPIWrapper(BaseModel): "hl": "en", } ) - config = Config() - serpapi_api_key: Optional[str] = config.serpapi_api_key + serpapi_api_key: Optional[str] = None aiosession: Optional[aiohttp.ClientSession] = None class Config: arbitrary_types_allowed = True - async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: + @validator("serpapi_api_key", always=True) + @classmethod + def check_serpapi_api_key(cls, val: str): + val = val or CONFIG.serpapi_api_key + if not val: + raise ValueError( + "To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, " + "ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain " + "an API key from https://serpapi.com/." + ) + return val + + async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" return self._process_response(await self.results(query, max_results), as_string=as_string) @@ -48,8 +52,6 @@ class SerpAPIWrapper(BaseModel): params = self.get_params(query) params["source"] = "python" params["num"] = max_results - if self.serpapi_api_key: - params["serp_api_key"] = self.serpapi_api_key params["output"] = "json" url = "https://serpapi.com/search" return url, params @@ -104,7 +106,7 @@ class SerpAPIWrapper(BaseModel): if res.get("organic_results"): toret_l += [get_focused(i) for i in res.get("organic_results")] - return str(toret) + '\n' + str(toret_l) if as_string else toret_l + return str(toret) + "\n" + str(toret_l) if as_string else toret_l if __name__ == "__main__": diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 2ae2c3b7d..0eec2694b 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -9,33 +9,32 @@ import json from typing import Any, Dict, Optional, Tuple import aiohttp -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, validator -from metagpt.config import Config +from metagpt.config import CONFIG class SerperWrapper(BaseModel): - """Wrapper around SerpAPI. - - To use, you should have the ``google-search-results`` python package installed, - and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass - `serpapi_api_key` as a named parameter to the constructor. - """ - search_engine: Any #: :meta private: - payload: dict = Field( - default={ - "page": 1, - "num": 10 - } - ) - config = Config() - serper_api_key: Optional[str] = config.serper_api_key + payload: dict = Field(default={"page": 1, "num": 10}) + serper_api_key: Optional[str] = None aiosession: Optional[aiohttp.ClientSession] = None class Config: arbitrary_types_allowed = True + @validator("serper_api_key", always=True) + @classmethod + def check_serper_api_key(cls, val: str): + val = val or CONFIG.serper_api_key + if not val: + raise ValueError( + "To use, make sure you provide the serper_api_key when constructing an object. Alternatively, " + "ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain " + "an API key from https://serper.dev/." + ) + return val + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" if isinstance(query, str): @@ -76,18 +75,17 @@ class SerperWrapper(BaseModel): return json.dumps(payloads, sort_keys=True) def get_headers(self) -> Dict[str, str]: - headers = { - 'X-API-KEY': self.serper_api_key, - 'Content-Type': 'application/json' - } + headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"} return headers @staticmethod def _process_response(res: dict, as_string: bool = False) -> str: """Process response from SerpAPI.""" # logger.debug(res) - focus = ['title', 'snippet', 'link'] - def get_focused(x): return {i: j for i, j in x.items() if i in focus} + focus = ["title", "snippet", "link"] + + def get_focused(x): + return {i: j for i, j in x.items() if i in focus} if "error" in res.keys(): raise ValueError(f"Got error from SerpAPI: {res['error']}") @@ -95,20 +93,11 @@ class SerperWrapper(BaseModel): toret = res["answer_box"]["answer"] elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): + elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys(): toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): + elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys(): toret = res["sports_results"]["game_spotlight"] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): + elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): toret = res["knowledge_graph"]["description"] elif "snippet" in res["organic"][0].keys(): toret = res["organic"][0]["snippet"] @@ -121,7 +110,7 @@ class SerperWrapper(BaseModel): if res.get("organic"): toret_l += [get_focused(i) for i in res.get("organic")] - return str(toret) + '\n' + str(toret_l) if as_string else toret_l + return str(toret) + "\n" + str(toret_l) if as_string else toret_l if __name__ == "__main__": diff --git a/requirements.txt b/requirements.txt index c5440abe0..efc2ea3e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,6 @@ channels==4.0.0 # chromadb==0.3.22 # Django==4.1.5 # docx==0.2.4 -duckduckgo_search==2.9.4 #faiss==1.5.3 faiss_cpu==1.7.4 fire==0.4.0 diff --git a/setup.py b/setup.py index 2a8edaae7..a88f9de92 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,8 @@ setup( extras_require={ "playwright": ["playwright>=1.26", "beautifulsoup4"], "selenium": ["selenium>4", "webdriver_manager", "beautifulsoup4"], + "search-google": ["google-api-python-client==2.94.0"], + "search-ddg": ["duckduckgo-search==3.8.5"], }, cmdclass={ "install_mermaid": InstallMermaidCLI,