Make the SearchEngine more user-friendly.

This commit is contained in:
shenchucheng 2023-08-17 17:37:20 +08:00
parent 7e329a478a
commit 6e6e91660d
10 changed files with 133 additions and 111 deletions

View file

@ -7,15 +7,15 @@
"""
from enum import Enum, auto
from enum import Enum
class SearchEngineType(Enum):
SERPAPI_GOOGLE = auto()
DIRECT_GOOGLE = auto()
SERPER_GOOGLE = auto()
DUCK_DUCK_GO = auto()
CUSTOM_ENGINE = auto()
SERPAPI_GOOGLE = "serpapi"
SERPER_GOOGLE = "serper"
DIRECT_GOOGLE = "google"
DUCK_DUCK_GO = "ddg"
CUSTOM_ENGINE = "custom"
class WebBrowserEngineType(Enum):

View file

@ -7,11 +7,15 @@ import json
from concurrent import futures
from typing import Literal, overload
from duckduckgo_search import DDGS
from googleapiclient.errors import HttpError
try:
from duckduckgo_search import DDGS
except ImportError:
raise ImportError(
"To use this module, you should have the `duckduckgo_search` Python package installed. "
"You can install it by running the command: `pip install -e.[search-ddg]`"
)
from metagpt.config import CONFIG
from metagpt.logs import logger
class DDGAPIWrapper:
@ -19,6 +23,7 @@ class DDGAPIWrapper:
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
def __init__(
self,
*,
@ -77,15 +82,8 @@ class DDGAPIWrapper:
query,
max_results,
)
try:
search_results = await future
# Extract the search result items from the response
search_results = await future
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
# Return the list of search result URLs
if as_string:
return json.dumps(search_results, ensure_ascii=False)
@ -93,11 +91,8 @@ class DDGAPIWrapper:
def _search_from_ddgs(self, query: str, max_results: int):
return [
{
"link": i["href"],
"snippet": i["body"],
"title": i["title"]
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
{"link": i["href"], "snippet": i["body"], "title": i["title"]}
for (_, i) in zip(range(max_results), self.ddgs.text(query))
]

View file

@ -5,30 +5,61 @@ from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Optional
from urllib.parse import urlparse
import httplib2
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
from pydantic import BaseModel, validator
from metagpt.config import CONFIG
from metagpt.logs import logger
try:
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError(
"To use this module, you should have the `google-api-python-client` Python package installed. "
"You can install it by running the command: `pip install -e.[search-google]`"
)
class GoogleAPIWrapper:
"""Wrapper around GoogleAPI.
To use this module, you should have the `google-api-python-client` Python package installed
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
https://programmablesearchengine.google.com/controlpanel/all.
"""
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
build_kwargs = {"developerKey": CONFIG.google_api_key}
class GoogleAPIWrapper(BaseModel):
google_api_key: Optional[str] = None
google_cse_id: Optional[str] = None
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
class Config:
arbitrary_types_allowed = True
@validator("google_api_key", always=True)
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
"an API key from https://console.cloud.google.com/apis/credentials."
)
return val
@validator("google_cse_id", always=True)
@classmethod
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
)
return val
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
@ -42,10 +73,7 @@ class GoogleAPIWrapper:
),
)
service = build("customsearch", "v1", **build_kwargs)
self.google_api_client = service.cse()
self.custom_search_engine_id = CONFIG.google_cse_id
self.loop = loop
self.executor = executor
return service.cse()
async def run(
self,
@ -69,12 +97,7 @@ class GoogleAPIWrapper:
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor,
self.google_api_client.list(
q=query,
num=max_results,
cx=self.custom_search_engine_id
).execute
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
)
try:
result = await future
@ -85,13 +108,13 @@ class GoogleAPIWrapper:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
# Return the list of search result URLs
if as_string:
return safe_google_results(details)
return details

View file

@ -8,19 +8,12 @@
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from metagpt.config import Config
from metagpt.config import CONFIG
class SerpAPIWrapper(BaseModel):
"""Wrapper around SerpAPI.
To use, you should have the ``google-search-results`` python package installed,
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
`serpapi_api_key` as a named parameter to the constructor.
"""
search_engine: Any #: :meta private:
params: dict = Field(
default={
@ -30,14 +23,25 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
config = Config()
serpapi_api_key: Optional[str] = config.serpapi_api_key
serpapi_api_key: Optional[str] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
arbitrary_types_allowed = True
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
@validator("serpapi_api_key", always=True)
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
"an API key from https://serpapi.com/."
)
return val
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
return self._process_response(await self.results(query, max_results), as_string=as_string)
@ -48,8 +52,6 @@ class SerpAPIWrapper(BaseModel):
params = self.get_params(query)
params["source"] = "python"
params["num"] = max_results
if self.serpapi_api_key:
params["serp_api_key"] = self.serpapi_api_key
params["output"] = "json"
url = "https://serpapi.com/search"
return url, params
@ -104,7 +106,7 @@ class SerpAPIWrapper(BaseModel):
if res.get("organic_results"):
toret_l += [get_focused(i) for i in res.get("organic_results")]
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
return str(toret) + "\n" + str(toret_l) if as_string else toret_l
if __name__ == "__main__":

View file

@ -9,33 +9,32 @@ import json
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, validator
from metagpt.config import Config
from metagpt.config import CONFIG
class SerperWrapper(BaseModel):
"""Wrapper around SerpAPI.
To use, you should have the ``google-search-results`` python package installed,
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
`serpapi_api_key` as a named parameter to the constructor.
"""
search_engine: Any #: :meta private:
payload: dict = Field(
default={
"page": 1,
"num": 10
}
)
config = Config()
serper_api_key: Optional[str] = config.serper_api_key
payload: dict = Field(default={"page": 1, "num": 10})
serper_api_key: Optional[str] = None
aiosession: Optional[aiohttp.ClientSession] = None
class Config:
arbitrary_types_allowed = True
@validator("serper_api_key", always=True)
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
"an API key from https://serper.dev/."
)
return val
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
if isinstance(query, str):
@ -76,18 +75,17 @@ class SerperWrapper(BaseModel):
return json.dumps(payloads, sort_keys=True)
def get_headers(self) -> Dict[str, str]:
headers = {
'X-API-KEY': self.serper_api_key,
'Content-Type': 'application/json'
}
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
return headers
@staticmethod
def _process_response(res: dict, as_string: bool = False) -> str:
"""Process response from SerpAPI."""
# logger.debug(res)
focus = ['title', 'snippet', 'link']
def get_focused(x): return {i: j for i, j in x.items() if i in focus}
focus = ["title", "snippet", "link"]
def get_focused(x):
return {i: j for i, j in x.items() if i in focus}
if "error" in res.keys():
raise ValueError(f"Got error from SerpAPI: {res['error']}")
@ -95,20 +93,11 @@ class SerperWrapper(BaseModel):
toret = res["answer_box"]["answer"]
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
toret = res["answer_box"]["snippet"]
elif (
"answer_box" in res.keys()
and "snippet_highlighted_words" in res["answer_box"].keys()
):
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
toret = res["answer_box"]["snippet_highlighted_words"][0]
elif (
"sports_results" in res.keys()
and "game_spotlight" in res["sports_results"].keys()
):
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
toret = res["sports_results"]["game_spotlight"]
elif (
"knowledge_graph" in res.keys()
and "description" in res["knowledge_graph"].keys()
):
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
toret = res["knowledge_graph"]["description"]
elif "snippet" in res["organic"][0].keys():
toret = res["organic"][0]["snippet"]
@ -121,7 +110,7 @@ class SerperWrapper(BaseModel):
if res.get("organic"):
toret_l += [get_focused(i) for i in res.get("organic")]
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
return str(toret) + "\n" + str(toret_l) if as_string else toret_l
if __name__ == "__main__":