mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-30 19:36:24 +02:00
mock search engine api
This commit is contained in:
parent
b603e19bda
commit
9d1df5acd5
9 changed files with 1062 additions and 50 deletions
|
|
@ -44,19 +44,20 @@ class SearchEngine:
|
|||
self,
|
||||
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ import logging
|
|||
import os
|
||||
import re
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -20,6 +21,9 @@ from metagpt.context import CONTEXT
|
|||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
from tests.mock.mock_aiohttp import MockAioResponse
|
||||
from tests.mock.mock_curl_cffi import MockCurlCffiResponse
|
||||
from tests.mock.mock_httplib2 import MockHttplib2Response
|
||||
from tests.mock.mock_llm import MockLLM
|
||||
|
||||
RSP_CACHE_NEW = {} # used globally for producing new and useful only response cache
|
||||
|
|
@ -164,39 +168,63 @@ def new_filename(mocker):
|
|||
yield mocker
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def search_rsp_cache():
|
||||
rsp_cache_file_path = TEST_DATA_PATH / "search_rsp_cache.json" # read repo-provided
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def aiohttp_mocker(mocker):
|
||||
class MockAioResponse:
|
||||
async def json(self, *args, **kwargs):
|
||||
return self._json
|
||||
|
||||
def set_json(self, json):
|
||||
self._json = json
|
||||
|
||||
response = MockAioResponse()
|
||||
|
||||
class MockCTXMng:
|
||||
async def __aenter__(self):
|
||||
return response
|
||||
|
||||
async def __aexit__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def __await__(self):
|
||||
yield
|
||||
return response
|
||||
|
||||
def mock_request(self, method, url, **kwargs):
|
||||
return MockCTXMng()
|
||||
MockResponse = type("MockResponse", (MockAioResponse,), {})
|
||||
|
||||
def wrap(method):
|
||||
def run(self, url, **kwargs):
|
||||
return mock_request(self, method, url, **kwargs)
|
||||
return MockResponse(self, method, url, **kwargs)
|
||||
|
||||
return run
|
||||
|
||||
mocker.patch("aiohttp.ClientSession.request", mock_request)
|
||||
mocker.patch("aiohttp.ClientSession.request", MockResponse)
|
||||
for i in ["get", "post", "delete", "patch"]:
|
||||
mocker.patch(f"aiohttp.ClientSession.{i}", wrap(i))
|
||||
yield MockResponse
|
||||
|
||||
yield response
|
||||
|
||||
@pytest.fixture
|
||||
def curl_cffi_mocker(mocker):
|
||||
MockResponse = type("MockResponse", (MockCurlCffiResponse,), {})
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return MockResponse(self, *args, **kwargs)
|
||||
|
||||
mocker.patch("curl_cffi.requests.Session.request", request)
|
||||
yield MockResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def httplib2_mocker(mocker):
|
||||
MockResponse = type("MockResponse", (MockHttplib2Response,), {})
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
return MockResponse(self, *args, **kwargs)
|
||||
|
||||
mocker.patch("httplib2.Http.request", request)
|
||||
yield MockResponse
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def search_engine_mocker(aiohttp_mocker, curl_cffi_mocker, httplib2_mocker, search_rsp_cache):
|
||||
# aiohttp_mocker: serpapi/serper
|
||||
# httplib2_mocker: google
|
||||
# curl_cffi_mocker: ddg
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
aiohttp_mocker.rsp_cache = httplib2_mocker.rsp_cache = curl_cffi_mocker.rsp_cache = search_rsp_cache
|
||||
aiohttp_mocker.check_funcs = httplib2_mocker.check_funcs = curl_cffi_mocker.check_funcs = check_funcs
|
||||
yield check_funcs
|
||||
|
|
|
|||
879
tests/data/search_rsp_cache.json
Normal file
879
tests/data/search_rsp_cache.json
Normal file
File diff suppressed because one or more lines are too long
|
|
@ -9,10 +9,12 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import research
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links(mocker):
|
||||
async def test_collect_links(mocker, search_engine_mocker):
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
if "Please provide up to 2 necessary keywords" in prompt:
|
||||
return '["metagpt", "llm"]'
|
||||
|
|
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
|
|||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks().run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO)).run(
|
||||
"The application of MetaGPT"
|
||||
)
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_collect_links_with_rank_func(mocker):
|
||||
async def test_collect_links_with_rank_func(mocker, search_engine_mocker):
|
||||
rank_before = []
|
||||
rank_after = []
|
||||
url_per_query = 4
|
||||
|
|
@ -45,7 +49,9 @@ async def test_collect_links_with_rank_func(mocker):
|
|||
return results
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
|
||||
resp = await research.CollectLinks(
|
||||
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func
|
||||
).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
assert [i["link"] for i in y] == z
|
||||
|
|
|
|||
|
|
@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.research import CollectLinks
|
||||
from metagpt.roles import researcher
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
async def mock_llm_ask(self, prompt: str, system_msgs):
|
||||
|
|
@ -25,12 +28,16 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_researcher(mocker):
|
||||
async def test_researcher(mocker, search_engine_mocker):
|
||||
with TemporaryDirectory() as dirname:
|
||||
topic = "dataiku vs. datarobot"
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
researcher.RESEARCH_PATH = Path(dirname)
|
||||
await researcher.Researcher().run(topic)
|
||||
role = researcher.Researcher()
|
||||
for i in role.actions:
|
||||
if isinstance(i, CollectLinks):
|
||||
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
|
||||
await role.run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,20 +7,15 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
import tests.data.search
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
search_cache_path = Path(tests.data.search.__path__[0])
|
||||
|
||||
|
||||
class MockSearchEnine:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
|
|
@ -46,24 +41,28 @@ class MockSearchEnine:
|
|||
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
|
||||
],
|
||||
)
|
||||
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
|
||||
async def test_search_engine(
|
||||
search_engine_type,
|
||||
run_func: Callable,
|
||||
max_results: int,
|
||||
as_string: bool,
|
||||
search_engine_mocker,
|
||||
):
|
||||
# Prerequisites
|
||||
cache_json_path = None
|
||||
# FIXME: 不能使用全局的config,而是要自己实例化对应的config
|
||||
search_engine_config = {}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
|
||||
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["google_api_key"] = "mock-google-key"
|
||||
search_engine_config["google_cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
|
||||
search_engine_config["serper_api_key"] = "mock-serper-key"
|
||||
|
||||
if cache_json_path:
|
||||
with open(cache_json_path) as f:
|
||||
data = json.load(f)
|
||||
aiohttp_mocker.set_json(data)
|
||||
search_engine = SearchEngine(search_engine_type, run_func)
|
||||
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
|
||||
rsp = await search_engine.run("metagpt", max_results, as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
|
|
|
|||
41
tests/mock/mock_aiohttp.py
Normal file
41
tests/mock/mock_aiohttp.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp.client import ClientSession
|
||||
|
||||
origin_request = ClientSession.request
|
||||
|
||||
|
||||
class MockAioResponse:
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "aiohttp"
|
||||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
fn = self.check_funcs.get((method, url))
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
|
||||
self.mng = self.response = None
|
||||
if self.key not in self.rsp_cache:
|
||||
self.mng = origin_request(session, method, url, **kwargs)
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.response:
|
||||
await self.response.__aenter__()
|
||||
elif self.mng:
|
||||
self.response = await self.mng.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args, **kwargs):
|
||||
if self.response:
|
||||
await self.response.__aexit__(*args, **kwargs)
|
||||
self.response = None
|
||||
elif self.mng:
|
||||
await self.mng.__aexit__(*args, **kwargs)
|
||||
self.mng = None
|
||||
|
||||
async def json(self, *args, **kwargs):
|
||||
if self.key in self.rsp_cache:
|
||||
return self.rsp_cache[self.key]
|
||||
data = await self.response.json(*args, **kwargs)
|
||||
self.rsp_cache[self.key] = data
|
||||
return data
|
||||
22
tests/mock/mock_curl_cffi.py
Normal file
22
tests/mock/mock_curl_cffi.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
|
||||
from curl_cffi import requests
|
||||
|
||||
origin_request = requests.Session.request
|
||||
|
||||
|
||||
class MockCurlCffiResponse(requests.Response):
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "curl-cffi"
|
||||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
super().__init__()
|
||||
fn = self.check_funcs.get((method, url))
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
|
||||
self.response = None
|
||||
if self.key not in self.rsp_cache:
|
||||
response = origin_request(session, method, url, **kwargs)
|
||||
self.rsp_cache[self.key] = response.content.decode()
|
||||
self.content = self.rsp_cache[self.key].encode()
|
||||
29
tests/mock/mock_httplib2.py
Normal file
29
tests/mock/mock_httplib2.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import json
|
||||
from typing import Callable
|
||||
from urllib.parse import parse_qsl, urlparse
|
||||
|
||||
import httplib2
|
||||
|
||||
origin_request = httplib2.Http.request
|
||||
|
||||
|
||||
class MockHttplib2Response(httplib2.Response):
|
||||
check_funcs: dict[tuple[str, str], Callable[[dict], str]] = {}
|
||||
rsp_cache: dict[str, str] = {}
|
||||
name = "httplib2"
|
||||
|
||||
def __init__(self, http, uri, method="GET", **kwargs) -> None:
|
||||
url = uri.split("?")[0]
|
||||
result = urlparse(uri)
|
||||
params = dict(parse_qsl(result.query))
|
||||
fn = self.check_funcs.get((method, uri))
|
||||
new_kwargs = {"params": params}
|
||||
key = f"{self.name}-{method}-{url}-{fn(new_kwargs) if fn else json.dumps(new_kwargs)}"
|
||||
if key not in self.rsp_cache:
|
||||
_, self.content = origin_request(http, uri, method, **kwargs)
|
||||
self.rsp_cache[key] = self.content.decode()
|
||||
self.content = self.rsp_cache[key]
|
||||
|
||||
def __iter__(self):
|
||||
yield self
|
||||
yield self.content.encode()
|
||||
Loading…
Add table
Add a link
Reference in a new issue