mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-26 09:16:21 +02:00
Make the SearchEngine more user-friendly.
This commit is contained in:
parent
7e329a478a
commit
6e6e91660d
10 changed files with 133 additions and 111 deletions
|
|
@ -7,15 +7,15 @@
|
|||
"""
|
||||
|
||||
|
||||
from enum import Enum, auto
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class SearchEngineType(Enum):
|
||||
SERPAPI_GOOGLE = auto()
|
||||
DIRECT_GOOGLE = auto()
|
||||
SERPER_GOOGLE = auto()
|
||||
DUCK_DUCK_GO = auto()
|
||||
CUSTOM_ENGINE = auto()
|
||||
SERPAPI_GOOGLE = "serpapi"
|
||||
SERPER_GOOGLE = "serper"
|
||||
DIRECT_GOOGLE = "google"
|
||||
DUCK_DUCK_GO = "ddg"
|
||||
CUSTOM_ENGINE = "custom"
|
||||
|
||||
|
||||
class WebBrowserEngineType(Enum):
|
||||
|
|
|
|||
|
|
@ -7,11 +7,15 @@ import json
|
|||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
from googleapiclient.errors import HttpError
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use this module, you should have the `duckduckgo_search` Python package installed. "
|
||||
"You can install it by running the command: `pip install -e.[search-ddg]`"
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
|
|
@ -19,6 +23,7 @@ class DDGAPIWrapper:
|
|||
|
||||
To use this module, you should have the `duckduckgo_search` Python package installed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
|
|
@ -77,15 +82,8 @@ class DDGAPIWrapper:
|
|||
query,
|
||||
max_results,
|
||||
)
|
||||
try:
|
||||
search_results = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = await future
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return json.dumps(search_results, ensure_ascii=False)
|
||||
|
|
@ -93,11 +91,8 @@ class DDGAPIWrapper:
|
|||
|
||||
def _search_from_ddgs(self, query: str, max_results: int):
|
||||
return [
|
||||
{
|
||||
"link": i["href"],
|
||||
"snippet": i["body"],
|
||||
"title": i["title"]
|
||||
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
|
||||
{"link": i["href"], "snippet": i["body"], "title": i["title"]}
|
||||
for (_, i) in zip(range(max_results), self.ddgs.text(query))
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,30 +5,61 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
try:
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use this module, you should have the `google-api-python-client` Python package installed. "
|
||||
"You can install it by running the command: `pip install -e.[search-google]`"
|
||||
)
|
||||
|
||||
class GoogleAPIWrapper:
|
||||
"""Wrapper around GoogleAPI.
|
||||
|
||||
To use this module, you should have the `google-api-python-client` Python package installed
|
||||
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
|
||||
https://programmablesearchengine.google.com/controlpanel/all.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
build_kwargs = {"developerKey": CONFIG.google_api_key}
|
||||
class GoogleAPIWrapper(BaseModel):
|
||||
google_api_key: Optional[str] = None
|
||||
google_cse_id: Optional[str] = None
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("google_api_key", always=True)
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or CONFIG.google_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
|
||||
"an API key from https://console.cloud.google.com/apis/credentials."
|
||||
)
|
||||
return val
|
||||
|
||||
@validator("google_cse_id", always=True)
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or CONFIG.google_cse_id
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
|
||||
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
|
||||
)
|
||||
return val
|
||||
|
||||
@property
|
||||
def google_api_client(self):
|
||||
build_kwargs = {"developerKey": self.google_api_key}
|
||||
if CONFIG.global_proxy:
|
||||
parse_result = urlparse(CONFIG.global_proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
|
|
@ -42,10 +73,7 @@ class GoogleAPIWrapper:
|
|||
),
|
||||
)
|
||||
service = build("customsearch", "v1", **build_kwargs)
|
||||
self.google_api_client = service.cse()
|
||||
self.custom_search_engine_id = CONFIG.google_cse_id
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
return service.cse()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -69,12 +97,7 @@ class GoogleAPIWrapper:
|
|||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor,
|
||||
self.google_api_client.list(
|
||||
q=query,
|
||||
num=max_results,
|
||||
cx=self.custom_search_engine_id
|
||||
).execute
|
||||
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
|
||||
)
|
||||
try:
|
||||
result = await future
|
||||
|
|
@ -85,13 +108,13 @@ class GoogleAPIWrapper:
|
|||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
|
||||
focus = focus or ["snippet", "link", "title"]
|
||||
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return safe_google_results(details)
|
||||
|
||||
|
||||
return details
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,19 +8,12 @@
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
"""Wrapper around SerpAPI.
|
||||
|
||||
To use, you should have the ``google-search-results`` python package installed,
|
||||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
||||
`serpapi_api_key` as a named parameter to the constructor.
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
params: dict = Field(
|
||||
default={
|
||||
|
|
@ -30,14 +23,25 @@ class SerpAPIWrapper(BaseModel):
|
|||
"hl": "en",
|
||||
}
|
||||
)
|
||||
config = Config()
|
||||
serpapi_api_key: Optional[str] = config.serpapi_api_key
|
||||
serpapi_api_key: Optional[str] = None
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
@validator("serpapi_api_key", always=True)
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or CONFIG.serpapi_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
|
||||
"an API key from https://serpapi.com/."
|
||||
)
|
||||
return val
|
||||
|
||||
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
|
||||
|
|
@ -48,8 +52,6 @@ class SerpAPIWrapper(BaseModel):
|
|||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
url = "https://serpapi.com/search"
|
||||
return url, params
|
||||
|
|
@ -104,7 +106,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
return str(toret) + "\n" + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -9,33 +9,32 @@ import json
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
"""Wrapper around SerpAPI.
|
||||
|
||||
To use, you should have the ``google-search-results`` python package installed,
|
||||
and the environment variable ``SERPAPI_API_KEY`` set with your API key, or pass
|
||||
`serpapi_api_key` as a named parameter to the constructor.
|
||||
"""
|
||||
|
||||
search_engine: Any #: :meta private:
|
||||
payload: dict = Field(
|
||||
default={
|
||||
"page": 1,
|
||||
"num": 10
|
||||
}
|
||||
)
|
||||
config = Config()
|
||||
serper_api_key: Optional[str] = config.serper_api_key
|
||||
payload: dict = Field(default={"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = None
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("serper_api_key", always=True)
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or CONFIG.serper_api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
|
||||
"an API key from https://serper.dev/."
|
||||
)
|
||||
return val
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
if isinstance(query, str):
|
||||
|
|
@ -76,18 +75,17 @@ class SerperWrapper(BaseModel):
|
|||
return json.dumps(payloads, sort_keys=True)
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
headers = {
|
||||
'X-API-KEY': self.serper_api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict, as_string: bool = False) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
def get_focused(x): return {i: j for i, j in x.items() if i in focus}
|
||||
focus = ["title", "snippet", "link"]
|
||||
|
||||
def get_focused(x):
|
||||
return {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
raise ValueError(f"Got error from SerpAPI: {res['error']}")
|
||||
|
|
@ -95,20 +93,11 @@ class SerperWrapper(BaseModel):
|
|||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic"][0].keys():
|
||||
toret = res["organic"][0]["snippet"]
|
||||
|
|
@ -121,7 +110,7 @@ class SerperWrapper(BaseModel):
|
|||
if res.get("organic"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
return str(toret) + "\n" + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue