From c62c870ab9312c5d45ddde45299df9fb5af8750b Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Mon, 7 Aug 2023 15:15:10 +0800 Subject: [PATCH] Unified Search Engine API with support for returning structured data. --- metagpt/tools/__init__.py | 1 + metagpt/tools/search_engine.py | 164 ++++++++-------------- metagpt/tools/search_engine_ddg.py | 107 ++++++++++++++ metagpt/tools/search_engine_googleapi.py | 117 +++++++++++++++ metagpt/tools/search_engine_serpapi.py | 34 +++-- metagpt/tools/search_engine_serper.py | 26 ++-- tests/metagpt/tools/test_search_engine.py | 46 ++++-- 7 files changed, 351 insertions(+), 144 deletions(-) create mode 100644 metagpt/tools/search_engine_ddg.py create mode 100644 metagpt/tools/search_engine_googleapi.py diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index f9b7abc52..e1f921c05 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -14,6 +14,7 @@ class SearchEngineType(Enum): SERPAPI_GOOGLE = auto() DIRECT_GOOGLE = auto() SERPER_GOOGLE = auto() + DUCK_DUCK_GO = auto() CUSTOM_ENGINE = auto() diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index cfd4e8789..d28700054 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -7,122 +7,76 @@ """ from __future__ import annotations -import json +import importlib +from typing import Callable, Coroutine, Literal, overload -from metagpt.config import Config -from metagpt.logs import logger -from metagpt.tools.search_engine_serpapi import SerpAPIWrapper -from metagpt.tools.search_engine_serper import SerperWrapper - -config = Config() +from metagpt.config import CONFIG from metagpt.tools import SearchEngineType class SearchEngine: - """ - TODO: 合入Google Search 并进行反代 - 注:这里Google需要挂Proxifier或者类似全局代理 - - DDG: https://pypi.org/project/duckduckgo-search/ - - GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9 - """ - def __init__(self, engine=None, run_func=None): - self.config = Config() - self.run_func = run_func - self.engine = engine or self.config.search_engine + """Class representing a search engine. - @classmethod - def run_google(cls, query, max_results=8): - # results = ddg(query, max_results=max_results) - results = google_official_search(query, num_results=max_results) - logger.info(results) - return results + Args: + engine: The search engine type. Defaults to the search engine specified in the config. + run_func: The function to run the search. Defaults to None. - async def run(self, query: str, max_results=8): - if self.engine == SearchEngineType.SERPAPI_GOOGLE: - api = SerpAPIWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.DIRECT_GOOGLE: - rsp = SearchEngine.run_google(query, max_results) - elif self.engine == SearchEngineType.SERPER_GOOGLE: - api = SerperWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.CUSTOM_ENGINE: - rsp = self.run_func(query) + Attributes: + run_func: The function to run the search. + engine: The search engine type. + """ + def __init__( + self, + engine: SearchEngineType | None = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None, + ): + engine = engine or CONFIG.search_engine + if engine == SearchEngineType.SERPAPI_GOOGLE: + module = "metagpt.tools.search_engine_serpapi" + run_func = importlib.import_module(module).SerpAPIWrapper().run + elif engine == SearchEngineType.SERPER_GOOGLE: + module = "metagpt.tools.search_engine_serper" + run_func = importlib.import_module(module).SerperWrapper().run + elif engine == SearchEngineType.DIRECT_GOOGLE: + module = "metagpt.tools.search_engine_googleapi" + run_func = importlib.import_module(module).GoogleAPIWrapper().run + elif engine == SearchEngineType.DUCK_DUCK_GO: + module = "metagpt.tools.search_engine_ddg" + run_func = importlib.import_module(module).DDGAPIWrapper().run + elif engine == SearchEngineType.CUSTOM_ENGINE: + pass # run_func = run_func else: raise NotImplementedError - return rsp + self.engine = engine + self.run_func = run_func + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + ) -> str: + ... -def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]: - """Return the results of a Google search using the official Google API + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + ) -> list[dict[str, str]]: + ... - Args: - query (str): The search query. - num_results (int): The number of results to return. + async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: + """Run a search query. - Returns: - str: The results of the search. - """ + Args: + query: The search query. + max_results: The maximum number of results to return. Defaults to 8. + as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True. - from googleapiclient.discovery import build - from googleapiclient.errors import HttpError - - try: - api_key = config.google_api_key - custom_search_engine_id = config.google_cse_id - - with build("customsearch", "v1", developerKey=api_key) as service: - - result = ( - service.cse() - .list(q=query, cx=custom_search_engine_id, num=num_results) - .execute() - ) - logger.info(result) - # Extract the search result items from the response - search_results = result.get("items", []) - - # Create a list of only the URLs from the search results - search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] - - except HttpError as e: - # Handle errors in the API call - error_details = json.loads(e.content.decode()) - - # Check if the error is related to an invalid or missing API key - if error_details.get("error", {}).get( - "code" - ) == 403 and "invalid API key" in error_details.get("error", {}).get( - "message", "" - ): - return "Error: The provided Google API key is invalid or missing." - else: - return f"Error: {e}" - # google_result can be a list or a string depending on the search results - - # Return the list of search result URLs - return search_results_details - - -def safe_google_results(results: str | list) -> str: - """ - Return the results of a google search in a safe format. - - Args: - results (str | list): The search results. - - Returns: - str: The results of the search. - """ - if isinstance(results, list): - safe_message = json.dumps( - # FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪 - [result for result in results] - ) - else: - safe_message = results.encode("utf-8", "ignore").decode("utf-8") - return safe_message - - -if __name__ == '__main__': - SearchEngine.run(query='wtf') + Returns: + The search results as a string or a list of dictionaries. + """ + return await self.run_func(query, max_results=max_results, as_string=as_string) diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py new file mode 100644 index 000000000..c054afed1 --- /dev/null +++ b/metagpt/tools/search_engine_ddg.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +from __future__ import annotations + +import asyncio +import json +from concurrent import futures +from typing import Literal, overload + +from duckduckgo_search import DDGS +from googleapiclient.errors import HttpError + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class DDGAPIWrapper: + """Wrapper around duckduckgo_search API. + + To use this module, you should have the `duckduckgo_search` Python package installed. + """ + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, + ): + kwargs = {} + if CONFIG.global_proxy: + kwargs["proxies"] = CONFIG.global_proxy + self.loop = loop + self.executor = executor + self.ddgs = DDGS(**kwargs) + + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + focus: list[str] | None = None, + ) -> str: + ... + + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + focus: list[str] | None = None, + ) -> list[dict[str, str]]: + ... + + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + ) -> str | list[dict]: + """Return the results of a Google search using the official Google API + + Args: + query: The search query. + max_results: The number of results to return. + as_string: A boolean flag to determine the return type of the results. If True, the function will + return a formatted string with the search results. If False, it will return a list of dictionaries + containing detailed information about each search result. + + Returns: + The results of the search. + """ + loop = self.loop or asyncio.get_event_loop() + future = loop.run_in_executor( + self.executor, + self._search_from_ddgs, + query, + max_results, + ) + try: + search_results = await future + # Extract the search result items from the response + + except HttpError as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + search_results = [] + + # Return the list of search result URLs + if as_string: + return json.dumps(search_results, ensure_ascii=False) + return search_results + + def _search_from_ddgs(self, query: str, max_results: int): + return [ + { + "link": i["href"], + "snippet": i["body"], + "title": i["title"] + } for (_, i) in zip(range(max_results), self.ddgs.text(query)) + ] + + +if __name__ == "__main__": + import fire + + fire.Fire(DDGAPIWrapper().run) diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py new file mode 100644 index 000000000..c226ca8d2 --- /dev/null +++ b/metagpt/tools/search_engine_googleapi.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +from __future__ import annotations + +import asyncio +import json +from concurrent import futures +from urllib.parse import urlparse + +import httplib2 +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError + +from metagpt.config import CONFIG +from metagpt.logs import logger + + +class GoogleAPIWrapper: + """Wrapper around GoogleAPI. + + To use this module, you should have the `google-api-python-client` Python package installed + and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See + https://programmablesearchengine.google.com/controlpanel/all. + """ + def __init__( + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, + ): + build_kwargs = {"developerKey": CONFIG.google_api_key} + if CONFIG.global_proxy: + parse_result = urlparse(CONFIG.global_proxy) + proxy_type = parse_result.scheme + if proxy_type == "https": + proxy_type = "http" + build_kwargs["http"] = httplib2.Http( + proxy_info=httplib2.ProxyInfo( + getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"), + parse_result.hostname, + parse_result.port, + ), + ) + service = build("customsearch", "v1", **build_kwargs) + self.google_api_client = service.cse() + self.custom_search_engine_id = CONFIG.google_cse_id + self.loop = loop + self.executor = executor + + async def run( + self, + query: str, + max_results: int = 8, + as_string: bool = True, + focus: list[str] | None = None, + ) -> str | list[dict]: + """Return the results of a Google search using the official Google API. + + Args: + query: The search query. + max_results: The number of results to return. + as_string: A boolean flag to determine the return type of the results. If True, the function will + return a formatted string with the search results. If False, it will return a list of dictionaries + containing detailed information about each search result. + focus: Specific information to be focused on from each search result. + + Returns: + The results of the search. + """ + loop = self.loop or asyncio.get_event_loop() + future = loop.run_in_executor( + self.executor, + self.google_api_client.list( + q=query, + num=max_results, + cx=self.custom_search_engine_id + ).execute + ) + try: + result = await future + # Extract the search result items from the response + search_results = result.get("items", []) + + except HttpError as e: + # Handle errors in the API call + logger.exception(f"fail to search {query} for {e}") + search_results = [] + + focus = focus or ["snippet", "link", "title"] + details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] + # Return the list of search result URLs + if as_string: + return safe_google_results(details) + + return details + + +def safe_google_results(results: str | list) -> str: + """Return the results of a google search in a safe format. + + Args: + results: The search results. + + Returns: + The results of the search. + """ + if isinstance(results, list): + safe_message = json.dumps([result for result in results]) + else: + safe_message = results.encode("utf-8", "ignore").decode("utf-8") + return safe_message + + +if __name__ == "__main__": + import fire + + fire.Fire(GoogleAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 28033f237..3d2d7cfe4 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" - return self._process_response(await self.results(query)) + return self._process_response(await self.results(query, max_results), as_string=as_string) - async def results(self, query: str) -> dict: + async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" def construct_url_and_params() -> Tuple[str, Dict[str, str]]: params = self.get_params(query) params["source"] = "python" + params["num"] = max_results if self.serpapi_api_key: params["serp_api_key"] = self.serpapi_api_key params["output"] = "json" @@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel): return params @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool) -> str: """Process response from SerpAPI.""" # logger.debug(res) - focus = ['title', 'snippet', 'link'] + focus = ["title", "snippet", "link"] get_focused = lambda x: {i: j for i, j in x.items() if i in focus} if "error" in res.keys(): @@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel): toret = res["answer_box"]["answer"] elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): + elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys(): toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): + elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys(): toret = res["sports_results"]["game_spotlight"] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): + elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): toret = res["knowledge_graph"]["description"] elif "snippet" in res["organic_results"][0].keys(): toret = res["organic_results"][0]["snippet"] @@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel): if res.get("organic_results"): toret_l += [get_focused(i) for i in res.get("organic_results")] - return str(toret) + '\n' + str(toret_l) + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerpAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 80c2f8001..2ae2c3b7d 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -36,16 +36,19 @@ class SerperWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" - queries = query.split("\n") - return "\n".join([self._process_response(res) for res in await self.results(queries)]) + if isinstance(query, str): + return self._process_response((await self.results([query], max_results))[0], as_string=as_string) + else: + results = [self._process_response(res, as_string) for res in await self.results(query, max_results)] + return "\n".join(results) if as_string else results - async def results(self, queries: list[str]) -> dict: + async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries) + payloads = self.get_payloads(queries, max_results) url = "https://google.serper.dev/search" headers = self.get_headers() return url, payloads, headers @@ -61,12 +64,13 @@ class SerperWrapper(BaseModel): return res - def get_payloads(self, queries: list[str]) -> Dict[str, str]: + def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]: """Get payloads for Serper.""" payloads = [] for query in queries: _payload = { "q": query, + "num": max_results, } payloads.append({**self.payload, **_payload}) return json.dumps(payloads, sort_keys=True) @@ -79,7 +83,7 @@ class SerperWrapper(BaseModel): return headers @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool = False) -> str: """Process response from SerpAPI.""" # logger.debug(res) focus = ['title', 'snippet', 'link'] @@ -117,4 +121,10 @@ class SerperWrapper(BaseModel): if res.get("organic"): toret_l += [get_focused(i) for i in res.get("organic")] - return str(toret) + '\n' + str(toret_l) + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerperWrapper().run) diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 2418c7b26..a7fe063a6 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -5,24 +5,44 @@ @Author : alexanderwu @File : test_search_engine.py """ +from __future__ import annotations import pytest from metagpt.logs import logger +from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine +class MockSearchEnine: + async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: + rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)] + return "\n".join(rets) if as_string else rets + + @pytest.mark.asyncio -@pytest.mark.usefixtures("llm_api") -async def test_search_engine(llm_api): - search_engine = SearchEngine() - poetries = [ - # ("北京美食", "北京"), - ("屈臣氏", "屈臣氏") - ] - for i, j in poetries: - rsp = await search_engine.run(i) - # rsp = context.llm.ask_batch([prompt]) - logger.info(rsp) - # assert any(j in k['body'] for k in rsp) - assert len(rsp) > 0 +@pytest.mark.parametrize( + ("search_engine_typpe", "run_func", "max_results", "as_string"), + [ + (SearchEngineType.SERPAPI_GOOGLE, None, 8, True), + (SearchEngineType.SERPAPI_GOOGLE, None, 4, False), + (SearchEngineType.DIRECT_GOOGLE, None, 8, True), + (SearchEngineType.DIRECT_GOOGLE, None, 6, False), + (SearchEngineType.SERPER_GOOGLE, None, 8, True), + (SearchEngineType.SERPER_GOOGLE, None, 6, False), + (SearchEngineType.DUCK_DUCK_GO, None, 8, True), + (SearchEngineType.DUCK_DUCK_GO, None, 6, False), + (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False), + (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False), + + ], +) +async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ): + search_engine = SearchEngine(search_engine_typpe, run_func) + rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) + logger.info(rsp) + if as_string: + assert isinstance(rsp, str) + else: + assert isinstance(rsp, list) + assert len(rsp) == max_results