mock search engine api

This commit is contained in:
shenchucheng 2024-01-14 23:40:09 +08:00
parent b603e19bda
commit 9d1df5acd5
9 changed files with 1062 additions and 50 deletions

View file

@ -44,19 +44,20 @@ class SearchEngine:
self,
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
**kwargs,
):
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper().run
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper().run
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper().run
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper().run
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:

View file

@ -12,6 +12,7 @@ import logging
import os
import re
import uuid
from typing import Callable
import pytest
@ -20,6 +21,9 @@ from metagpt.context import CONTEXT
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.git_repository import GitRepository
from tests.mock.mock_aiohttp import MockAioResponse
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
from tests.mock.mock_httplib2 import MockHttplib2Response
from tests.mock.mock_llm import MockLLM
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
@ -164,39 +168,63 @@ def new_filename(mocker):
yield mocker
@pytest.fixture(scope="session")
def search_rsp_cache():
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
if os.path.exists(rsp_cache_file_path):
with open(rsp_cache_file_path, "r") as f1:
rsp_cache_json = json.load(f1)
else:
rsp_cache_json = {}
yield rsp_cache_json
with open(rsp_cache_file_path, "w") as f2:
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
@pytest.fixture
def aiohttp_mocker(mocker):
class MockAioResponse:
async def json(self, *args, **kwargs):
return self._json
def set_json(self, json):
self._json = json
response = MockAioResponse()
class MockCTXMng:
async def __aenter__(self):
return response
async def __aexit__(self, *args, **kwargs):
pass
def __await__(self):
yield
return response
def mock_request(self, method, url, **kwargs):
return MockCTXMng()
MockResponse = type("MockResponse", (MockAioResponse,), {})
def wrap(method):
def run(self, url, **kwargs):
return mock_request(self, method, url, **kwargs)
return MockResponse(self, method, url, **kwargs)
return run
mocker.patch("aiohttp.ClientSession.request", mock_request)
mocker.patch("aiohttp.ClientSession.request", MockResponse)
for i in ["get", "post", "delete", "patch"]:
mocker.patch(f"aiohttp.ClientSession.{i}", wrap(i))
yield MockResponse
yield response
@pytest.fixture
def curl_cffi_mocker(mocker):
MockResponse = type("MockResponse", (MockCurlCffiResponse,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("curl_cffi.requests.Session.request", request)
yield MockResponse
@pytest.fixture
def httplib2_mocker(mocker):
MockResponse = type("MockResponse", (MockHttplib2Response,), {})
def request(self, *args, **kwargs):
return MockResponse(self, *args, **kwargs)
mocker.patch("httplib2.Http.request", request)
yield MockResponse
@pytest.fixture
def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, search_rsp_cache):
# aiohttp_mocker: serpapi/serper
# httplib2_mocker: google
# curl_cffi_mocker: ddg
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
yield check_funcs

File diff suppressed because one or more lines are too long

View file

@ -9,10 +9,12 @@
import pytest
from metagpt.actions import research
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def test_collect_links(mocker, search_engine_mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
return "[1,2]"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO)).run(
"The application of MetaGPT"
)
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
async def test_collect_links_with_rank_func(mocker, search_engine_mocker):
rank_before = []
rank_after = []
url_per_query = 4
@ -45,7 +49,9 @@ async def test_collect_links_with_rank_func(mocker):
return results
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
resp = await research.CollectLinks(
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func
).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z

View file

@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
import pytest
from metagpt.actions.research import CollectLinks
from metagpt.roles import researcher
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
async def mock_llm_ask(self, prompt: str, system_msgs):
@ -25,12 +28,16 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
@pytest.mark.asyncio
async def test_researcher(mocker):
async def test_researcher(mocker, search_engine_mocker):
with TemporaryDirectory() as dirname:
topic = "dataiku vs. datarobot"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
role = researcher.Researcher()
for i in role.actions:
if isinstance(i, CollectLinks):
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
await role.run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")

View file

@ -7,20 +7,15 @@
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Callable
import pytest
import tests.data.search
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
search_cache_path = Path(tests.data.search.__path__[0])
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
@ -46,24 +41,28 @@ class MockSearchEnine:
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
async def test_search_engine(
search_engine_type,
run_func: Callable,
max_results: int,
as_string: bool,
search_engine_mocker,
):
# Prerequisites
cache_json_path = None
# FIXME: 不能使用全局的config而是要自己实例化对应的config
search_engine_config = {}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["google_api_key"] = "mock-google-key"
search_engine_config["google_cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
search_engine_config["serper_api_key"] = "mock-serper-key"
if cache_json_path:
with open(cache_json_path) as f:
data = json.load(f)
aiohttp_mocker.set_json(data)
search_engine = SearchEngine(search_engine_type, run_func)
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string:

View file

@ -0,0 +1,41 @@
import json
from typing import Callable
from aiohttp.client import ClientSession
origin_request = ClientSession.request
class MockAioResponse:
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "aiohttp"
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.mng = self.response = None
if self.key not in self.rsp_cache:
self.mng = origin_request(session, method, url, **kwargs)
async def __aenter__(self):
if self.response:
await self.response.__aenter__()
elif self.mng:
self.response = await self.mng.__aenter__()
return self
async def __aexit__(self, *args, **kwargs):
if self.response:
await self.response.__aexit__(*args, **kwargs)
self.response = None
elif self.mng:
await self.mng.__aexit__(*args, **kwargs)
self.mng = None
async def json(self, *args, **kwargs):
if self.key in self.rsp_cache:
return self.rsp_cache[self.key]
data = await self.response.json(*args, **kwargs)
self.rsp_cache[self.key] = data
return data

View file

@ -0,0 +1,22 @@
import json
from typing import Callable
from curl_cffi import requests
origin_request = requests.Session.request
class MockCurlCffiResponse(requests.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "curl-cffi"
def __init__(self, session, method, url, **kwargs) -> None:
super().__init__()
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
self.response = None
if self.key not in self.rsp_cache:
response = origin_request(session, method, url, **kwargs)
self.rsp_cache[self.key] = response.content.decode()
self.content = self.rsp_cache[self.key].encode()

View file

@ -0,0 +1,29 @@
import json
from typing import Callable
from urllib.parse import parse_qsl, urlparse
import httplib2
origin_request = httplib2.Http.request
class MockHttplib2Response(httplib2.Response):
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
rsp_cache: dict[str, str] = {}
name = "httplib2"
def __init__(self, http, uri, method="GET", **kwargs) -> None:
url = uri.split("?")[0]
result = urlparse(uri)
params = dict(parse_qsl(result.query))
fn = self.check_funcs.get((method, uri))
new_kwargs = {"params": params}
key = f"{self.name}-{method}-{url}-{fn(new_kwargs) if fn else json.dumps(new_kwargs)}"
if key not in self.rsp_cache:
_, self.content = origin_request(http, uri, method, **kwargs)
self.rsp_cache[key] = self.content.decode()
self.content = self.rsp_cache[key]
def __iter__(self):
yield self
yield self.content.encode()