mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-17 15:35:21 +02:00
Unified Search Engine API with support for returning structured data.
This commit is contained in:
parent
930d18962f
commit
c62c870ab9
7 changed files with 351 additions and 144 deletions
|
|
@ -14,6 +14,7 @@ class SearchEngineType(Enum):
|
|||
SERPAPI_GOOGLE = auto()
|
||||
DIRECT_GOOGLE = auto()
|
||||
SERPER_GOOGLE = auto()
|
||||
DUCK_DUCK_GO = auto()
|
||||
CUSTOM_ENGINE = auto()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,122 +7,76 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
|
||||
from metagpt.tools.search_engine_serper import SerperWrapper
|
||||
|
||||
config = Config()
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
TODO: 合入Google Search 并进行反代
|
||||
注:这里Google需要挂Proxifier或者类似全局代理
|
||||
- DDG: https://pypi.org/project/duckduckgo-search/
|
||||
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
|
||||
"""
|
||||
def __init__(self, engine=None, run_func=None):
|
||||
self.config = Config()
|
||||
self.run_func = run_func
|
||||
self.engine = engine or self.config.search_engine
|
||||
"""Class representing a search engine.
|
||||
|
||||
@classmethod
|
||||
def run_google(cls, query, max_results=8):
|
||||
# results = ddg(query, max_results=max_results)
|
||||
results = google_official_search(query, num_results=max_results)
|
||||
logger.info(results)
|
||||
return results
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
|
||||
async def run(self, query: str, max_results=8):
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
api = SerpAPIWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
rsp = SearchEngine.run_google(query, max_results)
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
api = SerperWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
rsp = self.run_func(query)
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rsp
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
"""Run a search query.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The maximum number of results to return. Defaults to 8.
|
||||
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
try:
|
||||
api_key = config.google_api_key
|
||||
custom_search_engine_id = config.google_cse_id
|
||||
|
||||
with build("customsearch", "v1", developerKey=api_key) as service:
|
||||
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
logger.info(result)
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
return "Error: The provided Google API key is invalid or missing."
|
||||
else:
|
||||
return f"Error: {e}"
|
||||
# google_result can be a list or a string depending on the search results
|
||||
|
||||
# Return the list of search result URLs
|
||||
return search_results_details
|
||||
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
# FIXME: # .encode("utf-8", "ignore") 这里去掉了,但是AutoGPT里有,很奇怪
|
||||
[result for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
SearchEngine.run(query='wtf')
|
||||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
|
|
|
|||
107
metagpt/tools/search_engine_ddg.py
Normal file
107
metagpt/tools/search_engine_ddg.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
|
||||
from duckduckgo_search import DDGS
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
|
||||
To use this module, you should have the `duckduckgo_search` Python package installed.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if CONFIG.global_proxy:
|
||||
kwargs["proxies"] = CONFIG.global_proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
focus: list[str] | None = None,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
focus: list[str] | None = None,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
) -> str | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The number of results to return.
|
||||
as_string: A boolean flag to determine the return type of the results. If True, the function will
|
||||
return a formatted string with the search results. If False, it will return a list of dictionaries
|
||||
containing detailed information about each search result.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor,
|
||||
self._search_from_ddgs,
|
||||
query,
|
||||
max_results,
|
||||
)
|
||||
try:
|
||||
search_results = await future
|
||||
# Extract the search result items from the response
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return json.dumps(search_results, ensure_ascii=False)
|
||||
return search_results
|
||||
|
||||
def _search_from_ddgs(self, query: str, max_results: int):
|
||||
return [
|
||||
{
|
||||
"link": i["href"],
|
||||
"snippet": i["body"],
|
||||
"title": i["title"]
|
||||
} for (_, i) in zip(range(max_results), self.ddgs.text(query))
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(DDGAPIWrapper().run)
|
||||
117
metagpt/tools/search_engine_googleapi.py
Normal file
117
metagpt/tools/search_engine_googleapi.py
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class GoogleAPIWrapper:
|
||||
"""Wrapper around GoogleAPI.
|
||||
|
||||
To use this module, you should have the `google-api-python-client` Python package installed
|
||||
and set property values for the configurations `GOOGLE_API_KEY` and `GOOGLE_CSE_ID`. See
|
||||
https://programmablesearchengine.google.com/controlpanel/all.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
build_kwargs = {"developerKey": CONFIG.google_api_key}
|
||||
if CONFIG.global_proxy:
|
||||
parse_result = urlparse(CONFIG.global_proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
build_kwargs["http"] = httplib2.Http(
|
||||
proxy_info=httplib2.ProxyInfo(
|
||||
getattr(httplib2.socks, f"PROXY_TYPE_{proxy_type.upper()}"),
|
||||
parse_result.hostname,
|
||||
parse_result.port,
|
||||
),
|
||||
)
|
||||
service = build("customsearch", "v1", **build_kwargs)
|
||||
self.google_api_client = service.cse()
|
||||
self.custom_search_engine_id = CONFIG.google_cse_id
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
focus: list[str] | None = None,
|
||||
) -> str | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The number of results to return.
|
||||
as_string: A boolean flag to determine the return type of the results. If True, the function will
|
||||
return a formatted string with the search results. If False, it will return a list of dictionaries
|
||||
containing detailed information about each search result.
|
||||
focus: Specific information to be focused on from each search result.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor,
|
||||
self.google_api_client.list(
|
||||
q=query,
|
||||
num=max_results,
|
||||
cx=self.custom_search_engine_id
|
||||
).execute
|
||||
)
|
||||
try:
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
|
||||
focus = focus or ["snippet", "link", "title"]
|
||||
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
# Return the list of search result URLs
|
||||
if as_string:
|
||||
return safe_google_results(details)
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results: The search results.
|
||||
|
||||
Returns:
|
||||
The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps([result for result in results])
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(GoogleAPIWrapper().run)
|
||||
|
|
@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query))
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
|
||||
async def results(self, query: str) -> dict:
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
|
|
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
focus = ["title", "snippet", "link"]
|
||||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
|
|
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
|
|
@ -112,4 +104,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerpAPIWrapper().run)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
queries = query.split("\n")
|
||||
return "\n".join([self._process_response(res) for res in await self.results(queries)])
|
||||
if isinstance(query, str):
|
||||
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
|
||||
else:
|
||||
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
|
||||
return "\n".join(results) if as_string else results
|
||||
|
||||
async def results(self, queries: list[str]) -> dict:
|
||||
async def results(self, queries: list[str], max_results: int = 8) -> dict:
|
||||
"""Use aiohttp to run query through Serper and return the results async."""
|
||||
|
||||
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
|
||||
payloads = self.get_payloads(queries)
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = self.get_headers()
|
||||
return url, payloads, headers
|
||||
|
|
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
|
|||
|
||||
return res
|
||||
|
||||
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
|
||||
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
|
||||
"""Get payloads for Serper."""
|
||||
payloads = []
|
||||
for query in queries:
|
||||
_payload = {
|
||||
"q": query,
|
||||
"num": max_results,
|
||||
}
|
||||
payloads.append({**self.payload, **_payload})
|
||||
return json.dumps(payloads, sort_keys=True)
|
||||
|
|
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
|
|||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool = False) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
|
|
@ -117,4 +121,10 @@ class SerperWrapper(BaseModel):
|
|||
if res.get("organic"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerperWrapper().run)
|
||||
|
|
|
|||
|
|
@ -5,24 +5,44 @@
|
|||
@Author : alexanderwu
|
||||
@File : test_search_engine.py
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
|
||||
return "\n".join(rets) if as_string else rets
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
async def test_search_engine(llm_api):
|
||||
search_engine = SearchEngine()
|
||||
poetries = [
|
||||
# ("北京美食", "北京"),
|
||||
("屈臣氏", "屈臣氏")
|
||||
]
|
||||
for i, j in poetries:
|
||||
rsp = await search_engine.run(i)
|
||||
# rsp = context.llm.ask_batch([prompt])
|
||||
logger.info(rsp)
|
||||
# assert any(j in k['body'] for k in rsp)
|
||||
assert len(rsp) > 0
|
||||
@pytest.mark.parametrize(
|
||||
("search_engine_typpe", "run_func", "max_results", "as_string"),
|
||||
[
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPAPI_GOOGLE, None, 4, False),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.DIRECT_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 8, True),
|
||||
(SearchEngineType.SERPER_GOOGLE, None, 6, False),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 8, True),
|
||||
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
|
||||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string, ):
|
||||
search_engine = SearchEngine(search_engine_typpe, run_func)
|
||||
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) == max_results
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue