mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-02 14:45:17 +02:00
resolved current conflicts
This commit is contained in:
parent
730e2f912f
commit
193178b7d1
11 changed files with 290 additions and 269 deletions
|
|
@ -2,29 +2,27 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import Config
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
||||
config = Config()
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {
|
||||
"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
|
||||
},
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
|
|
@ -36,21 +34,20 @@ payload = {
|
|||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
'enable_hr': False,
|
||||
'hr_scale': 2,
|
||||
'hr_upscaler': 'Latent',
|
||||
'hr_second_pass_steps': 0,
|
||||
'hr_resize_x': 0,
|
||||
'hr_resize_y': 0,
|
||||
'hr_upscale_to_x': 0,
|
||||
'hr_upscale_to_y': 0,
|
||||
'truncate_x': 0,
|
||||
'truncate_y': 0,
|
||||
'applied_old_hires_behavior_to': None,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {}
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
|
@ -60,14 +57,20 @@ class SDEngine:
|
|||
def __init__(self):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.config = Config()
|
||||
self.sd_url = self.config.get('SD_URL')
|
||||
self.sd_url = self.config.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20"):
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
|
|
@ -76,13 +79,13 @@ class SDEngine:
|
|||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
|
|
@ -90,24 +93,25 @@ class SDEngine:
|
|||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json['images']
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: Add image-to-image interface call
|
||||
raise NotImplementedError
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo: Add SAM interface call
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
|
|
@ -122,12 +126,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
|
|
@ -7,118 +7,76 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
|
||||
from metagpt.tools.search_engine_serper import SerperWrapper
|
||||
|
||||
config = Config()
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
TODO: Integrate Google Search and reverse proxy.
|
||||
Note: Google here requires a Proxifier or similar global proxy.
|
||||
- DDG: https://pypi.org/project/duckduckgo-search/
|
||||
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
|
||||
"""
|
||||
def __init__(self, engine=None, run_func=None):
|
||||
self.config = Config()
|
||||
self.run_func = run_func
|
||||
self.engine = engine or self.config.search_engine
|
||||
"""Class representing a search engine.
|
||||
|
||||
@classmethod
|
||||
def run_google(cls, query, max_results=8):
|
||||
# results = ddg(query, max_results=max_results)
|
||||
results = google_official_search(query, num_results=max_results)
|
||||
logger.info(results)
|
||||
return results
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
|
||||
async def run(self, query: str, max_results=8):
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
api = SerpAPIWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
rsp = SearchEngine.run_google(query, max_results)
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
api = SerperWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
rsp = self.run_func(query)
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rsp
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
"""Run a search query.
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The maximum number of results to return. Defaults to 8.
|
||||
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
|
||||
|
||||
try:
|
||||
api_key = config.google_api_key
|
||||
custom_search_engine_id = config.google_cse_id
|
||||
|
||||
with build("customsearch", "v1", developerKey=api_key) as service:
|
||||
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
logger.info(result)
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
return "Error: The provided Google API key is invalid or missing."
|
||||
else:
|
||||
return f"Error: {e}"
|
||||
|
||||
# Return the list of search result URLs
|
||||
return search_results_details
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
if __name__ == '__main__':
|
||||
SearchEngine.run(query='wtf')
|
||||
|
||||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
|
|
|
|||
|
|
@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query))
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
|
||||
async def results(self, query: str) -> dict:
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
|
|
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
focus = ["title", "snippet", "link"]
|
||||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
|
|
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
|
|
@ -112,5 +104,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerpAPIWrapper().run)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
queries = query.split("\n")
|
||||
return "\n".join([self._process_response(res) for res in await self.results(queries)])
|
||||
if isinstance(query, str):
|
||||
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
|
||||
else:
|
||||
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
|
||||
return "\n".join(results) if as_string else results
|
||||
|
||||
async def results(self, queries: list[str]) -> dict:
|
||||
async def results(self, queries: list[str], max_results: int = 8) -> dict:
|
||||
"""Use aiohttp to run query through Serper and return the results async."""
|
||||
|
||||
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
|
||||
payloads = self.get_payloads(queries)
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = self.get_headers()
|
||||
return url, payloads, headers
|
||||
|
|
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
|
|||
|
||||
return res
|
||||
|
||||
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
|
||||
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
|
||||
"""Get payloads for Serper."""
|
||||
payloads = []
|
||||
for query in queries:
|
||||
_payload = {
|
||||
"q": query,
|
||||
"num": max_results,
|
||||
}
|
||||
payloads.append({**self.payload, **_payload})
|
||||
return json.dumps(payloads, sort_keys=True)
|
||||
|
|
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
|
|||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool = False) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
|
|
@ -117,5 +121,10 @@ class SerperWrapper(BaseModel):
|
|||
if res.get("organic"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerperWrapper().run)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import importlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from bs4 import BeautifulSoup
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
|
||||
parse_func: Callable[[str], str] | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
|
||||
|
|
@ -30,31 +28,25 @@ class WebBrowserEngine:
|
|||
run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.parse_func = parse_func or get_page_content
|
||||
self.run_func = run_func
|
||||
self.engine = engine
|
||||
|
||||
@overload
|
||||
async def run(self, url: str) -> str:
|
||||
async def run(self, url: str) -> WebPage:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def run(self, url: str, *urls: str) -> list[str]:
|
||||
async def run(self, url: str, *urls: str) -> list[WebPage]:
|
||||
...
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
page = await self.run_func(url, *urls)
|
||||
if isinstance(page, str):
|
||||
return self.parse_func(page)
|
||||
return [self.parse_func(i) for i in page]
|
||||
|
||||
|
||||
def get_page_content(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
return await self.run_func(url, *urls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from metagpt.config import CONFIG
|
||||
import asyncio
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from concurrent import futures
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class SeleniumWrapper:
|
||||
|
|
@ -48,7 +49,7 @@ class SeleniumWrapper:
|
|||
self.loop = loop
|
||||
self.executor = executor
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
await self._run_precheck()
|
||||
|
||||
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
|
||||
|
|
@ -69,9 +70,15 @@ class SeleniumWrapper:
|
|||
|
||||
def _scrape_website(self, url):
|
||||
with self._get_driver() as driver:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
return driver.page_source
|
||||
try:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
inner_text = driver.execute_script("return document.body.innerText;")
|
||||
html = driver.page_source
|
||||
except Exception as e:
|
||||
inner_text = f"Fail to load page content for {e}"
|
||||
html = ""
|
||||
return WebPage(inner_text=inner_text, html=html, url=url)
|
||||
|
||||
|
||||
_webdriver_manager_types = {
|
||||
|
|
@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
def _get_driver():
|
||||
options = Options()
|
||||
options.add_argument("--headless")
|
||||
options.add_argument("--enable-javascript")
|
||||
if browser_type == "chrome":
|
||||
options.add_argument("--no-sandbox")
|
||||
for i in args:
|
||||
|
|
@ -107,6 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue