mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-03 12:52:37 +02:00
Merge branch 'dev' into code_intepreter
This commit is contained in:
commit
891e35b92f
108 changed files with 5271 additions and 408 deletions
|
|
@ -8,14 +8,23 @@
|
|||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from semantic_kernel.skill_definition import sk_function
|
||||
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SkSearchEngine:
|
||||
def __init__(self):
|
||||
self.search_engine = SearchEngine()
|
||||
"""A search engine class for executing searches.
|
||||
|
||||
Attributes:
|
||||
search_engine: The search engine instance used for executing searches.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.search_engine = SearchEngine(**kwargs)
|
||||
|
||||
@sk_function(
|
||||
description="searches results from Google. Useful when you need to find short "
|
||||
|
|
@ -28,43 +37,85 @@ class SkSearchEngine:
|
|||
return result
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""Class representing a search engine.
|
||||
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
class SearchEngine(BaseModel):
|
||||
"""A model for configuring and executing searches with different search engines.
|
||||
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
model_config: Configuration for the model allowing arbitrary types.
|
||||
engine: The type of search engine to use.
|
||||
run_func: An optional callable for running the search. If not provided, it will be determined based on the engine.
|
||||
api_key: An optional API key for the search engine.
|
||||
proxy: An optional proxy for the search engine requests.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
|
||||
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
|
||||
api_key: Optional[str] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_extra(self):
|
||||
"""Validates extra fields provided to the model and updates the run function accordingly."""
|
||||
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
|
||||
if self.model_extra:
|
||||
data.update(self.model_extra)
|
||||
self._process_extra(**data)
|
||||
return self
|
||||
|
||||
def _process_extra(
|
||||
self,
|
||||
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
"""Processes extra configuration and updates the run function based on the search engine type.
|
||||
|
||||
Args:
|
||||
run_func: An optional callable for running the search. If not provided, it will be determined based on the engine.
|
||||
"""
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
elif self.engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
run_func = self.run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
@classmethod
|
||||
def from_search_config(cls, config: SearchConfig, **kwargs):
|
||||
"""Creates a SearchEngine instance from a SearchConfig.
|
||||
|
||||
Args:
|
||||
config: The search configuration to use for creating the SearchEngine instance.
|
||||
"""
|
||||
data = config.model_dump(exclude={"api_type", "search_func"})
|
||||
if config.search_func is not None:
|
||||
data["run_func"] = config.search_func
|
||||
|
||||
return cls(engine=config.api_type, **data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_search_func(
|
||||
cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs
|
||||
):
|
||||
"""Creates a SearchEngine instance from a custom search function.
|
||||
|
||||
Args:
|
||||
search_func: A callable that executes the search.
|
||||
"""
|
||||
return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
|
|
@ -83,15 +134,29 @@ class SearchEngine:
|
|||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]:
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
ignore_errors: bool = False,
|
||||
) -> Union[str, list[dict[str, str]]]:
|
||||
"""Run a search query.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The maximum number of results to return. Defaults to 8.
|
||||
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
|
||||
ignore_errors: Whether to ignore errors during the search. Defaults to False.
|
||||
|
||||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
try:
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
except Exception as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
if not ignore_errors:
|
||||
raise e
|
||||
return "" if as_string else []
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
|
|
@ -18,24 +18,16 @@ except ImportError:
|
|||
)
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
class DDGAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
To use this module, you should have the `duckduckgo_search` Python package installed.
|
||||
"""
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if config.proxy:
|
||||
kwargs["proxies"] = config.proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
@property
|
||||
def ddgs(self):
|
||||
return DDGS(proxies=self.proxy)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
|
|
|
|||
|
|
@ -4,19 +4,16 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import warnings
|
||||
from concurrent import futures
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
try:
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use this module, you should have the `google-api-python-client` Python package installed. "
|
||||
|
|
@ -27,40 +24,41 @@ except ImportError:
|
|||
class GoogleAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
google_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
|
||||
api_key: str
|
||||
cse_id: str
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("google_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_google(cls, values: dict) -> dict:
|
||||
if "google_api_key" in values:
|
||||
values.setdefault("api_key", values["google_api_key"])
|
||||
warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
|
||||
"To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
|
||||
"an API key from https://console.cloud.google.com/apis/credentials."
|
||||
)
|
||||
return val
|
||||
|
||||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or config.search.cse_id
|
||||
if not val:
|
||||
if "google_cse_id" in values:
|
||||
values.setdefault("cse_id", values["google_cse_id"])
|
||||
warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "cse_id" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
|
||||
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
|
||||
"To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain "
|
||||
"the cse_id from https://programmablesearchengine.google.com/controlpanel/create."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
@property
|
||||
def google_api_client(self):
|
||||
build_kwargs = {"developerKey": self.google_api_key}
|
||||
if config.proxy:
|
||||
parse_result = urlparse(config.proxy)
|
||||
build_kwargs = {"developerKey": self.api_key}
|
||||
if self.proxy:
|
||||
parse_result = urlparse(self.proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
|
|
@ -96,17 +94,11 @@ class GoogleAPIWrapper(BaseModel):
|
|||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
|
||||
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute
|
||||
)
|
||||
try:
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
focus = focus or ["snippet", "link", "title"]
|
||||
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
|
|
|||
|
|
@ -5,18 +5,17 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
api_key: str
|
||||
params: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"engine": "google",
|
||||
|
|
@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel):
|
|||
"hl": "en",
|
||||
}
|
||||
)
|
||||
# should add `validate_default=True` to check with default value
|
||||
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("serpapi_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_serpapi(cls, values: dict) -> dict:
|
||||
if "serpapi_api_key" in values:
|
||||
values.setdefault("api_key", values["serpapi_api_key"])
|
||||
warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
|
||||
"an API key from https://serpapi.com/."
|
||||
"To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain"
|
||||
" an API key from https://serpapi.com/."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
|
|
@ -60,11 +60,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
url, params = construct_url_and_params()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
async with session.get(url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params) as response:
|
||||
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
def get_params(self, query: str) -> Dict[str, str]:
|
||||
"""Get parameters for SerpAPI."""
|
||||
_params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"api_key": self.api_key,
|
||||
"q": query,
|
||||
}
|
||||
params = {**self.params, **_params}
|
||||
|
|
|
|||
|
|
@ -6,33 +6,34 @@
|
|||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
api_key: str
|
||||
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("serper_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_serper(cls, values: dict) -> dict:
|
||||
if "serper_api_key" in values:
|
||||
values.setdefault("api_key", values["serper_api_key"])
|
||||
warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
|
||||
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
|
||||
"an API key from https://serper.dev/."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
|
|
@ -54,11 +55,11 @@ class SerperWrapper(BaseModel):
|
|||
url, payloads, headers = construct_url_and_payload_and_headers()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=payloads, headers=headers) as response:
|
||||
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
@ -76,7 +77,7 @@ class SerperWrapper(BaseModel):
|
|||
return json.dumps(payloads, sort_keys=True)
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
|
||||
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,36 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
from typing import Any, Callable, Coroutine, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
if engine is None:
|
||||
raise NotImplementedError
|
||||
class WebBrowserEngine(BaseModel):
|
||||
"""Defines a web browser engine configuration for automated browsing and data extraction.
|
||||
|
||||
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
|
||||
This class encapsulates the configuration and operational logic for different web browser engines,
|
||||
such as Playwright, Selenium, or custom implementations. It provides a unified interface to run
|
||||
browser automation tasks.
|
||||
|
||||
Attributes:
|
||||
model_config: Configuration dictionary allowing arbitrary types and extra fields.
|
||||
engine: The type of web browser engine to use.
|
||||
run_func: An optional coroutine function to run the browser engine.
|
||||
proxy: An optional proxy server URL to use with the browser engine.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_extra(self):
|
||||
"""Validates and processes extra configuration data after model initialization.
|
||||
|
||||
This method is automatically called by Pydantic to validate and process any extra configuration
|
||||
data provided to the model. It ensures that the extra data is properly integrated into the model's
|
||||
configuration and operational logic.
|
||||
|
||||
Returns:
|
||||
The instance itself after processing the extra data.
|
||||
"""
|
||||
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
|
||||
if self.model_extra:
|
||||
data.update(self.model_extra)
|
||||
self._process_extra(**data)
|
||||
return self
|
||||
|
||||
def _process_extra(self, **kwargs):
|
||||
"""Processes extra configuration data to set up the browser engine run function.
|
||||
|
||||
Depending on the specified engine type, this method dynamically imports and configures
|
||||
the appropriate browser engine wrapper and its run function.
|
||||
|
||||
Args:
|
||||
**kwargs: Arbitrary keyword arguments representing extra configuration data.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the engine type is not supported.
|
||||
"""
|
||||
if self.engine is WebBrowserEngineType.PLAYWRIGHT:
|
||||
module = "metagpt.tools.web_browser_engine_playwright"
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper().run
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run
|
||||
elif self.engine is WebBrowserEngineType.SELENIUM:
|
||||
module = "metagpt.tools.web_browser_engine_selenium"
|
||||
run_func = importlib.import_module(module).SeleniumWrapper().run
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
|
||||
run_func = run_func
|
||||
run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run
|
||||
elif self.engine is WebBrowserEngineType.CUSTOM:
|
||||
run_func = self.run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.run_func = run_func
|
||||
self.engine = engine
|
||||
|
||||
@classmethod
|
||||
def from_browser_config(cls, config: BrowserConfig, **kwargs):
|
||||
"""Creates a WebBrowserEngine instance from a BrowserConfig object and additional keyword arguments.
|
||||
|
||||
This class method facilitates the creation of a WebBrowserEngine instance by extracting
|
||||
configuration data from a BrowserConfig object and optionally merging it with additional
|
||||
keyword arguments.
|
||||
|
||||
Args:
|
||||
config: A BrowserConfig object containing base configuration data.
|
||||
**kwargs: Optional additional keyword arguments to override or extend the configuration.
|
||||
|
||||
Returns:
|
||||
A new instance of WebBrowserEngine configured according to the provided arguments.
|
||||
"""
|
||||
data = config.model_dump()
|
||||
return cls(**data, **kwargs)
|
||||
|
||||
@overload
|
||||
async def run(self, url: str) -> WebPage:
|
||||
|
|
@ -41,4 +100,16 @@ class WebBrowserEngine:
|
|||
...
|
||||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
"""Runs the browser engine to load one or more web pages.
|
||||
|
||||
This method is the implementation of the overloaded run signatures. It delegates the task
|
||||
of loading web pages to the configured run function, handling either a single URL or multiple URLs.
|
||||
|
||||
Args:
|
||||
url: The URL of the first web page to load.
|
||||
*urls: Additional URLs of web pages to load, if any.
|
||||
|
||||
Returns:
|
||||
A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided.
|
||||
"""
|
||||
return await self.run_func(url, *urls)
|
||||
|
|
|
|||
|
|
@ -6,15 +6,16 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class PlaywrightWrapper:
|
||||
class PlaywrightWrapper(BaseModel):
|
||||
"""Wrapper around Playwright.
|
||||
|
||||
To use this module, you should have the `playwright` Python package installed and ensure that
|
||||
|
|
@ -23,28 +24,23 @@ class PlaywrightWrapper:
|
|||
command `playwright install` for the first time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
|
||||
launch_kwargs: dict | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
from metagpt.config2 import (
|
||||
config, # avoid circular import error when importing tools"
|
||||
)
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] = "chromium"
|
||||
launch_kwargs: dict = Field(default_factory=dict)
|
||||
proxy: Optional[str] = None
|
||||
context_kwargs: dict = Field(default_factory=dict)
|
||||
_has_run_precheck: bool = PrivateAttr(False)
|
||||
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy" not in launch_kwargs:
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
launch_kwargs = self.launch_kwargs
|
||||
if self.proxy and "proxy" not in launch_kwargs:
|
||||
args = launch_kwargs.get("args", [])
|
||||
if not any(str.startswith(i, "--proxy-server=") for i in args):
|
||||
launch_kwargs["proxy"] = {"server": config.proxy}
|
||||
self.launch_kwargs = launch_kwargs
|
||||
context_kwargs = {}
|
||||
launch_kwargs["proxy"] = {"server": self.proxy}
|
||||
|
||||
if "ignore_https_errors" in kwargs:
|
||||
context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
|
||||
self._context_kwargs = context_kwargs
|
||||
self._has_run_precheck = False
|
||||
self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
|
||||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
async with async_playwright() as ap:
|
||||
|
|
@ -58,7 +54,7 @@ class PlaywrightWrapper:
|
|||
return await _scrape(browser, url)
|
||||
|
||||
async def _scrape(self, browser, url):
|
||||
context = await browser.new_context(**self._context_kwargs)
|
||||
context = await browser.new_context(**self.context_kwargs)
|
||||
page = await context.new_page()
|
||||
async with page:
|
||||
try:
|
||||
|
|
@ -78,8 +74,8 @@ class PlaywrightWrapper:
|
|||
executable_path = Path(browser_type.executable_path)
|
||||
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
|
||||
kwargs = {}
|
||||
if config.proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": config.proxy}
|
||||
if self.proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": self.proxy}
|
||||
await _install_browsers(self.browser_type, **kwargs)
|
||||
|
||||
if self._has_run_precheck:
|
||||
|
|
|
|||
|
|
@ -7,19 +7,19 @@ import asyncio
|
|||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from webdriver_manager.core.download_manager import WDMDownloadManager
|
||||
from webdriver_manager.core.http import WDMHttpClient
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class SeleniumWrapper:
|
||||
class SeleniumWrapper(BaseModel):
|
||||
"""Wrapper around Selenium.
|
||||
|
||||
To use this module, you should check the following:
|
||||
|
|
@ -31,25 +31,28 @@ class SeleniumWrapper:
|
|||
can scrape web pages using the Selenium WebBrowserEngine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
|
||||
launch_kwargs: dict | None = None,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
) -> None:
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy-server" not in launch_kwargs:
|
||||
launch_kwargs["proxy-server"] = config.proxy
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
self.executable_path = launch_kwargs.pop("executable_path", None)
|
||||
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
|
||||
self._has_run_precheck = False
|
||||
self._get_driver = None
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
|
||||
launch_kwargs: dict = Field(default_factory=dict)
|
||||
proxy: Optional[str] = None
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
_has_run_precheck: bool = PrivateAttr(False)
|
||||
_get_driver: Optional[Callable] = PrivateAttr(None)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.proxy and "proxy-server" not in self.launch_kwargs:
|
||||
self.launch_kwargs["proxy-server"] = self.proxy
|
||||
|
||||
@property
|
||||
def launch_args(self):
|
||||
return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"]
|
||||
|
||||
@property
|
||||
def executable_path(self):
|
||||
return self.launch_kwargs.get("executable_path")
|
||||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
await self._run_precheck()
|
||||
|
|
@ -66,7 +69,9 @@ class SeleniumWrapper:
|
|||
self.loop = self.loop or asyncio.get_event_loop()
|
||||
self._get_driver = await self.loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path),
|
||||
lambda: _gen_get_driver_func(
|
||||
self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy
|
||||
),
|
||||
)
|
||||
self._has_run_precheck = True
|
||||
|
||||
|
|
@ -92,13 +97,17 @@ _webdriver_manager_types = {
|
|||
|
||||
|
||||
class WDMHttpProxyClient(WDMHttpClient):
|
||||
def __init__(self, proxy: str = None):
|
||||
super().__init__()
|
||||
self.proxy = proxy
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
if "proxies" not in kwargs and config.proxy:
|
||||
kwargs["proxies"] = {"all_proxy": config.proxy}
|
||||
if "proxies" not in kwargs and self.proxy:
|
||||
kwargs["proxies"] = {"all_proxy": self.proxy}
|
||||
return super().get(url, **kwargs)
|
||||
|
||||
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None):
|
||||
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
|
||||
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
|
||||
Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options")
|
||||
|
|
@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
if not executable_path:
|
||||
module_name, type_name = _webdriver_manager_types[browser_type]
|
||||
DriverManager = getattr(importlib.import_module(module_name), type_name)
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy)))
|
||||
# driver_manager.driver_cache.find_driver(driver_manager.driver))
|
||||
executable_path = driver_manager.install()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue