From 193178b7d1616cb383e24292a5045040ca0c42cc Mon Sep 17 00:00:00 2001 From: brucemeek <113046530+brucemeek@users.noreply.github.com> Date: Tue, 15 Aug 2023 06:51:39 -0500 Subject: [PATCH] resolved current conflicts --- Dockerfile | 16 +- metagpt/actions/__init__.py | 13 +- metagpt/config.py | 20 +-- metagpt/provider/openai_api.py | 79 ++++++--- metagpt/tools/sd_engine.py | 98 ++++++------ metagpt/tools/search_engine.py | 160 +++++++------------ metagpt/tools/search_engine_serpapi.py | 35 ++-- metagpt/tools/search_engine_serper.py | 27 ++-- metagpt/tools/web_browser_engine.py | 36 ++--- metagpt/tools/web_browser_engine_selenium.py | 33 ++-- metagpt/utils/mermaid.py | 42 +++-- 11 files changed, 290 insertions(+), 269 deletions(-) diff --git a/Dockerfile b/Dockerfile index be37f1df6..537bbc72e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,10 +1,10 @@ # Use a base image with Python3.9 and Nodejs20 slim version FROM nikolaik/python-nodejs:python3.9-nodejs20-slim -# Install Debian software needed by MetaGPT +# Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size RUN apt update &&\ apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\ - apt clean + apt clean && rm -rf /var/lib/apt/lists/* # Install Mermaid CLI globally ENV CHROME_BIN="/usr/bin/chromium" \ @@ -15,13 +15,11 @@ RUN npm install -g @mermaid-js/mermaid-cli &&\ # Install Python dependencies and install MetaGPT COPY . /app/metagpt -RUN cd /app/metagpt &&\ - mkdir workspace &&\ - pip install -r requirements.txt &&\ - pip cache purge &&\ +WORKDIR /app/metagpt +RUN mkdir workspace &&\ + pip install --no-cache-dir -r requirements.txt &&\ python setup.py install -WORKDIR /app/metagpt - # Running with an infinite loop using the tail command -CMD ["sh", "-c", "tail -f /dev/null"] \ No newline at end of file +CMD ["sh", "-c", "tail -f /dev/null"] + diff --git a/metagpt/actions/__init__.py b/metagpt/actions/__init__.py index 165349728..b004bd58e 100644 --- a/metagpt/actions/__init__.py +++ b/metagpt/actions/__init__.py @@ -15,6 +15,7 @@ from metagpt.actions.design_api import WriteDesign from metagpt.actions.design_api_review import DesignReview from metagpt.actions.design_filenames import DesignFilenames from metagpt.actions.project_management import AssignTasks, WriteTasks +from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch from metagpt.actions.run_code import RunCode from metagpt.actions.search_and_summarize import SearchAndSummarize from metagpt.actions.write_code import WriteCode @@ -26,6 +27,7 @@ from metagpt.actions.write_test import WriteTest class ActionType(Enum): """All types of Actions, used for indexing.""" + ADD_REQUIREMENT = BossRequirement WRITE_PRD = WritePRD WRITE_PRD_REVIEW = WritePRDReview @@ -40,4 +42,13 @@ class ActionType(Enum): WRITE_TASKS = WriteTasks ASSIGN_TASKS = AssignTasks SEARCH_AND_SUMMARIZE = SearchAndSummarize - \ No newline at end of file + COLLECT_LINKS = CollectLinks + WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize + CONDUCT_RESEARCH = ConductResearch + + +__all__ = [ + "ActionType", + "Action", + "ActionOutput", +] diff --git a/metagpt/config.py b/metagpt/config.py index 3753bb3b0..faeffd777 100644 --- a/metagpt/config.py +++ b/metagpt/config.py @@ -4,14 +4,14 @@ Provide configuration, singleton """ import os -import openai +import openai import yaml from metagpt.const import PROJECT_ROOT from metagpt.logs import logger -from metagpt.utils.singleton import Singleton from metagpt.tools import SearchEngineType, WebBrowserEngineType +from metagpt.utils.singleton import Singleton class NotConfiguredException(Exception): @@ -46,7 +46,6 @@ class Config(metaclass=Singleton): self.openai_api_key = self._get("OPENAI_API_KEY") if not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key: raise NotConfiguredException("Set OPENAI_API_KEY first") - self.openai_api_base = self._get("OPENAI_API_BASE") if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base: openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy @@ -67,21 +66,22 @@ class Config(metaclass=Singleton): self.google_api_key = self._get("GOOGLE_API_KEY") self.google_cse_id = self._get("GOOGLE_CSE_ID") self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE) - + self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright")) self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium") self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome") - + self.long_term_memory = self._get('LONG_TERM_MEMORY', False) if self.long_term_memory: logger.warning("LONG_TERM_MEMORY is True") self.max_budget = self._get("MAX_BUDGET", 10.0) self.total_cost = 0.0 - self.puppeteer_config = self._get("PUPPETEER_CONFIG","") - self.mmdc = self._get("MMDC","mmdc") - self.update_costs = self._get("UPDATE_COSTS",True) - self.calc_usage = self._get("CALC_USAGE",True) - + self.puppeteer_config = self._get("PUPPETEER_CONFIG", "") + self.mmdc = self._get("MMDC", "mmdc") + self.update_costs = self._get("UPDATE_COSTS", True) + self.calc_usage = self._get("CALC_USAGE", True) + self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY") + self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT") def _init_with_config_files_and_env(self, configs: dict, yaml_file): """Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority""" diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 6f7c33c4f..86b63770c 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # -*- coding: utf-8 -*- """ @Time : 2023/5/5 23:08 @@ -7,10 +6,11 @@ """ import asyncio import time -from functools import wraps from typing import NamedTuple import openai +from openai.error import APIConnectionError +from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type from metagpt.config import CONFIG from metagpt.logs import logger @@ -20,8 +20,10 @@ from metagpt.utils.token_counter import ( TOKEN_COSTS, count_message_tokens, count_string_tokens, + get_max_completion_tokens, ) +<<<<<<< main def retry(max_retries): def decorator(f): @wraps(f) @@ -41,10 +43,21 @@ class RateLimiter: def __init__(self, rpm): self.last_call_time = 0 self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly on time, they will still be QOS'd; consider switching to simple error retry later +======= + +class RateLimiter: + """Rate control class, each call goes through wait_if_needed, sleep if rate control is needed""" + + def __init__(self, rpm): + self.last_call_time = 0 + # Here 1.1 is used because even if the calls are made strictly according to time, + # they will still be QOS'd; consider switching to simple error retry later + self.interval = 1.1 * 60 / rpm +>>>>>>> main self.rpm = rpm def split_batches(self, batch): - return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)] + return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)] async def wait_if_needed(self, num_requests): current_time = time.time() @@ -64,7 +77,8 @@ class Costs(NamedTuple): total_budget: float class CostManager(metaclass=Singleton): - """Calculate the cost of using the interface""" + """计算使用接口的开销""" + def __init__(self): self.total_prompt_tokens = 0 self.total_completion_tokens = 0 @@ -82,13 +96,12 @@ class CostManager(metaclass=Singleton): """ self.total_prompt_tokens += prompt_tokens self.total_completion_tokens += completion_tokens - cost = ( - prompt_tokens * TOKEN_COSTS[model]["prompt"] - + completion_tokens * TOKEN_COSTS[model]["completion"] - ) / 1000 + cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000 self.total_cost += cost - logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " - f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}") + logger.info( + f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | " + f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" + ) CONFIG.total_cost = self.total_cost def get_total_prompt_tokens(self): @@ -122,14 +135,25 @@ def get_costs(self) -> Costs: """Get all costs""" return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget) +def log_and_reraise(retry_state): + logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") + logger.warning(""" +Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ +See FAQ 5.8 +""") + raise retry_state.outcome.exception() + + class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): """ Check https://platform.openai.com/examples for examples """ + def __init__(self): self.__init_openai(CONFIG) self.llm = openai self.model = CONFIG.openai_api_model + self.auto_max_tokens = False self._cost_manager = CostManager() RateLimiter.__init__(self, rpm=self.rpm) @@ -143,10 +167,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self.rpm = int(config.get("RPM", 10)) async def _achat_completion_stream(self, messages: list[dict]) -> str: - response = await openai.ChatCompletion.acreate( - **self._cons_kwargs(messages), - stream=True - ) + response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True) # create variables to collect the stream of chunks collected_chunks = [] @@ -154,41 +175,42 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # iterate through the stream of events async for chunk in response: collected_chunks.append(chunk) # save the event response - chunk_message = chunk['choices'][0]['delta'] # extract the message + chunk_message = chunk["choices"][0]["delta"] # extract the message collected_messages.append(chunk_message) # save the message if "content" in chunk_message: print(chunk_message["content"], end="") print() - full_reply_content = ''.join([m.get('content', '') for m in collected_messages]) + full_reply_content = "".join([m.get("content", "") for m in collected_messages]) usage = self._calc_usage(messages, full_reply_content) self._update_costs(usage) return full_reply_content def _cons_kwargs(self, messages: list[dict]) -> dict: - if CONFIG.openai_api_type == 'azure': + if CONFIG.openai_api_type == "azure": kwargs = { "deployment_id": CONFIG.deployment_id, "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, + "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, - "temperature": 0.3 + "temperature": 0.3, } else: kwargs = { "model": self.model, "messages": messages, - "max_tokens": CONFIG.max_tokens_rsp, + "max_tokens": self.get_max_tokens(messages), "n": 1, "stop": None, - "temperature": 0.3 + "temperature": 0.3, } + kwargs["timeout"] = 3 return kwargs async def _achat_completion(self, messages: list[dict]) -> dict: rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages)) - self._update_costs(rsp.get('usage')) + self._update_costs(rsp.get("usage")) return rsp def _chat_completion(self, messages: list[dict]) -> dict: @@ -206,7 +228,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): # messages = self.messages_to_dict(messages) return await self._achat_completion(messages) - @retry(max_retries=6) + @retry( + stop=stop_after_attempt(3), + wait=wait_fixed(1), + after=after_log(logger, logger.level('WARNING').name), + retry=retry_if_exception_type(APIConnectionError), + retry_error_callback=log_and_reraise, + ) async def acompletion_text(self, messages: list[dict], stream=False) -> str: """when streaming, print each token in place.""" if stream: @@ -257,3 +285,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): def get_costs(self) -> Costs: return self._cost_manager.get_costs() + + def get_max_tokens(self, messages: list[dict]): + if not self.auto_max_tokens: + return CONFIG.max_tokens_rsp + return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index e212c2fc7..1d9cd0b2a 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -2,29 +2,27 @@ # @Date : 2023/7/19 16:28 # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : -import os import asyncio +import base64 +import io +import json +import os from os.path import join from typing import List -import json -import io -import base64 from aiohttp import ClientSession from PIL import Image, PngImagePlugin -from metagpt.logs import logger from metagpt.config import Config from metagpt.const import WORKSPACE_ROOT +from metagpt.logs import logger config = Config() payload = { "prompt": "", "negative_prompt": "(easynegative:0.8),black, dark,Low resolution", - "override_settings": { - "sd_model_checkpoint": "galaxytimemachinesGTM_photoV20" - }, + "override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"}, "seed": -1, "batch_size": 1, "n_iter": 1, @@ -36,21 +34,20 @@ payload = { "tiling": False, "do_not_save_samples": False, "do_not_save_grid": False, - 'enable_hr': False, - 'hr_scale': 2, - 'hr_upscaler': 'Latent', - 'hr_second_pass_steps': 0, - 'hr_resize_x': 0, - 'hr_resize_y': 0, - 'hr_upscale_to_x': 0, - 'hr_upscale_to_y': 0, - 'truncate_x': 0, - 'truncate_y': 0, - 'applied_old_hires_behavior_to': None, + "enable_hr": False, + "hr_scale": 2, + "hr_upscaler": "Latent", + "hr_second_pass_steps": 0, + "hr_resize_x": 0, + "hr_resize_y": 0, + "hr_upscale_to_x": 0, + "hr_upscale_to_y": 0, + "truncate_x": 0, + "truncate_y": 0, + "applied_old_hires_behavior_to": None, "eta": None, - "sampler_index": "DPM++ SDE Karras", - "alwayson_scripts": {} + "alwayson_scripts": {}, } default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" @@ -60,14 +57,20 @@ class SDEngine: def __init__(self): # Initialize the SDEngine with configuration self.config = Config() - self.sd_url = self.config.get('SD_URL') + self.sd_url = self.config.get("SD_URL") self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}" # Define default payload settings for SD API self.payload = payload logger.info(self.sd_t2i_url) - - def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512, - sd_model="galaxytimemachinesGTM_photoV20"): + + def construct_payload( + self, + prompt, + negtive_prompt=default_negative_prompt, + width=512, + height=512, + sd_model="galaxytimemachinesGTM_photoV20", + ): # Configure the payload with provided inputs self.payload["prompt"] = prompt self.payload["negtive_prompt"] = negtive_prompt @@ -76,13 +79,13 @@ class SDEngine: self.payload["override_settings"]["sd_model_checkpoint"] = sd_model logger.info(f"call sd payload is {self.payload}") return self.payload - + def _save(self, imgs, save_name=""): - save_dir = WORKSPACE_ROOT / "resources"/"SD_Output" + save_dir = WORKSPACE_ROOT / "resources" / "SD_Output" if not os.path.exists(save_dir): os.makedirs(save_dir, exist_ok=True) batch_decode_base64_to_image(imgs, save_dir, save_name=save_name) - + async def run_t2i(self, prompts: List): # Asynchronously run the SD API for multiple prompts session = ClientSession() @@ -90,24 +93,25 @@ class SDEngine: results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) self._save(results, save_name=f"output_{payload_idx}") await session.close() - -async def run(self, url, payload, session): - # Perform the HTTP POST request to the SD API - async with session.post(url, json=payload, timeout=600) as rsp: - data = await rsp.read() - - rsp_json = json.loads(data) - imgs = rsp_json['images'] - logger.info(f"callback rsp json is {rsp_json.keys()}") - return imgs -async def run_i2i(self): - # todo: Add image-to-image interface call - raise NotImplementedError + async def run(self, url, payload, session): + # Perform the HTTP POST request to the SD API + async with session.post(url, json=payload, timeout=600) as rsp: + data = await rsp.read() + + rsp_json = json.loads(data) + imgs = rsp_json["images"] + logger.info(f"callback rsp json is {rsp_json.keys()}") + return imgs + + async def run_i2i(self): + # todo: 添加图生图接口调用 + raise NotImplementedError + + async def run_sam(self): + # todo:添加SAM接口调用 + raise NotImplementedError -async def run_sam(self): - # todo: Add SAM interface call - raise NotImplementedError def decode_base64_to_image(img, save_name): image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0]))) @@ -122,12 +126,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): decode_base64_to_image(_img, save_name=save_name) if __name__ == "__main__": - import asyncio - engine = SDEngine() prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - + engine.construct_payload(prompt) - + event_loop = asyncio.get_event_loop() event_loop.run_until_complete(engine.run_t2i(prompt)) diff --git a/metagpt/tools/search_engine.py b/metagpt/tools/search_engine.py index 1668dfb5c..d28700054 100644 --- a/metagpt/tools/search_engine.py +++ b/metagpt/tools/search_engine.py @@ -7,118 +7,76 @@ """ from __future__ import annotations -import json +import importlib +from typing import Callable, Coroutine, Literal, overload -from metagpt.config import Config -from metagpt.logs import logger -from metagpt.tools.search_engine_serpapi import SerpAPIWrapper -from metagpt.tools.search_engine_serper import SerperWrapper - -config = Config() +from metagpt.config import CONFIG from metagpt.tools import SearchEngineType class SearchEngine: - """ - TODO: Integrate Google Search and reverse proxy. - Note: Google here requires a Proxifier or similar global proxy. - - DDG: https://pypi.org/project/duckduckgo-search/ - - GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9 - """ - def __init__(self, engine=None, run_func=None): - self.config = Config() - self.run_func = run_func - self.engine = engine or self.config.search_engine + """Class representing a search engine. - @classmethod - def run_google(cls, query, max_results=8): - # results = ddg(query, max_results=max_results) - results = google_official_search(query, num_results=max_results) - logger.info(results) - return results + Args: + engine: The search engine type. Defaults to the search engine specified in the config. + run_func: The function to run the search. Defaults to None. - async def run(self, query: str, max_results=8): - if self.engine == SearchEngineType.SERPAPI_GOOGLE: - api = SerpAPIWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.DIRECT_GOOGLE: - rsp = SearchEngine.run_google(query, max_results) - elif self.engine == SearchEngineType.SERPER_GOOGLE: - api = SerperWrapper() - rsp = await api.run(query) - elif self.engine == SearchEngineType.CUSTOM_ENGINE: - rsp = self.run_func(query) + Attributes: + run_func: The function to run the search. + engine: The search engine type. + """ + def __init__( + self, + engine: SearchEngineType | None = None, + run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None, + ): + engine = engine or CONFIG.search_engine + if engine == SearchEngineType.SERPAPI_GOOGLE: + module = "metagpt.tools.search_engine_serpapi" + run_func = importlib.import_module(module).SerpAPIWrapper().run + elif engine == SearchEngineType.SERPER_GOOGLE: + module = "metagpt.tools.search_engine_serper" + run_func = importlib.import_module(module).SerperWrapper().run + elif engine == SearchEngineType.DIRECT_GOOGLE: + module = "metagpt.tools.search_engine_googleapi" + run_func = importlib.import_module(module).GoogleAPIWrapper().run + elif engine == SearchEngineType.DUCK_DUCK_GO: + module = "metagpt.tools.search_engine_ddg" + run_func = importlib.import_module(module).DDGAPIWrapper().run + elif engine == SearchEngineType.CUSTOM_ENGINE: + pass # run_func = run_func else: raise NotImplementedError - return rsp + self.engine = engine + self.run_func = run_func -def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]: - """Return the results of a Google search using the official Google API + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[True] = True, + ) -> str: + ... - Args: - query (str): The search query. - num_results (int): The number of results to return. + @overload + def run( + self, + query: str, + max_results: int = 8, + as_string: Literal[False] = False, + ) -> list[dict[str, str]]: + ... - Returns: - str: The results of the search. - """ + async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]: + """Run a search query. - from googleapiclient.discovery import build - from googleapiclient.errors import HttpError + Args: + query: The search query. + max_results: The maximum number of results to return. Defaults to 8. + as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True. - try: - api_key = config.google_api_key - custom_search_engine_id = config.google_cse_id - - with build("customsearch", "v1", developerKey=api_key) as service: - - result = ( - service.cse() - .list(q=query, cx=custom_search_engine_id, num=num_results) - .execute() - ) - logger.info(result) - # Extract the search result items from the response - search_results = result.get("items", []) - - # Create a list of only the URLs from the search results - search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results] - - except HttpError as e: - # Handle errors in the API call - error_details = json.loads(e.content.decode()) - - # Check if the error is related to an invalid or missing API key - if error_details.get("error", {}).get( - "code" - ) == 403 and "invalid API key" in error_details.get("error", {}).get( - "message", "" - ): - return "Error: The provided Google API key is invalid or missing." - else: - return f"Error: {e}" - - # Return the list of search result URLs - return search_results_details - -def safe_google_results(results: str | list) -> str: - """ - Return the results of a google search in a safe format. - - Args: - results (str | list): The search results. - - Returns: - str: The results of the search. - """ - if isinstance(results, list): - safe_message = json.dumps( - [result for result in results] - ) - else: - safe_message = results.encode("utf-8", "ignore").decode("utf-8") - return safe_message - -if __name__ == '__main__': - SearchEngine.run(query='wtf') - \ No newline at end of file + Returns: + The search results as a string or a list of dictionaries. + """ + return await self.run_func(query, max_results=max_results, as_string=as_string) diff --git a/metagpt/tools/search_engine_serpapi.py b/metagpt/tools/search_engine_serpapi.py index 2bf07b342..3d2d7cfe4 100644 --- a/metagpt/tools/search_engine_serpapi.py +++ b/metagpt/tools/search_engine_serpapi.py @@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through SerpAPI and parse result async.""" - return self._process_response(await self.results(query)) + return self._process_response(await self.results(query, max_results), as_string=as_string) - async def results(self, query: str) -> dict: + async def results(self, query: str, max_results: int) -> dict: """Use aiohttp to run query through SerpAPI and return the results async.""" def construct_url_and_params() -> Tuple[str, Dict[str, str]]: params = self.get_params(query) params["source"] = "python" + params["num"] = max_results if self.serpapi_api_key: params["serp_api_key"] = self.serpapi_api_key params["output"] = "json" @@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel): return params @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool) -> str: """Process response from SerpAPI.""" # logger.debug(res) - focus = ['title', 'snippet', 'link'] + focus = ["title", "snippet", "link"] get_focused = lambda x: {i: j for i, j in x.items() if i in focus} if "error" in res.keys(): @@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel): toret = res["answer_box"]["answer"] elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys(): toret = res["answer_box"]["snippet"] - elif ( - "answer_box" in res.keys() - and "snippet_highlighted_words" in res["answer_box"].keys() - ): + elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys(): toret = res["answer_box"]["snippet_highlighted_words"][0] - elif ( - "sports_results" in res.keys() - and "game_spotlight" in res["sports_results"].keys() - ): + elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys(): toret = res["sports_results"]["game_spotlight"] - elif ( - "knowledge_graph" in res.keys() - and "description" in res["knowledge_graph"].keys() - ): + elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys(): toret = res["knowledge_graph"]["description"] elif "snippet" in res["organic_results"][0].keys(): toret = res["organic_results"][0]["snippet"] @@ -112,5 +104,10 @@ class SerpAPIWrapper(BaseModel): if res.get("organic_results"): toret_l += [get_focused(i) for i in res.get("organic_results")] - return str(toret) + '\n' + str(toret_l) - \ No newline at end of file + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerpAPIWrapper().run) diff --git a/metagpt/tools/search_engine_serper.py b/metagpt/tools/search_engine_serper.py index 45c19090c..2ae2c3b7d 100644 --- a/metagpt/tools/search_engine_serper.py +++ b/metagpt/tools/search_engine_serper.py @@ -36,16 +36,19 @@ class SerperWrapper(BaseModel): class Config: arbitrary_types_allowed = True - async def run(self, query: str, **kwargs: Any) -> str: + async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str: """Run query through Serper and parse result async.""" - queries = query.split("\n") - return "\n".join([self._process_response(res) for res in await self.results(queries)]) + if isinstance(query, str): + return self._process_response((await self.results([query], max_results))[0], as_string=as_string) + else: + results = [self._process_response(res, as_string) for res in await self.results(query, max_results)] + return "\n".join(results) if as_string else results - async def results(self, queries: list[str]) -> dict: + async def results(self, queries: list[str], max_results: int = 8) -> dict: """Use aiohttp to run query through Serper and return the results async.""" def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]: - payloads = self.get_payloads(queries) + payloads = self.get_payloads(queries, max_results) url = "https://google.serper.dev/search" headers = self.get_headers() return url, payloads, headers @@ -61,12 +64,13 @@ class SerperWrapper(BaseModel): return res - def get_payloads(self, queries: list[str]) -> Dict[str, str]: + def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]: """Get payloads for Serper.""" payloads = [] for query in queries: _payload = { "q": query, + "num": max_results, } payloads.append({**self.payload, **_payload}) return json.dumps(payloads, sort_keys=True) @@ -79,7 +83,7 @@ class SerperWrapper(BaseModel): return headers @staticmethod - def _process_response(res: dict) -> str: + def _process_response(res: dict, as_string: bool = False) -> str: """Process response from SerpAPI.""" # logger.debug(res) focus = ['title', 'snippet', 'link'] @@ -117,5 +121,10 @@ class SerperWrapper(BaseModel): if res.get("organic"): toret_l += [get_focused(i) for i in res.get("organic")] - return str(toret) + '\n' + str(toret_l) - \ No newline at end of file + return str(toret) + '\n' + str(toret_l) if as_string else toret_l + + +if __name__ == "__main__": + import fire + + fire.Fire(SerperWrapper().run) diff --git a/metagpt/tools/web_browser_engine.py b/metagpt/tools/web_browser_engine.py index 67b794dd1..453d87f31 100644 --- a/metagpt/tools/web_browser_engine.py +++ b/metagpt/tools/web_browser_engine.py @@ -1,22 +1,20 @@ #!/usr/bin/env python from __future__ import annotations -import asyncio -import importlib -from typing import Any, Callable, Coroutine, overload +import importlib +from typing import Any, Callable, Coroutine, Literal, overload from metagpt.config import CONFIG from metagpt.tools import WebBrowserEngineType -from bs4 import BeautifulSoup +from metagpt.utils.parse_html import WebPage class WebBrowserEngine: def __init__( self, engine: WebBrowserEngineType | None = None, - run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None, - parse_func: Callable[[str], str] | None = None, + run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None, ): engine = engine or CONFIG.web_browser_engine @@ -30,31 +28,25 @@ class WebBrowserEngine: run_func = run_func else: raise NotImplementedError - self.parse_func = parse_func or get_page_content self.run_func = run_func self.engine = engine @overload - async def run(self, url: str) -> str: + async def run(self, url: str) -> WebPage: ... @overload - async def run(self, url: str, *urls: str) -> list[str]: + async def run(self, url: str, *urls: str) -> list[WebPage]: ... - async def run(self, url: str, *urls: str) -> str | list[str]: - page = await self.run_func(url, *urls) - if isinstance(page, str): - return self.parse_func(page) - return [self.parse_func(i) for i in page] - - -def get_page_content(page: str): - soup = BeautifulSoup(page, "html.parser") - return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"])) + async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: + return await self.run_func(url, *urls) if __name__ == "__main__": - text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/")) - print(text) - \ No newline at end of file + import fire + + async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs): + return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls) + + fire.Fire(main) diff --git a/metagpt/tools/web_browser_engine_selenium.py b/metagpt/tools/web_browser_engine_selenium.py index 78533e05a..d727709b8 100644 --- a/metagpt/tools/web_browser_engine_selenium.py +++ b/metagpt/tools/web_browser_engine_selenium.py @@ -2,16 +2,17 @@ from __future__ import annotations import asyncio -from copy import deepcopy import importlib +from concurrent import futures +from copy import deepcopy from typing import Literal -from metagpt.config import CONFIG -import asyncio from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.wait import WebDriverWait -from concurrent import futures + +from metagpt.config import CONFIG +from metagpt.utils.parse_html import WebPage class SeleniumWrapper: @@ -48,7 +49,7 @@ class SeleniumWrapper: self.loop = loop self.executor = executor - async def run(self, url: str, *urls: str) -> str | list[str]: + async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]: await self._run_precheck() _scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url) @@ -69,9 +70,15 @@ class SeleniumWrapper: def _scrape_website(self, url): with self._get_driver() as driver: - driver.get(url) - WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) - return driver.page_source + try: + driver.get(url) + WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body"))) + inner_text = driver.execute_script("return document.body.innerText;") + html = driver.page_source + except Exception as e: + inner_text = f"Fail to load page content for {e}" + html = "" + return WebPage(inner_text=inner_text, html=html, url=url) _webdriver_manager_types = { @@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): def _get_driver(): options = Options() options.add_argument("--headless") + options.add_argument("--enable-javascript") if browser_type == "chrome": options.add_argument("--no-sandbox") for i in args: @@ -107,6 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None): if __name__ == "__main__": - text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/")) - print(text) - \ No newline at end of file + import fire + + async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs): + return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls) + + fire.Fire(main) diff --git a/metagpt/utils/mermaid.py b/metagpt/utils/mermaid.py index 17ac0db4a..24aabe8ae 100644 --- a/metagpt/utils/mermaid.py +++ b/metagpt/utils/mermaid.py @@ -5,9 +5,9 @@ @Author : alexanderwu @File : mermaid.py """ -import os import subprocess from pathlib import Path + from metagpt.config import CONFIG from metagpt.const import PROJECT_ROOT from metagpt.logs import logger @@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height :return: 0 if succed, -1 if failed """ # Write the Mermaid code to a temporary file - tmp = Path(f'{output_file_without_suffix}.mmd') - tmp.write_text(mermaid_code, encoding='utf-8') + tmp = Path(f"{output_file_without_suffix}.mmd") + tmp.write_text(mermaid_code, encoding="utf-8") - if check_cmd_exists('mmdc') != 0: - logger.warning( - "RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc") + if check_cmd_exists("mmdc") != 0: + logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc") return -1 - for suffix in ['pdf', 'svg', 'png']: - output_file = f'{output_file_without_suffix}.{suffix}' + for suffix in ["pdf", "svg", "png"]: + output_file = f"{output_file_without_suffix}.{suffix}" # Call the `mmdc` command to convert the Mermaid code to a PNG logger.info(f"Generating {output_file}..") if CONFIG.puppeteer_config: - subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o', - output_file, '-w', str(width), '-H', str(height)]) + subprocess.run( + [ + CONFIG.mmdc, + "-p", + CONFIG.puppeteer_config, + "-i", + str(tmp), + "-o", + output_file, + "-w", + str(width), + "-H", + str(height), + ] + ) else: - subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o', - output_file, '-w', str(width), '-H', str(height)]) + subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)]) return 0 @@ -97,8 +108,7 @@ MMC2 = """sequenceDiagram SE-->>M: return summary""" -if __name__ == '__main__': +if __name__ == "__main__": # logger.info(print_members(print_members)) - mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png') - mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png') - \ No newline at end of file + mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png") + mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png")