Unified Search Engine API with support for returning structured data.

This commit is contained in:
shenchucheng 2023-08-07 15:15:10 +08:00
parent 930d18962f
commit c62c870ab9
7 changed files with 351 additions and 144 deletions

View file

@ -14,6 +14,7 @@ class SearchEngineType(Enum):
SERPAPI_GOOGLE = auto()
DIRECT_GOOGLE = auto()
SERPER_GOOGLE = auto()
DUCK_DUCK_GO = auto()
CUSTOM_ENGINE = auto()

View file

@ -7,122 +7,76 @@
"""
from __future__ import annotations
import json
import importlib
from typing import Callable, Coroutine, Literal, overload
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
from metagpt.tools.search_engine_serper import SerperWrapper
config = Config()
from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
class SearchEngine:
"""
TODO: 合入Google Search 并进行反代
这里Google需要挂Proxifier或者类似全局代理
- DDG: https://pypi.org/project/duckduckgo-search/
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
"""
def __init__(self, engine=None, run_func=None):
self.config = Config()
self.run_func = run_func
self.engine = engine or self.config.search_engine
"""Class representing a search engine.
@classmethod
def run_google(cls, query, max_results=8):
# results = ddg(query, max_results=max_results)
results = google_official_search(query, num_results=max_results)
logger.info(results)
return results
Args:
engine: The search engine type. Defaults to the search engine specified in the config.
run_func: The function to run the search. Defaults to None.
async def run(self, query: str, max_results=8):
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
api = SerpAPIWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
rsp = SearchEngine.run_google(query, max_results)
elif self.engine == SearchEngineType.SERPER_GOOGLE:
api = SerperWrapper()
rsp = await api.run(query)
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
rsp = self.run_func(query)
Attributes:
run_func: The function to run the search.
engine: The search engine type.
"""
def __init__(
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper().run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper().run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper().run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper().run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:
raise NotImplementedError
return rsp
self.engine = engine
self.run_func = run_func
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
) -> str:
...
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
"""Return the results of a Google search using the official Google API
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
) -> list[dict[str, str]]:
...
Args:
query (str): The search query.
num_results (int): The number of results to return.
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
"""Run a search query.
Returns:
str: The results of the search.
"""
Args:
query: The search query.
max_results: The maximum number of results to return. Defaults to 8.
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
try:
api_key = config.google_api_key
custom_search_engine_id = config.google_cse_id
with build("customsearch", "v1", developerKey=api_key) as service:
result = (
service.cse()
.list(q=query, cx=custom_search_engine_id, num=num_results)
.execute()
)
logger.info(result)
# Extract the search result items from the response
search_results = result.get("items", [])
# Create a list of only the URLs from the search results
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
except HttpError as e:
# Handle errors in the API call
error_details = json.loads(e.content.decode())
# Check if the error is related to an invalid or missing API key
if error_details.get("error", {}).get(
"code"
) == 403 and "invalid API key" in error_details.get("error", {}).get(
"message", ""
):
return "Error: The provided Google API key is invalid or missing."
else:
return f"Error: {e}"
# google_result can be a list or a string depending on the search results
# Return the list of search result URLs
return search_results_details
def safe_google_results(results: str | list) -> str:
"""
Return the results of a google search in a safe format.
Args:
results (str | list): The search results.
Returns:
str: The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps(
# FIXME: # .encode("utf-8", "ignore") 这里去掉了但是AutoGPT里有很奇怪
[result for result in results]
)
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == '__main__':
SearchEngine.run(query='wtf')
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)

View file

@ -0,0 +1,107 @@
#!/usr/bin/env python
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload
from duckduckgo_search import DDGS
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
) -> list[dict[str, str]]:
...
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self._search_from_ddgs,
query,
max_results,
)
try:
search_results = await future
# Extract the search result items from the response
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
# Return the list of search result URLs
if as_string:
return json.dumps(search_results, ensure_ascii=False)
return search_results
def _search_from_ddgs(self, query: str, max_results: int):
return [
{
"link": i["href"],
"snippet": i["body"],
"title": i["title"]
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
]
if __name__ == "__main__":
import fire
fire.Fire(DDGAPIWrapper().run)

View file

@ -0,0 +1,117 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from urllib.parse import urlparse
import httplib2
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from metagpt.config import CONFIG
from metagpt.logs import logger
class GoogleAPIWrapper:
"""Wrapper around GoogleAPI.
To use this module, you should have the `google-api-python-client` Python package installed
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
https://programmablesearchengine.google.com/controlpanel/all.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
build_kwargs = {"developerKey": CONFIG.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"
build_kwargs["http"] = httplib2.Http(
proxy_info=httplib2.ProxyInfo(
getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
parse_result.hostname,
parse_result.port,
),
)
service = build("customsearch", "v1", **build_kwargs)
self.google_api_client = service.cse()
self.custom_search_engine_id = CONFIG.google_cse_id
self.loop = loop
self.executor = executor
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
focus: list[str] | None = None,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API.
Args:
query: The search query.
max_results: The number of results to return.
as_string: A boolean flag to determine the return type of the results. If True, the function will
return a formatted string with the search results. If False, it will return a list of dictionaries
containing detailed information about each search result.
focus: Specific information to be focused on from each search result.
Returns:
The results of the search.
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self.google_api_client.list(
q=query,
num=max_results,
cx=self.custom_search_engine_id
).execute
)
try:
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
# Return the list of search result URLs
if as_string:
return safe_google_results(details)
return details
def safe_google_results(results: str | list) -> str:
"""Return the results of a google search in a safe format.
Args:
results: The search results.
Returns:
The results of the search.
"""
if isinstance(results, list):
safe_message = json.dumps([result for result in results])
else:
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
return safe_message
if __name__ == "__main__":
import fire
fire.Fire(GoogleAPIWrapper().run)

View file

@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query))
return self._process_response(await self.results(query, max_results), as_string=as_string)
async def results(self, query: str) -> dict:
async def results(self, query: str, max_results: int) -> dict:
"""Use aiohttp to run query through SerpAPI and return the results async."""
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
return params
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
focus = ["title", "snippet", "link"]
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
toret = res["sports_results"]["game_spotlight"]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic_results"][0].keys():
toret = res["organic_results"][0]["snippet"]
@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel):
if res.get("organic_results"):
toret_l += [get_focused(i) for i in res.get("organic_results")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerpAPIWrapper().run)

View file

@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, **kwargs: Any) -> str:
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
queries = query.split("\n")
return "\n".join([self._process_response(res) for res in await self.results(queries)])
if isinstance(query, str):
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
else:
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
return "\n".join(results) if as_string else results
async def results(self, queries: list[str]) -> dict:
async def results(self, queries: list[str], max_results: int = 8) -> dict:
"""Use aiohttp to run query through Serper and return the results async."""
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
payloads = self.get_payloads(queries)
payloads = self.get_payloads(queries, max_results)
url = "https://google.serper.dev/search"
headers = self.get_headers()
return url, payloads, headers
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
return res
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
"""Get payloads for Serper."""
payloads = []
for query in queries:
_payload = {
"q": query,
"num": max_results,
}
payloads.append({**self.payload, **_payload})
return json.dumps(payloads, sort_keys=True)
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
return headers
@staticmethod
def _process_response(res: dict) -> str:
def _process_response(res: dict, as_string: bool = False) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
@ -117,4 +121,10 @@ class SerperWrapper(BaseModel):
if res.get("organic"):
toret_l += [get_focused(i) for i in res.get("organic")]
return str(toret) + '\n' + str(toret_l)
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
if __name__ == "__main__":
import fire
fire.Fire(SerperWrapper().run)

View file

@ -5,24 +5,44 @@
@Author : alexanderwu
@File : test_search_engine.py
"""
from __future__ import annotations
import pytest
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
return "\n".join(rets) if as_string else rets
@pytest.mark.asyncio
@pytest.mark.usefixtures("llm_api")
async def test_search_engine(llm_api):
search_engine = SearchEngine()
poetries = [
# ("北京美食", "北京"),
("屈臣氏", "屈臣氏")
]
for i, j in poetries:
rsp = await search_engine.run(i)
# rsp = context.llm.ask_batch([prompt])
logger.info(rsp)
# assert any(j in k['body'] for k in rsp)
assert len(rsp) > 0
@pytest.mark.parametrize(
("search_engine_typpe", "run_func", "max_results", "as_string"),
[
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
(SearchEngineType.DIRECT_GOOGLE, None, 8, True),
(SearchEngineType.DIRECT_GOOGLE, None, 6, False),
(SearchEngineType.SERPER_GOOGLE, None, 8, True),
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
search_engine = SearchEngine(search_engine_typpe, run_func)
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)
else:
assert isinstance(rsp, list)
assert len(rsp) == max_results