restoresearch engine code

This commit is contained in:
shenchucheng 2023-09-02 18:03:31 +08:00
parent 635ed09372
commit 07a1d229cf
6 changed files with 62 additions and 61 deletions

View file

@ -4,12 +4,11 @@
@Time : 2023/5/6 20:15
@Author : alexanderwu
@File : search_engine.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import importlib
from typing import Callable, Coroutine, Literal, overload, Dict
from typing import Callable, Coroutine, Literal, overload
from metagpt.config import CONFIG
from metagpt.tools import SearchEngineType
@ -28,23 +27,23 @@ class SearchEngine:
"""
def __init__(
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None
self,
engine: SearchEngineType | None = None,
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
):
engine = engine or CONFIG.search_engine
if engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).SerpAPIWrapper().run
elif engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).SerperWrapper().run
elif engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).GoogleAPIWrapper().run
elif engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper(**CONFIG.options).run
run_func = importlib.import_module(module).DDGAPIWrapper().run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
else:
@ -54,19 +53,19 @@ class SearchEngine:
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
) -> list[dict[str, str]]:
...

View file

@ -1,14 +1,11 @@
#!/usr/bin/env python
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload, Optional
from typing import Literal, overload
try:
from duckduckgo_search import DDGS
@ -18,6 +15,8 @@ except ImportError:
"You can install it by running the command: `pip install -e.[search-ddg]`"
)
from metagpt.config import CONFIG
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
@ -26,44 +25,43 @@ class DDGAPIWrapper:
"""
def __init__(
self,
*,
global_proxy: Optional[str] = None,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if global_proxy:
kwargs["proxies"] = global_proxy
if CONFIG.global_proxy:
kwargs["proxies"] = CONFIG.global_proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
self,
query: str,
max_results: int = 8,
as_string: Literal[True] = True,
focus: list[str] | None = None,
) -> str:
...
@overload
def run(
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
self,
query: str,
max_results: int = 8,
as_string: Literal[False] = False,
focus: list[str] | None = None,
) -> list[dict[str, str]]:
...
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
self,
query: str,
max_results: int = 8,
as_string: bool = True,
) -> str | list[dict]:
"""Return the results of a Google search using the official Google API

View file

@ -1,8 +1,5 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import asyncio
@ -14,6 +11,7 @@ from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, validator
from metagpt.config import CONFIG
from metagpt.logs import logger
try:
@ -29,7 +27,6 @@ except ImportError:
class GoogleAPIWrapper(BaseModel):
google_api_key: Optional[str] = None
google_cse_id: Optional[str] = None
global_proxy: Optional[str] = None
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
@ -39,6 +36,7 @@ class GoogleAPIWrapper(BaseModel):
@validator("google_api_key", always=True)
@classmethod
def check_google_api_key(cls, val: str):
val = val or CONFIG.google_api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
@ -49,7 +47,8 @@ class GoogleAPIWrapper(BaseModel):
@validator("google_cse_id", always=True)
@classmethod
def check_google_cse_id(cls, val):
def check_google_cse_id(cls, val: str):
val = val or CONFIG.google_cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
@ -61,8 +60,8 @@ class GoogleAPIWrapper(BaseModel):
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if self.global_proxy:
parse_result = urlparse(self.global_proxy)
if CONFIG.global_proxy:
parse_result = urlparse(CONFIG.global_proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"

View file

@ -4,14 +4,13 @@
@Time : 2023/5/23 18:27
@Author : alexanderwu
@File : search_engine_serpapi.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from metagpt.config import Config
from metagpt.config import CONFIG
class SerpAPIWrapper(BaseModel):
@ -33,6 +32,7 @@ class SerpAPIWrapper(BaseModel):
@validator("serpapi_api_key", always=True)
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or CONFIG.serpapi_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
@ -112,4 +112,4 @@ class SerpAPIWrapper(BaseModel):
if __name__ == "__main__":
import fire
fire.Fire(SerpAPIWrapper(Config().runtime_options).run)
fire.Fire(SerpAPIWrapper().run)

View file

@ -4,7 +4,6 @@
@Time : 2023/5/23 18:27
@Author : alexanderwu
@File : search_engine_serpapi.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
import json
from typing import Any, Dict, Optional, Tuple
@ -12,6 +11,8 @@ from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, Field, validator
from metagpt.config import CONFIG
class SerperWrapper(BaseModel):
search_engine: Any #: :meta private:
@ -25,6 +26,7 @@ class SerperWrapper(BaseModel):
@validator("serper_api_key", always=True)
@classmethod
def check_serper_api_key(cls, val: str):
val = val or CONFIG.serper_api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "

View file

@ -4,13 +4,11 @@
@Time : 2023/5/2 17:46
@Author : alexanderwu
@File : test_search_engine.py
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
"""
from __future__ import annotations
import pytest
from metagpt.config import Config
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -18,7 +16,9 @@ from metagpt.tools.search_engine import SearchEngine
class MockSearchEnine:
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
rets = [{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)]
rets = [
{"url": "https://metagpt.com/mock/{i}", "title": query, "snippet": query * i} for i in range(max_results)
]
return "\n".join(rets) if as_string else rets
@ -36,13 +36,16 @@ class MockSearchEnine:
(SearchEngineType.DUCK_DUCK_GO, None, 6, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 8, False),
(SearchEngineType.CUSTOM_ENGINE, MockSearchEnine().run, 6, False),
],
)
async def test_search_engine(search_engine_typpe, run_func, max_results, as_string):
conf = Config()
search_engine = SearchEngine(options=conf.runtime_options, engine=search_engine_typpe, run_func=run_func)
rsp = await search_engine.run(query="metagpt", max_results=max_results, as_string=as_string)
async def test_search_engine(
search_engine_typpe,
run_func,
max_results,
as_string,
):
search_engine = SearchEngine(search_engine_typpe, run_func)
rsp = await search_engine.run("metagpt", max_results=max_results, as_string=as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)