From 07a1d229cf08f89595c10f7d198ca9aa6b0e550d Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Sat, 2 Sep 2023 18:03:31 +0800 Subject: [PATCH] restoresearch engine code --- metagpt/tools/search_engine.py | 33 ++++++++-------- metagpt/tools/search_engine_ddg.py | 48 +++++++++++------------ metagpt/tools/search_engine_googleapi.py | 13 +++--- metagpt/tools/search_engine_serpapi.py | 6 +-- metagpt/tools/search_engine_serper.py | 4 +- tests/metagpt/tools/test_search_engine.py | 19 +++++---- 6 files changed, 62 insertions(+), 61 deletions(-) diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 5b8b7f046..db8c091d1 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -4,12 +4,11 @@ @Time : 2023/5/6 20:15 @Author : alexanderwu @File : search_engine.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from __future__ import annotations import importlib -from typing import Callable, Coroutine, Literal, overload, Dict +from typing import Callable, Coroutine, Literal, overload from metagpt.config import CONFIG from metagpt.tools import SearchEngineType @@ -28,23 +27,23 @@ class SearchEngine: """ def __init__( - self, - engine: SearchEngineType | None = None, - run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None + self, + engine: SearchEngineType | None = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None, ): engine = engine or CONFIG.search_engine if engine == SearchEngineType.SERPAPI_GOOGLE: module = "metagpt.tools.search_engine_serpapi" - run_func = importlib.import_module(module).SerpAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).SerpAPIWrapper().run elif engine == SearchEngineType.SERPER_GOOGLE: module = "metagpt.tools.search_engine_serper" - run_func = importlib.import_module(module).SerperWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).SerperWrapper().run elif engine == SearchEngineType.DIRECT_GOOGLE: module = "metagpt.tools.search_engine_googleapi" - run_func = importlib.import_module(module).GoogleAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).GoogleAPIWrapper().run elif engine == SearchEngineType.DUCK_DUCK_GO: module = "metagpt.tools.search_engine_ddg" - run_func = importlib.import_module(module).DDGAPIWrapper(**CONFIG.options).run + run_func = importlib.import_module(module).DDGAPIWrapper().run elif engine == SearchEngineType.CUSTOM_ENGINE: pass # run_func = run_func else: @@ -54,19 +53,19 @@ class SearchEngine: @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[True] = True, + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, ) -> str: ... @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[False] = False, + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, ) -> list[dict[str, str]]: ... diff --git a/metagpt/tools/search_engine_ddg.py b/metagpt/tools/search_engine_ddg.py index 78562c77e..57bc61b82 100644 --- a/metagpt/tools/search_engine_ddg.py +++ b/metagpt/tools/search_engine_ddg.py @@ -1,14 +1,11 @@ #!/usr/bin/env python -""" -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. -""" from __future__ import annotations import asyncio import json from concurrent import futures -from typing import Literal, overload, Optional +from typing import Literal, overload try: from duckduckgo_search import DDGS @@ -18,6 +15,8 @@ except ImportError: "You can install it by running the command: `pip install -e.[search-ddg]`" ) +from metagpt.config import CONFIG + class DDGAPIWrapper: """Wrapper around duckduckgo_search API. @@ -26,44 +25,43 @@ class DDGAPIWrapper: """ def __init__( - self, - *, - global_proxy: Optional[str] = None, - loop: asyncio.AbstractEventLoop | None = None, - executor: futures.Executor | None = None, + self, + *, + loop: asyncio.AbstractEventLoop | None = None, + executor: futures.Executor | None = None, ): kwargs = {} - if global_proxy: - kwargs["proxies"] = global_proxy + if CONFIG.global_proxy: + kwargs["proxies"] = CONFIG.global_proxy self.loop = loop self.executor = executor self.ddgs = DDGS(**kwargs) @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[True] = True, - focus: list[str] | None = None, + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + focus: list[str] | None = None, ) -> str: ... @overload def run( - self, - query: str, - max_results: int = 8, - as_string: Literal[False] = False, - focus: list[str] | None = None, + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + focus: list[str] | None = None, ) -> list[dict[str, str]]: ... async def run( - self, - query: str, - max_results: int = 8, - as_string: bool = True, + self, + query: str, + max_results: int = 8, + as_string: bool = True, ) -> str | list[dict]: """Return the results of a Google search using the official Google API diff --git a/metagpt/tools/search_engine_googleapi.py b/metagpt/tools/search_engine_googleapi.py index b5aeb5875..b9faf2ced 100644 --- a/metagpt/tools/search_engine_googleapi.py +++ b/metagpt/tools/search_engine_googleapi.py @@ -1,8 +1,5 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. -""" from __future__ import annotations import asyncio @@ -14,6 +11,7 @@ from urllib.parse import urlparse import httplib2 from pydantic import BaseModel, validator +from metagpt.config import CONFIG from metagpt.logs import logger try: @@ -29,7 +27,6 @@ except ImportError: class GoogleAPIWrapper(BaseModel): google_api_key: Optional[str] = None google_cse_id: Optional[str] = None - global_proxy: Optional[str] = None loop: Optional[asyncio.AbstractEventLoop] = None executor: Optional[futures.Executor] = None @@ -39,6 +36,7 @@ class GoogleAPIWrapper(BaseModel): @validator("google_api_key", always=True) @classmethod def check_google_api_key(cls, val: str): + val = val or CONFIG.google_api_key if not val: raise ValueError( "To use, make sure you provide the google_api_key when constructing an object. Alternatively, " @@ -49,7 +47,8 @@ class GoogleAPIWrapper(BaseModel): @validator("google_cse_id", always=True) @classmethod - def check_google_cse_id(cls, val): + def check_google_cse_id(cls, val: str): + val = val or CONFIG.google_cse_id if not val: raise ValueError( "To use, make sure you provide the google_cse_id when constructing an object. Alternatively, " @@ -61,8 +60,8 @@ class GoogleAPIWrapper(BaseModel): @property def google_api_client(self): build_kwargs = {"developerKey": self.google_api_key} - if self.global_proxy: - parse_result = urlparse(self.global_proxy) + if CONFIG.global_proxy: + parse_result = urlparse(CONFIG.global_proxy) proxy_type = parse_result.scheme if proxy_type == "https": proxy_type = "http" diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 1b93a91e9..750184198 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -4,14 +4,13 @@ @Time : 2023/5/23 18:27 @Author : alexanderwu @File : search_engine_serpapi.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from typing import Any, Dict, Optional, Tuple import aiohttp from pydantic import BaseModel, Field, validator -from metagpt.config import Config +from metagpt.config import CONFIG class SerpAPIWrapper(BaseModel): @@ -33,6 +32,7 @@ class SerpAPIWrapper(BaseModel): @validator("serpapi_api_key", always=True) @classmethod def check_serpapi_api_key(cls, val: str): + val = val or CONFIG.serpapi_api_key if not val: raise ValueError( "To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, " @@ -112,4 +112,4 @@ class SerpAPIWrapper(BaseModel): if __name__ == "__main__": import fire - fire.Fire(SerpAPIWrapper(Config().runtime_options).run) + fire.Fire(SerpAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 849839f05..0eec2694b 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -4,7 +4,6 @@ @Time : 2023/5/23 18:27 @Author : alexanderwu @File : search_engine_serpapi.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ import json from typing import Any, Dict, Optional, Tuple @@ -12,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple import aiohttp from pydantic import BaseModel, Field, validator +from metagpt.config import CONFIG + class SerperWrapper(BaseModel): search_engine: Any #: :meta private: @@ -25,6 +26,7 @@ class SerperWrapper(BaseModel): @validator("serper_api_key", always=True) @classmethod def check_serper_api_key(cls, val: str): + val = val or CONFIG.serper_api_key if not val: raise ValueError( "To use, make sure you provide the serper_api_key when constructing an object. Alternatively, " diff --git a/tests/metagpt/tools/test_search_engine.py b/tests/metagpt/tools/test_search_engine.py index 35ccdf78b..25bce124a 100644 --- a/tests/metagpt/tools/test_search_engine.py +++ b/tests/metagpt/tools/test_search_engine.py @@ -4,13 +4,11 @@ @Time : 2023/5/2 17:46 @Author : alexanderwu @File : test_search_engine.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation. """ from __future__ import annotations import pytest -from metagpt.config import Config from metagpt.logs import logger from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -18,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine class MockSearchEnine: async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: - rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)] + rets = [ + {"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results) + ] return "\n".join(rets) if as_string else rets @@ -36,13 +36,16 @@ class MockSearchEnine: (SearchEngineType.DUCK_DUCK_GO, None, 6, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False), (SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False), - ], ) -async def test_search_engine(search_engine_typpe, run_func, max_results, as_string): - conf = Config() - search_engine = SearchEngine(options=conf.runtime_options, engine=search_engine_typpe, run_func=run_func) - rsp = await search_engine.run(query="metagpt", max_results=max_results, as_string=as_string) +async def test_search_engine( + search_engine_typpe, + run_func, + max_results, + as_string, +): + search_engine = SearchEngine(search_engine_typpe, run_func) + rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string) logger.info(rsp) if as_string: assert isinstance(rsp, str)