update timeout

This commit is contained in:
seehi 2024-08-02 09:00:51 +08:00
parent 997ec8fcb7
commit 033211eb4b
5 changed files with 86 additions and 37 deletions

View file

@ -206,6 +206,7 @@ class WebBrowseAndSummarize(Action):
query: str,
system_text: str = RESEARCH_BASE_SYSTEM,
use_concurrent_summarization: bool = False,
per_page_timeout: Optional[float] = None,
) -> dict[str, str]:
"""Run the action to browse the web and provide summaries.
@ -215,11 +216,12 @@ class WebBrowseAndSummarize(Action):
query: The research question.
system_text: The system text.
use_concurrent_summarization: Whether to concurrently summarize the content of the webpage by LLM.
per_page_timeout: The maximum time for fetching a single page in seconds.
Returns:
A dictionary containing the URLs as keys and their summaries as values.
"""
contents = await self._fetch_web_contents(url, *urls)
contents = await self._fetch_web_contents(url, *urls, per_page_timeout=per_page_timeout)
all_urls = [url] + list(urls)
summarize_tasks = [self._summarize_content(content, query, system_text) for content in contents]
@ -228,37 +230,52 @@ class WebBrowseAndSummarize(Action):
return result
async def _fetch_web_contents(self, url: str, *urls: str) -> list[str]:
async def _fetch_web_contents(self, url: str, *urls: str, per_page_timeout: Optional[float] = None) -> list[str]:
"""Fetch web contents from given URLs."""
contents = await self.web_browser_engine.run(url, *urls)
contents = await self.web_browser_engine.run(url, *urls, per_page_timeout=per_page_timeout)
return [contents] if not urls else contents
async def _summarize_content(self, content: str, query: str, system_text: str) -> tuple[str, str]:
async def _summarize_content(self, content: str, query: str, system_text: str) -> str:
"""Summarize web content."""
try:
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
prompt_template = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content="{}")
content = content.inner_text
content = content.inner_text
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
logger.debug(prompt)
if self._is_content_invalid(content):
logger.warning(f"Invalid content detected: {content[:10]}...")
return None
chunk_summaries = []
for prompt in generate_prompt_chunk(content, prompt_template, self.llm.model, system_text, 4096):
logger.debug(prompt)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
return None
if len(chunk_summaries) == 1:
return chunk_summaries[0]
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
summary = await self._aask(prompt, [system_text])
if summary == "Not relevant.":
continue
chunk_summaries.append(summary)
if not chunk_summaries:
return summary
except Exception as e:
logger.error(f"Error summarizing content: {e}")
return None
if len(chunk_summaries) == 1:
return chunk_summaries[0]
def _is_content_invalid(self, content: str) -> bool:
"""Check if the content is invalid based on specific starting phrases."""
content = "\n".join(chunk_summaries)
prompt = WEB_BROWSE_AND_SUMMARIZE_PROMPT.format(query=query, content=content)
summary = await self._aask(prompt, [system_text])
return summary
invalid_starts = ["Fail to load page", "Access Denied"]
return any(content.strip().startswith(phrase) for phrase in invalid_starts)
async def _execute_summarize_tasks(self, tasks: list[Coroutine[Any, Any, str]], use_concurrent: bool) -> list[str]:
"""Execute summarize tasks either concurrently or sequentially."""

View file

@ -4,11 +4,12 @@ from __future__ import annotations
import json
from pydantic import Field
from pydantic import Field, model_validator
from metagpt.actions import Action
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
from metagpt.logs import logger
from metagpt.tools.web_browser_engine import WebBrowserEngine
from metagpt.utils.common import CodeParser
REWRITE_QUERY_PROMPT = """
@ -62,9 +63,26 @@ class SearchEnhancedQA(Action):
default=CollectLinks(), description="Action to collect relevant links from a search engine."
)
web_browse_and_summarize_action: WebBrowseAndSummarize = Field(
default=WebBrowseAndSummarize(),
default=None,
description="Action to explore the web and provide summaries of articles and webpages.",
)
per_page_timeout: float = Field(
default=10, description="The maximum time for fetching a single page is in seconds. Defaults to 10s."
)
java_script_enabled: bool = Field(
default=False, description="Whether or not to enable JavaScript in the web browser context. Defaults to False."
)
@model_validator(mode="after")
def initialize(self):
if self.web_browse_and_summarize_action is None:
self.web_browser_engine = WebBrowserEngine.from_browser_config(
self.config.browser, proxy=self.config.proxy, java_script_enabled=self.java_script_enabled
)
self.web_browse_and_summarize_action = WebBrowseAndSummarize(web_browser_engine=self.web_browser_engine)
return self
async def run(self, query: str, rewrite_query: bool = True) -> str:
"""Answer a query by leveraging web search results.
@ -202,7 +220,9 @@ class SearchEnhancedQA(Action):
dict[str, str]: Mapping of URLs to their summaries.
"""
return await self.web_browse_and_summarize_action.run(*urls, query=query, use_concurrent_summarization=True)
return await self.web_browse_and_summarize_action.run(
*urls, query=query, use_concurrent_summarization=True, per_page_timeout=self.per_page_timeout
)
async def _generate_answer(self, query: str, context: str) -> str:
"""Generate an answer using the query and context.

View file

@ -92,14 +92,14 @@ class WebBrowserEngine(BaseModel):
return cls(**data, **kwargs)
@overload
async def run(self, url: str) -> WebPage:
async def run(self, url: str, per_page_timeout: float = None) -> WebPage:
...
@overload
async def run(self, url: str, *urls: str) -> list[WebPage]:
async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> list[WebPage]:
...
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]:
"""Runs the browser engine to load one or more web pages.
This method is the implementation of the overloaded run signatures. It delegates the task
@ -108,8 +108,9 @@ class WebBrowserEngine(BaseModel):
Args:
url: The URL of the first web page to load.
*urls: Additional URLs of web pages to load, if any.
per_page_timeout: The maximum time for fetching a single page in seconds.
Returns:
A WebPage object if a single URL is provided, or a list of WebPage objects if multiple URLs are provided.
"""
return await self.run_func(url, *urls)
return await self.run_func(url, *urls, per_page_timeout=per_page_timeout)

View file

@ -42,7 +42,10 @@ class PlaywrightWrapper(BaseModel):
if "ignore_https_errors" in kwargs:
self.context_kwargs["ignore_https_errors"] = kwargs["ignore_https_errors"]
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
if "java_script_enabled" in kwargs:
self.context_kwargs["java_script_enabled"] = kwargs["java_script_enabled"]
async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]:
async with async_playwright() as ap:
browser_type = getattr(ap, self.browser_type)
await self._run_precheck(browser_type)
@ -50,11 +53,17 @@ class PlaywrightWrapper(BaseModel):
_scrape = self._scrape
if urls:
return await asyncio.gather(_scrape(browser, url), *(_scrape(browser, i) for i in urls))
return await _scrape(browser, url)
return await asyncio.gather(
_scrape(browser, url, per_page_timeout), *(_scrape(browser, i, per_page_timeout) for i in urls)
)
return await _scrape(browser, url, per_page_timeout)
async def _scrape(self, browser, url):
async def _scrape(self, browser, url, timeout: float = None):
context = await browser.new_context(**self.context_kwargs)
if timeout is not None:
context.set_default_timeout(timeout * 1000) # playwright uses milliseconds.
page = await context.new_page()
async with page:
try:

View file

@ -54,14 +54,16 @@ class SeleniumWrapper(BaseModel):
def executable_path(self):
return self.launch_kwargs.get("executable_path")
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
async def run(self, url: str, *urls: str, per_page_timeout: float = None) -> WebPage | list[WebPage]:
await self._run_precheck()
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
_scrape = lambda url, per_page_timeout: self.loop.run_in_executor(
self.executor, self._scrape_website, url, per_page_timeout
)
if urls:
return await asyncio.gather(_scrape(url), *(_scrape(i) for i in urls))
return await _scrape(url)
return await asyncio.gather(_scrape(url, per_page_timeout), *(_scrape(i, per_page_timeout) for i in urls))
return await _scrape(url, per_page_timeout)
async def _run_precheck(self):
if self._has_run_precheck:
@ -75,11 +77,11 @@ class SeleniumWrapper(BaseModel):
)
self._has_run_precheck = True
def _scrape_website(self, url):
def _scrape_website(self, url, timeout: float = None):
with self._get_driver() as driver:
try:
driver.get(url)
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
WebDriverWait(driver, timeout or 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
inner_text = driver.execute_script("return document.body.innerText;")
html = driver.page_source
except Exception as e: