mock search engine api

This commit is contained in:
shenchucheng 2024-01-14 23:40:09 +08:00
parent b603e19bda
commit 9d1df5acd5
9 changed files with 1062 additions and 50 deletions

View file

@ -9,10 +9,12 @@
import pytest
from metagpt.actions import research
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@pytest.mark.asyncio
async def test_collect_links(mocker):
async def test_collect_links(mocker, search_engine_mocker):
async def mock_llm_ask(self, prompt: str, system_msgs):
if "Please provide up to 2 necessary keywords" in prompt:
return '["metagpt", "llm"]'
@ -26,13 +28,15 @@ async def test_collect_links(mocker):
return "[1,2]"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks().run("The application of MetaGPT")
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO)).run(
"The application of MetaGPT"
)
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@pytest.mark.asyncio
async def test_collect_links_with_rank_func(mocker):
async def test_collect_links_with_rank_func(mocker, search_engine_mocker):
rank_before = []
rank_after = []
url_per_query = 4
@ -45,7 +49,9 @@ async def test_collect_links_with_rank_func(mocker):
return results
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(rank_func=rank_func).run("The application of MetaGPT")
resp = await research.CollectLinks(
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func
).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y
assert [i["link"] for i in y] == z

View file

@ -4,7 +4,10 @@ from tempfile import TemporaryDirectory
import pytest
from metagpt.actions.research import CollectLinks
from metagpt.roles import researcher
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
async def mock_llm_ask(self, prompt: str, system_msgs):
@ -25,12 +28,16 @@ async def mock_llm_ask(self, prompt: str, system_msgs):
@pytest.mark.asyncio
async def test_researcher(mocker):
async def test_researcher(mocker, search_engine_mocker):
with TemporaryDirectory() as dirname:
topic = "dataiku vs. datarobot"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
researcher.RESEARCH_PATH = Path(dirname)
await researcher.Researcher().run(topic)
role = researcher.Researcher()
for i in role.actions:
if isinstance(i, CollectLinks):
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
await role.run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")

View file

@ -7,20 +7,15 @@
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Callable
import pytest
import tests.data.search
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
search_cache_path = Path(tests.data.search.__path__[0])
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
@ -46,24 +41,28 @@ class MockSearchEnine:
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
async def test_search_engine(
search_engine_type,
run_func: Callable,
max_results: int,
as_string: bool,
search_engine_mocker,
):
# Prerequisites
cache_json_path = None
# FIXME: 不能使用全局的config而是要自己实例化对应的config
search_engine_config = {}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["google_api_key"] = "mock-google-key"
search_engine_config["google_cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
search_engine_config["serper_api_key"] = "mock-serper-key"
if cache_json_path:
with open(cache_json_path) as f:
data = json.load(f)
aiohttp_mocker.set_json(data)
search_engine = SearchEngine(search_engine_type, run_func)
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string: