mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
remove global config in the search/browser engine
This commit is contained in:
parent
9ecdccd836
commit
9b613eec59
24 changed files with 351 additions and 309 deletions
|
|
@ -5,17 +5,20 @@
|
|||
import asyncio
|
||||
|
||||
from metagpt.roles import Searcher
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine, SearchEngineType
|
||||
|
||||
|
||||
async def main():
|
||||
question = "What are the most interesting human facts?"
|
||||
kwargs = {"api_key": "", "cse_id": "", "proxy": None}
|
||||
# Serper API
|
||||
# await Searcher(engine=SearchEngineType.SERPER_GOOGLE).run(question)
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPER_GOOGLE, **kwargs)).run(question)
|
||||
# SerpAPI
|
||||
await Searcher(engine=SearchEngineType.SERPAPI_GOOGLE).run(question)
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.SERPAPI_GOOGLE, **kwargs)).run(question)
|
||||
# Google API
|
||||
# await Searcher(engine=SearchEngineType.DIRECT_GOOGLE).run(question)
|
||||
# await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DIRECT_GOOGLE, **kwargs)).run(question)
|
||||
# DDG API
|
||||
await Searcher(search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO, **kwargs)).run(question)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,15 +3,15 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pydantic import Field, parse_obj_as
|
||||
from pydantic import TypeAdapter, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.text import generate_prompt_chunk, reduce_message_length
|
||||
|
||||
|
|
@ -81,10 +81,16 @@ class CollectLinks(Action):
|
|||
name: str = "CollectLinks"
|
||||
i_context: Optional[str] = None
|
||||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: Optional[SearchEngine] = None
|
||||
rank_func: Optional[Callable[[list[str]], None]] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_engine_and_run_func(self):
|
||||
if self.search_engine is None:
|
||||
self.search_engine = SearchEngine.from_search_config(self.config.search, proxy=self.config.proxy)
|
||||
return self
|
||||
|
||||
async def run(
|
||||
self,
|
||||
topic: str,
|
||||
|
|
@ -107,7 +113,7 @@ class CollectLinks(Action):
|
|||
keywords = await self._aask(SEARCH_TOPIC_PROMPT, [system_text])
|
||||
try:
|
||||
keywords = OutputParser.extract_struct(keywords, list)
|
||||
keywords = parse_obj_as(list[str], keywords)
|
||||
keywords = TypeAdapter(list[str]).validate_python(keywords)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to get keywords related to the research topic '{topic}' for {e}")
|
||||
keywords = [topic]
|
||||
|
|
@ -133,7 +139,7 @@ class CollectLinks(Action):
|
|||
queries = await self._aask(prompt, [system_text])
|
||||
try:
|
||||
queries = OutputParser.extract_struct(queries, list)
|
||||
queries = parse_obj_as(list[str], queries)
|
||||
queries = TypeAdapter(list[str]).validate_python(queries)
|
||||
except Exception as e:
|
||||
logger.exception(f"fail to break down the research question due to {e}")
|
||||
queries = keywords
|
||||
|
|
@ -178,15 +184,17 @@ class WebBrowseAndSummarize(Action):
|
|||
i_context: Optional[str] = None
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: Optional[WebBrowserEngine] = WebBrowserEngineType.PLAYWRIGHT
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
@model_validator(mode="after")
|
||||
def validate_engine_and_run_func(self):
|
||||
if self.web_browser_engine is None:
|
||||
self.web_browser_engine = WebBrowserEngine.from_browser_config(
|
||||
self.config.browser,
|
||||
browse_func=self.browse_func,
|
||||
proxy=self.config.proxy,
|
||||
)
|
||||
return self
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_google.py
|
||||
"""
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import model_validator
|
||||
|
|
@ -13,7 +13,6 @@ from pydantic import model_validator
|
|||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
SEARCH_AND_SUMMARIZE_SYSTEM = """### Requirements
|
||||
|
|
@ -105,21 +104,19 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
engine: Optional[SearchEngineType] = None
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
result: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_engine_and_run_func(self):
|
||||
if self.engine is None:
|
||||
self.engine = self.config.search_engine
|
||||
try:
|
||||
search_engine = SearchEngine(engine=self.engine, run_func=self.search_func)
|
||||
except pydantic.ValidationError:
|
||||
search_engine = None
|
||||
def validate_search_engine(self):
|
||||
if self.search_engine is None:
|
||||
try:
|
||||
config = self.config
|
||||
search_engine = SearchEngine.from_search_config(config.search, proxy=config.proxy)
|
||||
except pydantic.ValidationError:
|
||||
search_engine = None
|
||||
|
||||
self.search_engine = search_engine
|
||||
self.search_engine = search_engine
|
||||
return self
|
||||
|
||||
async def run(self, context: list[Message], system_text=SEARCH_AND_SUMMARIZE_SYSTEM) -> str:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class Config(CLIParams, YamlModel):
|
|||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
search: Optional[SearchConfig] = None
|
||||
search: SearchConfig = SearchConfig()
|
||||
browser: BrowserConfig = BrowserConfig()
|
||||
mermaid: MermaidConfig = MermaidConfig()
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ class BrowserConfig(YamlModel):
|
|||
"""Config for Browser"""
|
||||
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
browser: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
|
||||
driver: Literal["chromium", "firefox", "webkit"] = "chromium"
|
||||
path: str = ""
|
||||
browser_type: Literal["chromium", "firefox", "webkit", "chrome", "firefox", "edge", "ie"] = "chromium"
|
||||
"""If the engine is Playwright, the value should be one of "chromium", "firefox", or "webkit". If it is Selenium, the value
|
||||
should be either "chrome", "firefox", "edge", or "ie"."""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_config.py
|
||||
"""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
||||
|
|
@ -12,6 +14,7 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class SearchConfig(YamlModel):
|
||||
"""Config for Search"""
|
||||
|
||||
api_key: str
|
||||
api_type: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
|
||||
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
|
||||
api_key: str = ""
|
||||
cse_id: str = "" # for google
|
||||
search_func: Optional[Callable] = None
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.context import Context
|
||||
|
|
@ -17,7 +17,7 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
class ContextMixin(BaseModel):
|
||||
"""Mixin class for context and config"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
# Pydantic has bug on _private_attr when using inheritance, so we use private_* instead
|
||||
# - https://github.com/pydantic/pydantic/issues/7142
|
||||
|
|
@ -32,15 +32,18 @@ class ContextMixin(BaseModel):
|
|||
# Env/Role/Action will use this llm as private llm, or use self.context._llm instance
|
||||
private_llm: Optional[BaseLLM] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(
|
||||
@model_validator(mode="after")
|
||||
def validate_extra(self):
|
||||
self._process_extra(**(self.model_extra or {}))
|
||||
return self
|
||||
|
||||
def _process_extra(
|
||||
self,
|
||||
context: Optional[Context] = None,
|
||||
config: Optional[Config] = None,
|
||||
llm: Optional[BaseLLM] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize with config"""
|
||||
super().__init__(**kwargs)
|
||||
"""Process the extra field"""
|
||||
self.set_context(context)
|
||||
self.set_config(config)
|
||||
self.set_llm(llm)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import SearchAndSummarize, UserRequirement
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
from metagpt.roles import Role
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
class Sales(Role):
|
||||
|
|
@ -29,14 +29,13 @@ class Sales(Role):
|
|||
|
||||
store: Optional[BaseStore] = Field(default=None, exclude=True)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._set_store(self.store)
|
||||
|
||||
def _set_store(self, store):
|
||||
if store:
|
||||
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=store.asearch)
|
||||
@model_validator(mode="after")
|
||||
def validate_stroe(self):
|
||||
if self.store:
|
||||
search_engine = SearchEngine.from_search_func(search_func=self.store.asearch, proxy=self.config.proxy)
|
||||
action = SearchAndSummarize(search_engine=search_engine, context=self.context)
|
||||
else:
|
||||
action = SearchAndSummarize()
|
||||
action = SearchAndSummarize
|
||||
self.set_actions([action])
|
||||
self._watch([UserRequirement])
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@
|
|||
the `cause_by` value in the `Message` to a string to support the new message distribution feature.
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import SearchAndSummarize
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -16,7 +18,7 @@ from metagpt.actions.action_output import ActionOutput
|
|||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
||||
|
||||
class Searcher(Role):
|
||||
|
|
@ -28,33 +30,22 @@ class Searcher(Role):
|
|||
profile (str): Role profile.
|
||||
goal (str): Goal of the searcher.
|
||||
constraints (str): Constraints or limitations for the searcher.
|
||||
engine (SearchEngineType): The type of search engine to use.
|
||||
search_engine (SearchEngine): The search engine to use.
|
||||
"""
|
||||
|
||||
name: str = Field(default="Alice")
|
||||
profile: str = Field(default="Smart Assistant")
|
||||
goal: str = "Provide search services for users"
|
||||
constraints: str = "Answer is rich and complete"
|
||||
engine: SearchEngineType = SearchEngineType.SERPAPI_GOOGLE
|
||||
search_engine: Optional[SearchEngine] = None
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initializes the Searcher role with given attributes.
|
||||
|
||||
Args:
|
||||
name (str): Name of the searcher.
|
||||
profile (str): Role profile.
|
||||
goal (str): Goal of the searcher.
|
||||
constraints (str): Constraints or limitations for the searcher.
|
||||
engine (SearchEngineType): The type of search engine to use.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.set_actions([SearchAndSummarize(engine=self.engine)])
|
||||
|
||||
def set_search_func(self, search_func):
|
||||
"""Sets a custom search function for the searcher."""
|
||||
action = SearchAndSummarize(name="", engine=SearchEngineType.CUSTOM_ENGINE, search_func=search_func)
|
||||
self.set_actions([action])
|
||||
@model_validator(mode="after")
|
||||
def post_root(self):
|
||||
if self.search_engine:
|
||||
self.set_actions([SearchAndSummarize(search_engine=self.search_engine, context=self.context)])
|
||||
else:
|
||||
self.set_actions([SearchAndSummarize])
|
||||
return self
|
||||
|
||||
async def _act_sp(self) -> Message:
|
||||
"""Performs the search action in a single process."""
|
||||
|
|
|
|||
|
|
@ -8,14 +8,17 @@
|
|||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
from semantic_kernel.skill_definition import sk_function
|
||||
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SkSearchEngine:
|
||||
def __init__(self):
|
||||
self.search_engine = SearchEngine()
|
||||
def __init__(self, **kwargs):
|
||||
self.search_engine = SearchEngine(**kwargs)
|
||||
|
||||
@sk_function(
|
||||
description="searches results from Google. Useful when you need to find short "
|
||||
|
|
@ -28,43 +31,59 @@ class SkSearchEngine:
|
|||
return result
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""Class representing a search engine.
|
||||
class SearchEngine(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
engine: SearchEngineType = SearchEngineType.SERPER_GOOGLE
|
||||
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None
|
||||
api_key: Optional[str] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
"""
|
||||
@model_validator(mode="after")
|
||||
def validate_extra(self):
|
||||
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
|
||||
if self.model_extra:
|
||||
data.update(self.model_extra)
|
||||
self._process_extra(**data)
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
def _process_extra(
|
||||
self,
|
||||
engine: Optional[SearchEngineType] = SearchEngineType.SERPER_GOOGLE,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
run_func: Optional[Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
elif self.engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper(**kwargs).run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
run_func = self.run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
@classmethod
|
||||
def from_search_config(cls, config: SearchConfig, **kwargs):
|
||||
data = config.model_dump(exclude={"api_type", "search_func"})
|
||||
if config.search_func is not None:
|
||||
data["run_func"] = config.search_func
|
||||
|
||||
return cls(engine=config.api_type, **data, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_search_func(
|
||||
cls, search_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]], **kwargs
|
||||
):
|
||||
return cls(engine=SearchEngineType.CUSTOM_ENGINE, run_func=search_func, **kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
|
|
@ -83,7 +102,13 @@ class SearchEngine:
|
|||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> Union[str, list[dict[str, str]]]:
|
||||
async def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: bool = True,
|
||||
ignore_errors: bool = False,
|
||||
) -> Union[str, list[dict[str, str]]]:
|
||||
"""Run a search query.
|
||||
|
||||
Args:
|
||||
|
|
@ -94,4 +119,11 @@ class SearchEngine:
|
|||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
try:
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
except Exception as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
if not ignore_errors:
|
||||
raise e
|
||||
return "" if as_string else []
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import json
|
||||
from concurrent import futures
|
||||
from typing import Literal, overload
|
||||
from typing import Literal, Optional, overload
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
try:
|
||||
from duckduckgo_search import DDGS
|
||||
|
|
@ -18,24 +18,16 @@ except ImportError:
|
|||
)
|
||||
|
||||
|
||||
class DDGAPIWrapper:
|
||||
"""Wrapper around duckduckgo_search API.
|
||||
class DDGAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
To use this module, you should have the `duckduckgo_search` Python package installed.
|
||||
"""
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
):
|
||||
kwargs = {}
|
||||
if config.proxy:
|
||||
kwargs["proxies"] = config.proxy
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
self.ddgs = DDGS(**kwargs)
|
||||
@property
|
||||
def ddgs(self):
|
||||
return DDGS(proxies=self.proxy)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
|
|
|
|||
|
|
@ -4,19 +4,16 @@ from __future__ import annotations
|
|||
|
||||
import asyncio
|
||||
import json
|
||||
import warnings
|
||||
from concurrent import futures
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
try:
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use this module, you should have the `google-api-python-client` Python package installed. "
|
||||
|
|
@ -27,40 +24,41 @@ except ImportError:
|
|||
class GoogleAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
google_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
|
||||
api_key: str
|
||||
cse_id: str
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("google_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_google(cls, values: dict) -> dict:
|
||||
if "google_api_key" in values:
|
||||
values.setdefault("api_key", values["google_api_key"])
|
||||
warnings.warn("`google_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_API_KEY is set with your API key. You can obtain "
|
||||
"To use google search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
|
||||
"an API key from https://console.cloud.google.com/apis/credentials."
|
||||
)
|
||||
return val
|
||||
|
||||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or config.search.cse_id
|
||||
if not val:
|
||||
if "google_cse_id" in values:
|
||||
values.setdefault("cse_id", values["google_cse_id"])
|
||||
warnings.warn("`google_cse_id` is deprecated, use `cse_id` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "cse_id" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable GOOGLE_CSE_ID is set with your API key. You can obtain "
|
||||
"an API key from https://programmablesearchengine.google.com/controlpanel/create."
|
||||
"To use google search engine, make sure you provide the `cse_id` when constructing an object. You can obtain "
|
||||
"the cse_id from https://programmablesearchengine.google.com/controlpanel/create."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
@property
|
||||
def google_api_client(self):
|
||||
build_kwargs = {"developerKey": self.google_api_key}
|
||||
if config.proxy:
|
||||
parse_result = urlparse(config.proxy)
|
||||
build_kwargs = {"developerKey": self.api_key}
|
||||
if self.proxy:
|
||||
parse_result = urlparse(self.proxy)
|
||||
proxy_type = parse_result.scheme
|
||||
if proxy_type == "https":
|
||||
proxy_type = "http"
|
||||
|
|
@ -96,17 +94,11 @@ class GoogleAPIWrapper(BaseModel):
|
|||
"""
|
||||
loop = self.loop or asyncio.get_event_loop()
|
||||
future = loop.run_in_executor(
|
||||
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.google_cse_id).execute
|
||||
self.executor, self.google_api_client.list(q=query, num=max_results, cx=self.cse_id).execute
|
||||
)
|
||||
try:
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
logger.exception(f"fail to search {query} for {e}")
|
||||
search_results = []
|
||||
result = await future
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
focus = focus or ["snippet", "link", "title"]
|
||||
details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
|
|
|||
|
|
@ -5,18 +5,17 @@
|
|||
@Author : alexanderwu
|
||||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
api_key: str
|
||||
params: dict = Field(
|
||||
default_factory=lambda: {
|
||||
"engine": "google",
|
||||
|
|
@ -25,21 +24,22 @@ class SerpAPIWrapper(BaseModel):
|
|||
"hl": "en",
|
||||
}
|
||||
)
|
||||
# should add `validate_default=True` to check with default value
|
||||
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("serpapi_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_serpapi(cls, values: dict) -> dict:
|
||||
if "serpapi_api_key" in values:
|
||||
values.setdefault("api_key", values["serpapi_api_key"])
|
||||
warnings.warn("`serpapi_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPAPI_API_KEY is set with your API key. You can obtain "
|
||||
"an API key from https://serpapi.com/."
|
||||
"To use serpapi search engine, make sure you provide the `api_key` when constructing an object. You can obtain"
|
||||
" an API key from https://serpapi.com/."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
|
|
@ -60,11 +60,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
url, params = construct_url_and_params()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params) as response:
|
||||
async with session.get(url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params) as response:
|
||||
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
@ -73,7 +73,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
def get_params(self, query: str) -> Dict[str, str]:
|
||||
"""Get parameters for SerpAPI."""
|
||||
_params = {
|
||||
"api_key": self.serpapi_api_key,
|
||||
"api_key": self.api_key,
|
||||
"q": query,
|
||||
}
|
||||
params = {**self.params, **_params}
|
||||
|
|
|
|||
|
|
@ -6,33 +6,34 @@
|
|||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
api_key: str
|
||||
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@field_validator("serper_api_key", mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
def validate_serper(cls, values: dict) -> dict:
|
||||
if "serper_api_key" in values:
|
||||
values.setdefault("api_key", values["serper_api_key"])
|
||||
warnings.warn("`serper_api_key` is deprecated, use `api_key` instead", DeprecationWarning, stacklevel=2)
|
||||
|
||||
if "api_key" not in values:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
"ensure that the environment variable SERPER_API_KEY is set with your API key. You can obtain "
|
||||
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
|
||||
"an API key from https://serper.dev/."
|
||||
)
|
||||
return val
|
||||
return values
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
|
|
@ -54,11 +55,11 @@ class SerperWrapper(BaseModel):
|
|||
url, payloads, headers = construct_url_and_payload_and_headers()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=payloads, headers=headers) as response:
|
||||
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers) as response:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
@ -76,7 +77,7 @@ class SerperWrapper(BaseModel):
|
|||
return json.dumps(payloads, sort_keys=True)
|
||||
|
||||
def get_headers(self) -> Dict[str, str]:
|
||||
headers = {"X-API-KEY": self.serper_api_key, "Content-Type": "application/json"}
|
||||
headers = {"X-API-KEY": self.api_key, "Content-Type": "application/json"}
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -1,36 +1,49 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
from typing import Any, Callable, Coroutine, Optional, Union, overload
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
if engine is None:
|
||||
raise NotImplementedError
|
||||
class WebBrowserEngine(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow")
|
||||
|
||||
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
|
||||
engine: WebBrowserEngineType = WebBrowserEngineType.PLAYWRIGHT
|
||||
run_func: Optional[Callable[..., Coroutine[Any, Any, Union[WebPage, list[WebPage]]]]] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_extra(self):
|
||||
data = self.model_dump(exclude={"engine"}, exclude_none=True, exclude_defaults=True)
|
||||
if self.model_extra:
|
||||
data.update(self.model_extra)
|
||||
self._process_extra(**data)
|
||||
return self
|
||||
|
||||
def _process_extra(self, **kwargs):
|
||||
if self.engine is WebBrowserEngineType.PLAYWRIGHT:
|
||||
module = "metagpt.tools.web_browser_engine_playwright"
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper().run
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper(**kwargs).run
|
||||
elif self.engine is WebBrowserEngineType.SELENIUM:
|
||||
module = "metagpt.tools.web_browser_engine_selenium"
|
||||
run_func = importlib.import_module(module).SeleniumWrapper().run
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
|
||||
run_func = run_func
|
||||
run_func = importlib.import_module(module).SeleniumWrapper(**kwargs).run
|
||||
elif self.engine is WebBrowserEngineType.CUSTOM:
|
||||
run_func = self.run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.run_func = run_func
|
||||
self.engine = engine
|
||||
|
||||
@classmethod
|
||||
def from_browser_config(cls, config: BrowserConfig, **kwargs):
|
||||
data = config.model_dump()
|
||||
return cls(**data, **kwargs)
|
||||
|
||||
@overload
|
||||
async def run(self, url: str) -> WebPage:
|
||||
|
|
|
|||
|
|
@ -6,16 +6,16 @@ from __future__ import annotations
|
|||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from playwright.async_api import async_playwright
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class PlaywrightWrapper:
|
||||
class PlaywrightWrapper(BaseModel):
|
||||
"""Wrapper around Playwright.
|
||||
|
||||
To use this module, you should have the `playwright` Python package installed and ensure that
|
||||
|
|
@ -24,24 +24,23 @@ class PlaywrightWrapper:
|
|||
command `playwright install` for the first time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
|
||||
launch_kwargs: dict | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy" not in launch_kwargs:
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] = "chromium"
|
||||
launch_kwargs: dict = Field(default_factory=dict)
|
||||
proxy: Optional[str] = None
|
||||
context_kwargs: dict = Field(default_factory=dict)
|
||||
_has_run_precheck: bool = PrivateAttr(False)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
launch_kwargs = self.launch_kwargs
|
||||
if self.proxy and "proxy" not in launch_kwargs:
|
||||
args = launch_kwargs.get("args", [])
|
||||
if not any(str.startswith(i, "--proxy-server=") for i in args):
|
||||
launch_kwargs["proxy"] = {"server": config.proxy}
|
||||
self.launch_kwargs = launch_kwargs
|
||||
context_kwargs = {}
|
||||
launch_kwargs["proxy"] = {"server": self.proxy}
|
||||
|
||||
if "ignore_https_errors" in kwargs:
|
||||
context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
|
||||
self._context_kwargs = context_kwargs
|
||||
self._has_run_precheck = False
|
||||
self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
|
||||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
async with async_playwright() as ap:
|
||||
|
|
@ -55,7 +54,7 @@ class PlaywrightWrapper:
|
|||
return await _scrape(browser, url)
|
||||
|
||||
async def _scrape(self, browser, url):
|
||||
context = await browser.new_context(**self._context_kwargs)
|
||||
context = await browser.new_context(**self.context_kwargs)
|
||||
page = await context.new_page()
|
||||
async with page:
|
||||
try:
|
||||
|
|
@ -75,8 +74,8 @@ class PlaywrightWrapper:
|
|||
executable_path = Path(browser_type.executable_path)
|
||||
if not executable_path.exists() and "executable_path" not in self.launch_kwargs:
|
||||
kwargs = {}
|
||||
if config.proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": config.proxy}
|
||||
if self.proxy:
|
||||
kwargs["env"] = {"ALL_PROXY": self.proxy}
|
||||
await _install_browsers(self.browser_type, **kwargs)
|
||||
|
||||
if self._has_run_precheck:
|
||||
|
|
|
|||
|
|
@ -7,19 +7,19 @@ import asyncio
|
|||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from typing import Callable, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from webdriver_manager.core.download_manager import WDMDownloadManager
|
||||
from webdriver_manager.core.http import WDMHttpClient
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class SeleniumWrapper:
|
||||
class SeleniumWrapper(BaseModel):
|
||||
"""Wrapper around Selenium.
|
||||
|
||||
To use this module, you should check the following:
|
||||
|
|
@ -31,25 +31,28 @@ class SeleniumWrapper:
|
|||
can scrape web pages using the Selenium WebBrowserEngine.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome",
|
||||
launch_kwargs: dict | None = None,
|
||||
*,
|
||||
loop: asyncio.AbstractEventLoop | None = None,
|
||||
executor: futures.Executor | None = None,
|
||||
) -> None:
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy-server" not in launch_kwargs:
|
||||
launch_kwargs["proxy-server"] = config.proxy
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
self.executable_path = launch_kwargs.pop("executable_path", None)
|
||||
self.launch_args = [f"--{k}={v}" for k, v in launch_kwargs.items()]
|
||||
self._has_run_precheck = False
|
||||
self._get_driver = None
|
||||
self.loop = loop
|
||||
self.executor = executor
|
||||
browser_type: Literal["chrome", "firefox", "edge", "ie"] = "chrome"
|
||||
launch_kwargs: dict = Field(default_factory=dict)
|
||||
proxy: Optional[str] = None
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
_has_run_precheck: bool = PrivateAttr(False)
|
||||
_get_driver: Optional[Callable] = PrivateAttr(None)
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
super().__init__(**kwargs)
|
||||
if self.proxy and "proxy-server" not in self.launch_kwargs:
|
||||
self.launch_kwargs["proxy-server"] = self.proxy
|
||||
|
||||
@property
|
||||
def launch_args(self):
|
||||
return [f"--{k}={v}" for k, v in self.launch_kwargs.items() if k != "executable_path"]
|
||||
|
||||
@property
|
||||
def executable_path(self):
|
||||
return self.launch_kwargs.get("executable_path")
|
||||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
await self._run_precheck()
|
||||
|
|
@ -66,7 +69,9 @@ class SeleniumWrapper:
|
|||
self.loop = self.loop or asyncio.get_event_loop()
|
||||
self._get_driver = await self.loop.run_in_executor(
|
||||
self.executor,
|
||||
lambda: _gen_get_driver_func(self.browser_type, *self.launch_args, executable_path=self.executable_path),
|
||||
lambda: _gen_get_driver_func(
|
||||
self.browser_type, *self.launch_args, executable_path=self.executable_path, proxy=self.proxy
|
||||
),
|
||||
)
|
||||
self._has_run_precheck = True
|
||||
|
||||
|
|
@ -92,13 +97,17 @@ _webdriver_manager_types = {
|
|||
|
||||
|
||||
class WDMHttpProxyClient(WDMHttpClient):
|
||||
def __init__(self, proxy: str = None):
|
||||
super().__init__()
|
||||
self.proxy = proxy
|
||||
|
||||
def get(self, url, **kwargs):
|
||||
if "proxies" not in kwargs and config.proxy:
|
||||
kwargs["proxies"] = {"all_proxy": config.proxy}
|
||||
if "proxies" not in kwargs and self.proxy:
|
||||
kwargs["proxies"] = {"all_proxy": self.proxy}
|
||||
return super().get(url, **kwargs)
|
||||
|
||||
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None, proxy=None):
|
||||
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
|
||||
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
|
||||
Options = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.options"), "Options")
|
||||
|
|
@ -106,7 +115,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
if not executable_path:
|
||||
module_name, type_name = _webdriver_manager_types[browser_type]
|
||||
DriverManager = getattr(importlib.import_module(module_name), type_name)
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient(proxy=proxy)))
|
||||
# driver_manager.driver_cache.find_driver(driver_manager.driver))
|
||||
executable_path = driver_manager.install()
|
||||
|
||||
|
|
|
|||
|
|
@ -28,9 +28,9 @@ async def test_collect_links(mocker, search_engine_mocker, context):
|
|||
return "[1,2]"
|
||||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_llm_ask)
|
||||
resp = await research.CollectLinks(search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), context=context).run(
|
||||
"The application of MetaGPT"
|
||||
)
|
||||
resp = await research.CollectLinks(
|
||||
search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO), context=context
|
||||
).run("The application of MetaGPT")
|
||||
for i in ["MetaGPT use cases", "The roadmap of MetaGPT", "The function of MetaGPT", "What llm MetaGPT support"]:
|
||||
assert i in resp
|
||||
|
||||
|
|
@ -50,7 +50,9 @@ async def test_collect_links_with_rank_func(mocker, search_engine_mocker, contex
|
|||
|
||||
mocker.patch("metagpt.provider.base_llm.BaseLLM.aask", mock_collect_links_llm_ask)
|
||||
resp = await research.CollectLinks(
|
||||
search_engine=SearchEngine(SearchEngineType.DUCK_DUCK_GO), rank_func=rank_func, context=context
|
||||
search_engine=SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO),
|
||||
rank_func=rank_func,
|
||||
context=context,
|
||||
).run("The application of MetaGPT")
|
||||
for x, y, z in zip(rank_before, rank_after, resp.values()):
|
||||
assert x[::-1] == y
|
||||
|
|
|
|||
|
|
@ -16,6 +16,6 @@ async def test_google_search(search_engine_mocker):
|
|||
result = await google_search(
|
||||
seed.input,
|
||||
engine=SearchEngineType.SERPER_GOOGLE,
|
||||
serper_api_key="mock-serper-key",
|
||||
api_key="mock-serper-key",
|
||||
)
|
||||
assert result != ""
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ async def test_researcher(mocker, search_engine_mocker, context):
|
|||
role = researcher.Researcher(context=context)
|
||||
for i in role.actions:
|
||||
if isinstance(i, CollectLinks):
|
||||
i.search_engine = SearchEngine(SearchEngineType.DUCK_DUCK_GO)
|
||||
i.search_engine = SearchEngine(engine=SearchEngineType.DUCK_DUCK_GO)
|
||||
await role.run(topic)
|
||||
assert (researcher.RESEARCH_PATH / f"{topic}.md").read_text().startswith("# Research Report")
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from typing import Callable
|
|||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.configs.search_config import SearchConfig
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -49,27 +50,34 @@ async def test_search_engine(
|
|||
search_engine_mocker,
|
||||
):
|
||||
# Prerequisites
|
||||
search_engine_config = {}
|
||||
search_engine_config = {"engine": search_engine_type, "run_func": run_func}
|
||||
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["serpapi_api_key"] = "mock-serpapi-key"
|
||||
search_engine_config["api_key"] = "mock-serpapi-key"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["google_api_key"] = "mock-google-key"
|
||||
search_engine_config["google_cse_id"] = "mock-google-cse"
|
||||
search_engine_config["api_key"] = "mock-google-key"
|
||||
search_engine_config["cse_id"] = "mock-google-cse"
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search
|
||||
search_engine_config["serper_api_key"] = "mock-serper-key"
|
||||
search_engine_config["api_key"] = "mock-serper-key"
|
||||
|
||||
search_engine = SearchEngine(search_engine_type, run_func, **search_engine_config)
|
||||
rsp = await search_engine.run("metagpt", max_results, as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) <= max_results
|
||||
async def test(search_engine):
|
||||
rsp = await search_engine.run("metagpt", max_results, as_string)
|
||||
logger.info(rsp)
|
||||
if as_string:
|
||||
assert isinstance(rsp, str)
|
||||
else:
|
||||
assert isinstance(rsp, list)
|
||||
assert len(rsp) <= max_results
|
||||
|
||||
await test(SearchEngine(**search_engine_config))
|
||||
search_engine_config["api_type"] = search_engine_config.pop("engine")
|
||||
if run_func:
|
||||
await test(SearchEngine.from_search_func(run_func))
|
||||
search_engine_config["search_func"] = search_engine_config.pop("run_func")
|
||||
await test(SearchEngine.from_search_config(SearchConfig(**search_engine_config)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.tools import web_browser_engine_playwright
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
|
@ -19,26 +18,22 @@ from metagpt.utils.parse_html import WebPage
|
|||
ids=["chromium-normal", "firefox-normal", "webkit-normal"],
|
||||
)
|
||||
async def test_scrape_web_page(browser_type, use_proxy, kwagrs, url, urls, proxy, capfd):
|
||||
global_proxy = config.proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
config.proxy = proxy_url
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_playwright.PlaywrightWrapper(browser_type=browser_type, proxy=proxy_url, **kwagrs)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
config.proxy = global_proxy
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
import browsers
|
||||
import pytest
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.tools import web_browser_engine_selenium
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
|
@ -40,27 +39,22 @@ from metagpt.utils.parse_html import WebPage
|
|||
async def test_scrape_web_page(browser_type, use_proxy, url, urls, proxy, capfd):
|
||||
# Prerequisites
|
||||
# firefox, chrome, Microsoft Edge
|
||||
proxy_url = None
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type, proxy=proxy_url)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
global_proxy = config.proxy
|
||||
try:
|
||||
if use_proxy:
|
||||
server, proxy_url = await proxy()
|
||||
config.proxy = proxy_url
|
||||
browser = web_browser_engine_selenium.SeleniumWrapper(browser_type=browser_type)
|
||||
result = await browser.run(url)
|
||||
assert isinstance(result, WebPage)
|
||||
assert "MetaGPT" in result.inner_text
|
||||
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
finally:
|
||||
config.proxy = global_proxy
|
||||
if urls:
|
||||
results = await browser.run(url, *urls)
|
||||
assert isinstance(results, list)
|
||||
assert len(results) == len(urls) + 1
|
||||
assert all(("MetaGPT" in i.inner_text) for i in results)
|
||||
if use_proxy:
|
||||
server.close()
|
||||
assert "Proxy:" in capfd.readouterr().out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -13,7 +13,8 @@ class MockAioResponse:
|
|||
|
||||
def __init__(self, session, method, url, **kwargs) -> None:
|
||||
fn = self.check_funcs.get((method, url))
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(kwargs, sort_keys=True)}"
|
||||
_kwargs = {k: v for k, v in kwargs.items() if k != "proxy"}
|
||||
self.key = f"{self.name}-{method}-{url}-{fn(kwargs) if fn else json.dumps(_kwargs, sort_keys=True)}"
|
||||
self.mng = self.response = None
|
||||
if self.key not in self.rsp_cache:
|
||||
self.mng = origin_request(session, method, url, **kwargs)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue