remove global config in the search/browser engine

This commit is contained in:
shenchucheng 2024-02-01 16:47:29 +08:00
parent 9ecdccd836
commit 9b613eec59
24 changed files with 351 additions and 309 deletions

View file

@ -5,17 +5,20 @@
import asyncio
from metagpt.roles import Searcher
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine, SearchEngineType
async def main():
question = "What are the most interesting human facts?"
kwargs = {"api_key": "", "cse_id": "", "proxy": None}
# Serper API
# await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question)
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question)
# SerpAPI
await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question)
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question)
# Google API
# await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question)
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question)
# DDG API
await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question)
if __name__ == "__main__":

View file

@ -3,15 +3,15 @@
from __future__ import annotations
import asyncio
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, Union
from pydantic import Field, parse_obj_as
from pydantic import TypeAdapter, model_validator
from metagpt.actions import Action
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.tools.search_engine import SearchEngine
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
from metagpt.tools.web_browser_engine import WebBrowserEngine
from metagpt.utils.common import OutputParser
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
@ -81,10 +81,16 @@ class CollectLinks(Action):
name: str = "CollectLinks"
i_context: Optional[str] = None
desc: str = "Collect links from a search engine."
search_engine: SearchEngine = Field(default_factory=SearchEngine)
search_func: Optional[Any] = None
search_engine: Optional[SearchEngine] = None
rank_func: Optional[Callable[[list[str]], None]] = None
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.search_engine is None:
self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy)
return self
async def run(
self,
topic: str,
@ -107,7 +113,7 @@ class CollectLinks(Action):
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
try:
keywords = OutputParser.extract_struct(keywords, list)
keywords = parse_obj_as(list[str], keywords)
keywords = TypeAdapter(list[str]).validate_python(keywords)
except Exception as e:
logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}")
keywords = [topic]
@ -133,7 +139,7 @@ class CollectLinks(Action):
queries = await self._aask(prompt, [system_text])
try:
queries = OutputParser.extract_struct(queries, list)
queries = parse_obj_as(list[str], queries)
queries = TypeAdapter(list[str]).validate_python(queries)
except Exception as e:
logger.exception(f"fail to break down the research question due to {e}")
queries = keywords
@ -178,15 +184,17 @@ class WebBrowseAndSummarize(Action):
i_context: Optional[str] = None
desc: str = "Explore the web and provide summaries of articles and webpages."
browse_func: Union[Callable[[list[str]], None], None] = None
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
web_browser_engine: Optional[WebBrowserEngine] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
run_func=self.browse_func,
)
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.web_browser_engine is None:
self.web_browser_engine = WebBrowserEngine.from_browser_config(
self.config.browser,
browse_func=self.browse_func,
proxy=self.config.proxy,
)
return self
async def run(
self,

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : search_google.py
"""
from typing import Any, Optional
from typing import Optional
import pydantic
from pydantic import model_validator
@ -13,7 +13,6 @@ from pydantic import model_validator
from metagpt.actions import Action
from metagpt.logs import logger
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements
@ -105,21 +104,19 @@ You are a member of a professional butler team and will provide helpful suggesti
class SearchAndSummarize(Action):
name: str = ""
content: Optional[str] = None
engine: Optional[SearchEngineType] = None
search_func: Optional[Any] = None
search_engine: SearchEngine = None
result: str = ""
@model_validator(mode="after")
def validate_engine_and_run_func(self):
if self.engine is None:
self.engine = self.config.search_engine
try:
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
except pydantic.ValidationError:
search_engine = None
def validate_search_engine(self):
if self.search_engine is None:
try:
config = self.config
search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy)
except pydantic.ValidationError:
search_engine = None
self.search_engine = search_engine
self.search_engine = search_engine
return self
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:

View file

@ -51,7 +51,7 @@ class Config(CLIParams, YamlModel):
proxy: str = ""
# Tool Parameters
search: Optional[SearchConfig] = None
search: SearchConfig = SearchConfig()
browser: BrowserConfig = BrowserConfig()
mermaid: MermaidConfig = MermaidConfig()

View file

@ -15,6 +15,6 @@ class BrowserConfig(YamlModel):
"""Config for Browser"""
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
path: str = ""
browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium"
"""If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value
should be either "chrome", "firefox", "edge", or "ie"."""

View file

@ -5,6 +5,8 @@
@Author : alexanderwu
@File : search_config.py
"""
from typing import Callable, Optional
from metagpt.tools import SearchEngineType
from metagpt.utils.yaml_model import YamlModel
@ -12,6 +14,7 @@ from metagpt.utils.yaml_model import YamlModel
class SearchConfig(YamlModel):
"""Config for Search"""
api_key: str
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
api_key: str = ""
cse_id: str = "" # for google
search_func: Optional[Callable] = None

View file

@ -7,7 +7,7 @@
"""
from typing import Optional
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from metagpt.config2 import Config
from metagpt.context import Context
@ -17,7 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
class ContextMixin(BaseModel):
"""Mixin class for context and config"""
model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
# - https://github.com/pydantic/pydantic/issues/7142
@ -32,15 +32,18 @@ class ContextMixin(BaseModel):
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
def __init__(
@model_validator(mode="after")
def validate_extra(self):
self._process_extra(**(self.model_extra or {}))
return self
def _process_extra(
self,
context: Optional[Context] = None,
config: Optional[Config] = None,
llm: Optional[BaseLLM] = None,
**kwargs,
):
"""Initialize with config"""
super().__init__(**kwargs)
"""Process the extra field"""
self.set_context(context)
self.set_config(config)
self.set_llm(llm)

View file

@ -8,12 +8,12 @@
from typing import Optional
from pydantic import Field
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize, UserRequirement
from metagpt.document_store.base_store import BaseStore
from metagpt.roles import Role
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
class Sales(Role):
@ -29,14 +29,13 @@ class Sales(Role):
store: Optional[BaseStore] = Field(default=None, exclude=True)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._set_store(self.store)
def _set_store(self, store):
if store:
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
@model_validator(mode="after")
def validate_stroe(self):
if self.store:
search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy)
action = SearchAndSummarize(search_engine=search_engine, context=self.context)
else:
action = SearchAndSummarize()
action = SearchAndSummarize
self.set_actions([action])
self._watch([UserRequirement])
return self

View file

@ -8,7 +8,9 @@
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
"""
from pydantic import Field
from typing import Optional
from pydantic import Field, model_validator
from metagpt.actions import SearchAndSummarize
from metagpt.actions.action_node import ActionNode
@ -16,7 +18,7 @@ from metagpt.actions.action_output import ActionOutput
from metagpt.logs import logger
from metagpt.roles import Role
from metagpt.schema import Message
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
class Searcher(Role):
@ -28,33 +30,22 @@ class Searcher(Role):
profile (str): Role profile.
goal (str): Goal of the searcher.
constraints (str): Constraints or limitations for the searcher.
engine (SearchEngineType): The type of search engine to use.
search_engine (SearchEngine): The search engine to use.
"""
name: str = Field(default="Alice")
profile: str = Field(default="Smart Assistant")
goal: str = "Provide search services for users"
constraints: str = "Answer is rich and complete"
engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
search_engine: Optional[SearchEngine] = None
def __init__(self, **kwargs) -> None:
"""
Initializes the Searcher role with given attributes.
Args:
name (str): Name of the searcher.
profile (str): Role profile.
goal (str): Goal of the searcher.
constraints (str): Constraints or limitations for the searcher.
engine (SearchEngineType): The type of search engine to use.
"""
super().__init__(**kwargs)
self.set_actions([SearchAndSummarize(engine=self.engine)])
def set_search_func(self, search_func):
"""Sets a custom search function for the searcher."""
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func)
self.set_actions([action])
@model_validator(mode="after")
def post_root(self):
if self.search_engine:
self.set_actions([SearchAndSummarize(search_engine=self.search_engine, context=self.context)])
else:
self.set_actions([SearchAndSummarize])
return self
async def _act_sp(self) -> Message:
"""Performs the search action in a single process."""

View file

@ -8,14 +8,17 @@
import importlib
from typing import Callable, Coroutine, Literal, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from semantic_kernel.skill_definition import sk_function
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
class SkSearchEngine:
def __init__(self):
self.search_engine = SearchEngine()
def __init__(self, **kwargs):
self.search_engine = SearchEngine(**kwargs)
@sk_function(
description="searches results from Google. Useful when you need to find short "
@ -28,43 +31,59 @@ class SkSearchEngine:
return result
class SearchEngine:
"""Class representing a search engine.
class SearchEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
Args:
engine: The search engine type. Defaults to the search engine specified in the config.
run_func: The function to run the search. Defaults to None.
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
api_key: Optional[str] = None
proxy: Optional[str] = None
Attributes:
run_func: The function to run the search.
engine: The search engine type.
"""
@model_validator(mode="after")
def validate_extra(self):
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
if self.model_extra:
data.update(self.model_extra)
self._process_extra(**data)
return self
def __init__(
def _process_extra(
self,
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None,
**kwargs,
):
if engine == SearchEngineType.SERPAPI_GOOGLE:
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
module = "metagpt.tools.search_engine_serpapi"
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
elif engine == SearchEngineType.SERPER_GOOGLE:
elif self.engine == SearchEngineType.SERPER_GOOGLE:
module = "metagpt.tools.search_engine_serper"
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
elif engine == SearchEngineType.DIRECT_GOOGLE:
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
module = "metagpt.tools.search_engine_googleapi"
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
elif engine == SearchEngineType.DUCK_DUCK_GO:
elif self.engine == SearchEngineType.DUCK_DUCK_GO:
module = "metagpt.tools.search_engine_ddg"
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
elif engine == SearchEngineType.CUSTOM_ENGINE:
pass # run_func = run_func
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
run_func = self.run_func
else:
raise NotImplementedError
self.engine = engine
self.run_func = run_func
@classmethod
def from_search_config(cls, config: SearchConfig, **kwargs):
data = config.model_dump(exclude={"api_type", "search_func"})
if config.search_func is not None:
data["run_func"] = config.search_func
return cls(engine=config.api_type, **data, **kwargs)
@classmethod
def from_search_func(
cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs
):
return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs)
@overload
def run(
self,
@ -83,7 +102,13 @@ class SearchEngine:
) -> list[dict[str, str]]:
...
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]:
async def run(
self,
query: str,
max_results: int = 8,
as_string: bool = True,
ignore_errors: bool = False,
) -> Union[str, list[dict[str, str]]]:
"""Run a search query.
Args:
@ -94,4 +119,11 @@ class SearchEngine:
Returns:
The search results as a string or a list of dictionaries.
"""
return await self.run_func(query, max_results=max_results, as_string=as_string)
try:
return await self.run_func(query, max_results=max_results, as_string=as_string)
except Exception as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
if not ignore_errors:
raise e
return "" if as_string else []

View file

@ -5,9 +5,9 @@ from __future__ import annotations
import asyncio
import json
from concurrent import futures
from typing import Literal, overload
from typing import Literal, Optional, overload
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict
try:
from duckduckgo_search import DDGS
@ -18,24 +18,16 @@ except ImportError:
)
class DDGAPIWrapper:
"""Wrapper around duckduckgo_search API.
class DDGAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
To use this module, you should have the `duckduckgo_search` Python package installed.
"""
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
proxy: Optional[str] = None
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
):
kwargs = {}
if config.proxy:
kwargs["proxies"] = config.proxy
self.loop = loop
self.executor = executor
self.ddgs = DDGS(**kwargs)
@property
def ddgs(self):
return DDGS(proxies=self.proxy)
@overload
def run(

View file

@ -4,19 +4,16 @@ from __future__ import annotations
import asyncio
import json
import warnings
from concurrent import futures
from typing import Optional
from urllib.parse import urlparse
import httplib2
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from metagpt.logs import logger
from pydantic import BaseModel, ConfigDict, model_validator
try:
from googleapiclient.discovery import build
from googleapiclient.errors import HttpError
except ImportError:
raise ImportError(
"To use this module, you should have the `google-api-python-client` Python package installed. "
@ -27,40 +24,41 @@ except ImportError:
class GoogleAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
google_api_key: Optional[str] = Field(default=None, validate_default=True)
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
api_key: str
cse_id: str
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
proxy: Optional[str] = None
@field_validator("google_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_google(cls, values: dict) -> dict:
if "google_api_key" in values:
values.setdefault("api_key", values["google_api_key"])
warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
"To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://console.cloud.google.com/apis/credentials."
)
return val
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or config.search.cse_id
if not val:
if "google_cse_id" in values:
values.setdefault("cse_id", values["google_cse_id"])
warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2)
if "cse_id" not in values:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
"To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain "
"the cse_id from https://programmablesearchengine.google.com/controlpanel/create."
)
return val
return values
@property
def google_api_client(self):
build_kwargs = {"developerKey": self.google_api_key}
if config.proxy:
parse_result = urlparse(config.proxy)
build_kwargs = {"developerKey": self.api_key}
if self.proxy:
parse_result = urlparse(self.proxy)
proxy_type = parse_result.scheme
if proxy_type == "https":
proxy_type = "http"
@ -96,17 +94,11 @@ class GoogleAPIWrapper(BaseModel):
"""
loop = self.loop or asyncio.get_event_loop()
future = loop.run_in_executor(
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute
)
try:
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
except HttpError as e:
# Handle errors in the API call
logger.exception(f"fail to search {query} for {e}")
search_results = []
result = await future
# Extract the search result items from the response
search_results = result.get("items", [])
focus = focus or ["snippet", "link", "title"]
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]

View file

@ -5,18 +5,17 @@
@Author : alexanderwu
@File : search_engine_serpapi.py
"""
import warnings
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict, Field, model_validator
class SerpAPIWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
search_engine: Any = None #: :meta private:
api_key: str
params: dict = Field(
default_factory=lambda: {
"engine": "google",
@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel):
"hl": "en",
}
)
# should add `validate_default=True` to check with default value
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@field_validator("serpapi_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_serpapi(cls, values: dict) -> dict:
if "serpapi_api_key" in values:
values.setdefault("api_key", values["serpapi_api_key"])
warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
"an API key from https://serpapi.com/."
"To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain"
" an API key from https://serpapi.com/."
)
return val
return values
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through SerpAPI and parse result async."""
@ -60,11 +60,11 @@ class SerpAPIWrapper(BaseModel):
url, params = construct_url_and_params()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
async with session.get(url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get(url, params=params) as response:
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
@ -73,7 +73,7 @@ class SerpAPIWrapper(BaseModel):
def get_params(self, query: str) -> Dict[str, str]:
"""Get parameters for SerpAPI."""
_params = {
"api_key": self.serpapi_api_key,
"api_key": self.api_key,
"q": query,
}
params = {**self.params, **_params}

View file

@ -6,33 +6,34 @@
@File : search_engine_serpapi.py
"""
import json
import warnings
from typing import Any, Dict, Optional, Tuple
import aiohttp
from pydantic import BaseModel, ConfigDict, Field, field_validator
from metagpt.config2 import config
from pydantic import BaseModel, ConfigDict, Field, model_validator
class SerperWrapper(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
search_engine: Any = None #: :meta private:
api_key: str
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
aiosession: Optional[aiohttp.ClientSession] = None
proxy: Optional[str] = None
@field_validator("serper_api_key", mode="before")
@model_validator(mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or config.search.api_key
if not val:
def validate_serper(cls, values: dict) -> dict:
if "serper_api_key" in values:
values.setdefault("api_key", values["serper_api_key"])
warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
if "api_key" not in values:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
"an API key from https://serper.dev/."
)
return val
return values
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
"""Run query through Serper and parse result async."""
@ -54,11 +55,11 @@ class SerperWrapper(BaseModel):
url, payloads, headers = construct_url_and_payload_and_headers()
if not self.aiosession:
async with aiohttp.ClientSession() as session:
async with session.post(url, data=payloads, headers=headers) as response:
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
else:
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
response.raise_for_status()
res = await response.json()
@ -76,7 +77,7 @@ class SerperWrapper(BaseModel):
return json.dumps(payloads, sort_keys=True)
def get_headers(self) -> Dict[str, str]:
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
return headers
@staticmethod

View file

@ -1,36 +1,49 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import annotations
import importlib
from typing import Any, Callable, Coroutine, overload
from typing import Any, Callable, Coroutine, Optional, Union, overload
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.tools import WebBrowserEngineType
from metagpt.utils.parse_html import WebPage
class WebBrowserEngine:
def __init__(
self,
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
):
if engine is None:
raise NotImplementedError
class WebBrowserEngine(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
proxy: Optional[str] = None
@model_validator(mode="after")
def validate_extra(self):
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
if self.model_extra:
data.update(self.model_extra)
self._process_extra(**data)
return self
def _process_extra(self, **kwargs):
if self.engine is WebBrowserEngineType.PLAYWRIGHT:
module = "metagpt.tools.web_browser_engine_playwright"
run_func = importlib.import_module(module).PlaywrightWrapper().run
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run
elif self.engine is WebBrowserEngineType.SELENIUM:
module = "metagpt.tools.web_browser_engine_selenium"
run_func = importlib.import_module(module).SeleniumWrapper().run
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
run_func = run_func
run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run
elif self.engine is WebBrowserEngineType.CUSTOM:
run_func = self.run_func
else:
raise NotImplementedError
self.run_func = run_func
self.engine = engine
@classmethod
def from_browser_config(cls, config: BrowserConfig, **kwargs):
data = config.model_dump()
return cls(**data, **kwargs)
@overload
async def run(self, url: str) -> WebPage:

View file

@ -6,16 +6,16 @@ from __future__ import annotations
import asyncio
import sys
from pathlib import Path
from typing import Literal
from typing import Literal, Optional
from playwright.async_api import async_playwright
from pydantic import BaseModel, Field, PrivateAttr
from metagpt.config2 import config
from metagpt.logs import logger
from metagpt.utils.parse_html import WebPage
class PlaywrightWrapper:
class PlaywrightWrapper(BaseModel):
"""Wrapper around Playwright.
To use this module, you should have the `playwright` Python package installed and ensure that
@ -24,24 +24,23 @@ class PlaywrightWrapper:
command `playwright install` for the first time.
"""
def __init__(
self,
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
launch_kwargs: dict | None = None,
**kwargs,
) -> None:
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if config.proxy and "proxy" not in launch_kwargs:
browser_type: Literal["chromium", "firefox", "webkit"] = "chromium"
launch_kwargs: dict = Field(default_factory=dict)
proxy: Optional[str] = None
context_kwargs: dict = Field(default_factory=dict)
_has_run_precheck: bool = PrivateAttr(False)
def __init__(self, **kwargs):
super().__init__(**kwargs)
launch_kwargs = self.launch_kwargs
if self.proxy and "proxy" not in launch_kwargs:
args = launch_kwargs.get("args", [])
if not any(str.startswith(i, "--proxy-server=") for i in args):
launch_kwargs["proxy"] = {"server": config.proxy}
self.launch_kwargs = launch_kwargs
context_kwargs = {}
launch_kwargs["proxy"] = {"server": self.proxy}
if "ignore_https_errors" in kwargs:
context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
self._context_kwargs = context_kwargs
self._has_run_precheck = False
self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async with async_playwright() as ap:
@ -55,7 +54,7 @@ class PlaywrightWrapper:
return await _scrape(browser, url)
async def _scrape(self, browser, url):
context = await browser.new_context(**self._context_kwargs)
context = await browser.new_context(**self.context_kwargs)
page = await context.new_page()
async with page:
try:
@ -75,8 +74,8 @@ class PlaywrightWrapper:
executable_path = Path(browser_type.executable_path)
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
kwargs = {}
if config.proxy:
kwargs["env"] = {"ALL_PROXY": config.proxy}
if self.proxy:
kwargs["env"] = {"ALL_PROXY": self.proxy}
await _install_browsers(self.browser_type, **kwargs)
if self._has_run_precheck:

View file

@ -7,19 +7,19 @@ import asyncio
import importlib
from concurrent import futures
from copy import deepcopy
from typing import Literal
from typing import Callable, Literal, Optional
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
from selenium.webdriver.support.wait import WebDriverWait
from webdriver_manager.core.download_manager import WDMDownloadManager
from webdriver_manager.core.http import WDMHttpClient
from metagpt.config2 import config
from metagpt.utils.parse_html import WebPage
class SeleniumWrapper:
class SeleniumWrapper(BaseModel):
"""Wrapper around Selenium.
To use this module, you should check the following:
@ -31,25 +31,28 @@ class SeleniumWrapper:
can scrape web pages using the Selenium WebBrowserEngine.
"""
def __init__(
self,
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
launch_kwargs: dict | None = None,
*,
loop: asyncio.AbstractEventLoop | None = None,
executor: futures.Executor | None = None,
) -> None:
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if config.proxy and "proxy-server" not in launch_kwargs:
launch_kwargs["proxy-server"] = config.proxy
model_config = ConfigDict(arbitrary_types_allowed=True)
self.executable_path = launch_kwargs.pop("executable_path", None)
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
self._has_run_precheck = False
self._get_driver = None
self.loop = loop
self.executor = executor
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
launch_kwargs: dict = Field(default_factory=dict)
proxy: Optional[str] = None
loop: Optional[asyncio.AbstractEventLoop] = None
executor: Optional[futures.Executor] = None
_has_run_precheck: bool = PrivateAttr(False)
_get_driver: Optional[Callable] = PrivateAttr(None)
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
if self.proxy and "proxy-server" not in self.launch_kwargs:
self.launch_kwargs["proxy-server"] = self.proxy
@property
def launch_args(self):
return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"]
@property
def executable_path(self):
return self.launch_kwargs.get("executable_path")
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
await self._run_precheck()
@ -66,7 +69,9 @@ class SeleniumWrapper:
self.loop = self.loop or asyncio.get_event_loop()
self._get_driver = await self.loop.run_in_executor(
self.executor,
lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path),
lambda: _gen_get_driver_func(
self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy
),
)
self._has_run_precheck = True
@ -92,13 +97,17 @@ _webdriver_manager_types = {
class WDMHttpProxyClient(WDMHttpClient):
def __init__(self, proxy: str = None):
super().__init__()
self.proxy = proxy
def get(self, url, **kwargs):
if "proxies" not in kwargs and config.proxy:
kwargs["proxies"] = {"all_proxy": config.proxy}
if "proxies" not in kwargs and self.proxy:
kwargs["proxies"] = {"all_proxy": self.proxy}
return super().get(url, **kwargs)
def _gen_get_driver_func(browser_type, *args, executable_path=None):
def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None):
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options")
@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
if not executable_path:
module_name, type_name = _webdriver_manager_types[browser_type]
DriverManager = getattr(importlib.import_module(module_name), type_name)
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy)))
# driver_manager.driver_cache.find_driver(driver_manager.driver))
executable_path = driver_manager.install()

View file

@ -28,9 +28,9 @@ async def test_collect_links(mocker, search_engine_mocker, context):
return "[1,2]"
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), context=context).run(
"The application of MetaGPT"
)
resp = await research.CollectLinks(
search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), context=context
).run("The application of MetaGPT")
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
assert i in resp
@ -50,7 +50,9 @@ async def test_collect_links_with_rank_func(mocker, search_engine_mocker, contex
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
resp = await research.CollectLinks(
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func, context=context
search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO),
rank_func=rank_func,
context=context,
).run("The application of MetaGPT")
for x, y, z in zip(rank_before, rank_after, resp.values()):
assert x[::-1] == y

View file

@ -16,6 +16,6 @@ async def test_google_search(search_engine_mocker):
result = await google_search(
seed.input,
engine=SearchEngineType.SERPER_GOOGLE,
serper_api_key="mock-serper-key",
api_key="mock-serper-key",
)
assert result != ""

View file

@ -36,7 +36,7 @@ async def test_researcher(mocker, search_engine_mocker, context):
role = researcher.Researcher(context=context)
for i in role.actions:
if isinstance(i, CollectLinks):
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
i.search_engine = SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO)
await role.run(topic)
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")

View file

@ -12,6 +12,7 @@ from typing import Callable
import pytest
from metagpt.config2 import config
from metagpt.configs.search_config import SearchConfig
from metagpt.logs import logger
from metagpt.tools import SearchEngineType
from metagpt.tools.search_engine import SearchEngine
@ -49,27 +50,34 @@ async def test_search_engine(
search_engine_mocker,
):
# Prerequisites
search_engine_config = {}
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
assert config.search
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
search_engine_config["api_key"] = "mock-serpapi-key"
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
assert config.search
search_engine_config["google_api_key"] = "mock-google-key"
search_engine_config["google_cse_id"] = "mock-google-cse"
search_engine_config["api_key"] = "mock-google-key"
search_engine_config["cse_id"] = "mock-google-cse"
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
assert config.search
search_engine_config["serper_api_key"] = "mock-serper-key"
search_engine_config["api_key"] = "mock-serper-key"
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)
else:
assert isinstance(rsp, list)
assert len(rsp) <= max_results
async def test(search_engine):
rsp = await search_engine.run("metagpt", max_results, as_string)
logger.info(rsp)
if as_string:
assert isinstance(rsp, str)
else:
assert isinstance(rsp, list)
assert len(rsp) <= max_results
await test(SearchEngine(**search_engine_config))
search_engine_config["api_type"] = search_engine_config.pop("engine")
if run_func:
await test(SearchEngine.from_search_func(run_func))
search_engine_config["search_func"] = search_engine_config.pop("run_func")
await test(SearchEngine.from_search_config(SearchConfig(**search_engine_config)))
if __name__ == "__main__":

View file

@ -3,7 +3,6 @@
import pytest
from metagpt.config2 import config
from metagpt.tools import web_browser_engine_playwright
from metagpt.utils.parse_html import WebPage
@ -19,26 +18,22 @@ from metagpt.utils.parse_html import WebPage
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
)
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
global_proxy = config.proxy
try:
if use_proxy:
server, proxy_url = await proxy()
config.proxy = proxy_url
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
proxy_url = None
if use_proxy:
server, proxy_url = await proxy()
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
if urls:
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
assert "Proxy:" in capfd.readouterr().out
finally:
config.proxy = global_proxy
if urls:
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
assert "Proxy:" in capfd.readouterr().out
if __name__ == "__main__":

View file

@ -4,7 +4,6 @@
import browsers
import pytest
from metagpt.config2 import config
from metagpt.tools import web_browser_engine_selenium
from metagpt.utils.parse_html import WebPage
@ -40,27 +39,22 @@ from metagpt.utils.parse_html import WebPage
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
# Prerequisites
# firefox, chrome, Microsoft Edge
proxy_url = None
if use_proxy:
server, proxy_url = await proxy()
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
global_proxy = config.proxy
try:
if use_proxy:
server, proxy_url = await proxy()
config.proxy = proxy_url
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type)
result = await browser.run(url)
assert isinstance(result, WebPage)
assert "MetaGPT" in result.inner_text
if urls:
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
assert "Proxy:" in capfd.readouterr().out
finally:
config.proxy = global_proxy
if urls:
results = await browser.run(url, *urls)
assert isinstance(results, list)
assert len(results) == len(urls) + 1
assert all(("MetaGPT" in i.inner_text) for i in results)
if use_proxy:
server.close()
assert "Proxy:" in capfd.readouterr().out
if __name__ == "__main__":

View file

@ -13,7 +13,8 @@ class MockAioResponse:
def __init__(self, session, method, url, **kwargs) -> None:
fn = self.check_funcs.get((method, url))
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
_kwargs = {k: v for k, v in kwargs.items() if k != "proxy"}
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(_kwargs, sort_keys=True)}"
self.mng = self.response = None
if self.key not in self.rsp_cache:
self.mng = origin_request(session, method, url, **kwargs)