mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
resolved current conflicts
This commit is contained in:
parent
730e2f912f
commit
193178b7d1
11 changed files with 290 additions and 269 deletions
16
Dockerfile
16
Dockerfile
|
|
@ -1,10 +1,10 @@
|
|||
# Use a base image with Python3.9 and Nodejs20 slim version
|
||||
FROM nikolaik/python-nodejs:python3.9-nodejs20-slim
|
||||
|
||||
# Install Debian software needed by MetaGPT
|
||||
# Install Debian software needed by MetaGPT and clean up in one RUN command to reduce image size
|
||||
RUN apt update &&\
|
||||
apt install -y git chromium fonts-ipafont-gothic fonts-wqy-zenhei fonts-thai-tlwg fonts-kacst fonts-freefont-ttf libxss1 --no-install-recommends &&\
|
||||
apt clean
|
||||
apt clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Install Mermaid CLI globally
|
||||
ENV CHROME_BIN="/usr/bin/chromium" \
|
||||
|
|
@ -15,13 +15,11 @@ RUN npm install -g @mermaid-js/mermaid-cli &&\
|
|||
|
||||
# Install Python dependencies and install MetaGPT
|
||||
COPY . /app/metagpt
|
||||
RUN cd /app/metagpt &&\
|
||||
mkdir workspace &&\
|
||||
pip install -r requirements.txt &&\
|
||||
pip cache purge &&\
|
||||
WORKDIR /app/metagpt
|
||||
RUN mkdir workspace &&\
|
||||
pip install --no-cache-dir -r requirements.txt &&\
|
||||
python setup.py install
|
||||
|
||||
WORKDIR /app/metagpt
|
||||
|
||||
# Running with an infinite loop using the tail command
|
||||
CMD ["sh", "-c", "tail -f /dev/null"]
|
||||
CMD ["sh", "-c", "tail -f /dev/null"]
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from metagpt.actions.design_api import WriteDesign
|
|||
from metagpt.actions.design_api_review import DesignReview
|
||||
from metagpt.actions.design_filenames import DesignFilenames
|
||||
from metagpt.actions.project_management import AssignTasks, WriteTasks
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.actions.search_and_summarize import SearchAndSummarize
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
|
|
@ -26,6 +27,7 @@ from metagpt.actions.write_test import WriteTest
|
|||
|
||||
class ActionType(Enum):
|
||||
"""All types of Actions, used for indexing."""
|
||||
|
||||
ADD_REQUIREMENT = BossRequirement
|
||||
WRITE_PRD = WritePRD
|
||||
WRITE_PRD_REVIEW = WritePRDReview
|
||||
|
|
@ -40,4 +42,13 @@ class ActionType(Enum):
|
|||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
CONDUCT_RESEARCH = ConductResearch
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ActionType",
|
||||
"Action",
|
||||
"ActionOutput",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@
|
|||
Provide configuration, singleton
|
||||
"""
|
||||
import os
|
||||
import openai
|
||||
|
||||
import openai
|
||||
import yaml
|
||||
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.tools import SearchEngineType, WebBrowserEngineType
|
||||
from metagpt.utils.singleton import Singleton
|
||||
|
||||
|
||||
class NotConfiguredException(Exception):
|
||||
|
|
@ -46,7 +46,6 @@ class Config(metaclass=Singleton):
|
|||
self.openai_api_key = self._get("OPENAI_API_KEY")
|
||||
if not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key:
|
||||
raise NotConfiguredException("Set OPENAI_API_KEY first")
|
||||
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
if not self.openai_api_base or "YOUR_API_BASE" == self.openai_api_base:
|
||||
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
|
|
@ -67,21 +66,22 @@ class Config(metaclass=Singleton):
|
|||
self.google_api_key = self._get("GOOGLE_API_KEY")
|
||||
self.google_cse_id = self._get("GOOGLE_CSE_ID")
|
||||
self.search_engine = self._get("SEARCH_ENGINE", SearchEngineType.SERPAPI_GOOGLE)
|
||||
|
||||
|
||||
self.web_browser_engine = WebBrowserEngineType(self._get("WEB_BROWSER_ENGINE", "playwright"))
|
||||
self.playwright_browser_type = self._get("PLAYWRIGHT_BROWSER_TYPE", "chromium")
|
||||
self.selenium_browser_type = self._get("SELENIUM_BROWSER_TYPE", "chrome")
|
||||
|
||||
|
||||
self.long_term_memory = self._get('LONG_TERM_MEMORY', False)
|
||||
if self.long_term_memory:
|
||||
logger.warning("LONG_TERM_MEMORY is True")
|
||||
self.max_budget = self._get("MAX_BUDGET", 10.0)
|
||||
self.total_cost = 0.0
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG","")
|
||||
self.mmdc = self._get("MMDC","mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS",True)
|
||||
self.calc_usage = self._get("CALC_USAGE",True)
|
||||
|
||||
self.puppeteer_config = self._get("PUPPETEER_CONFIG", "")
|
||||
self.mmdc = self._get("MMDC", "mmdc")
|
||||
self.update_costs = self._get("UPDATE_COSTS", True)
|
||||
self.calc_usage = self._get("CALC_USAGE", True)
|
||||
self.model_for_researcher_summary = self._get("MODEL_FOR_RESEARCHER_SUMMARY")
|
||||
self.model_for_researcher_report = self._get("MODEL_FOR_RESEARCHER_REPORT")
|
||||
|
||||
def _init_with_config_files_and_env(self, configs: dict, yaml_file):
|
||||
"""Load from config/key.yaml, config/config.yaml, and env in decreasing order of priority"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:08
|
||||
|
|
@ -7,10 +6,11 @@
|
|||
"""
|
||||
import asyncio
|
||||
import time
|
||||
from functools import wraps
|
||||
from typing import NamedTuple
|
||||
|
||||
import openai
|
||||
from openai.error import APIConnectionError
|
||||
from tenacity import retry, stop_after_attempt, after_log, wait_fixed, retry_if_exception_type
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -20,8 +20,10 @@ from metagpt.utils.token_counter import (
|
|||
TOKEN_COSTS,
|
||||
count_message_tokens,
|
||||
count_string_tokens,
|
||||
get_max_completion_tokens,
|
||||
)
|
||||
|
||||
<<<<<<< main
|
||||
def retry(max_retries):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
|
|
@ -41,10 +43,21 @@ class RateLimiter:
|
|||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
self.interval = 1.1 * 60 / rpm # Here 1.1 is used because even if the calls are made strictly on time, they will still be QOS'd; consider switching to simple error retry later
|
||||
=======
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
>>>>>>> main
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i:i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
|
|
@ -64,7 +77,8 @@ class Costs(NamedTuple):
|
|||
total_budget: float
|
||||
|
||||
class CostManager(metaclass=Singleton):
|
||||
"""Calculate the cost of using the interface"""
|
||||
"""计算使用接口的开销"""
|
||||
|
||||
def __init__(self):
|
||||
self.total_prompt_tokens = 0
|
||||
self.total_completion_tokens = 0
|
||||
|
|
@ -82,13 +96,12 @@ class CostManager(metaclass=Singleton):
|
|||
"""
|
||||
self.total_prompt_tokens += prompt_tokens
|
||||
self.total_completion_tokens += completion_tokens
|
||||
cost = (
|
||||
prompt_tokens * TOKEN_COSTS[model]["prompt"]
|
||||
+ completion_tokens * TOKEN_COSTS[model]["completion"]
|
||||
) / 1000
|
||||
cost = (prompt_tokens * TOKEN_COSTS[model]["prompt"] + completion_tokens * TOKEN_COSTS[model]["completion"]) / 1000
|
||||
self.total_cost += cost
|
||||
logger.info(f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, {prompt_tokens=}, {completion_tokens=}")
|
||||
logger.info(
|
||||
f"Total running cost: ${self.total_cost:.3f} | Max budget: ${CONFIG.max_budget:.3f} | "
|
||||
f"Current cost: ${cost:.3f}, prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
def get_total_prompt_tokens(self):
|
||||
|
|
@ -122,14 +135,25 @@ def get_costs(self) -> Costs:
|
|||
"""Get all costs"""
|
||||
return Costs(self.total_prompt_tokens, self.total_completion_tokens, self.total_cost, self.total_budget)
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning("""
|
||||
Recommend going to https://deepwisdom.feishu.cn/wiki/MsGnwQBjiif9c3koSJNcYaoSnu4#part-XdatdVlhEojeAfxaaEZcMV3ZniQ
|
||||
See FAQ 5.8
|
||||
""")
|
||||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.__init_openai(CONFIG)
|
||||
self.llm = openai
|
||||
self.model = CONFIG.openai_api_model
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = CostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
|
|
@ -143,10 +167,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self.rpm = int(config.get("RPM", 10))
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
response = await openai.ChatCompletion.acreate(
|
||||
**self._cons_kwargs(messages),
|
||||
stream=True
|
||||
)
|
||||
response = await openai.ChatCompletion.acreate(**self._cons_kwargs(messages), stream=True)
|
||||
|
||||
# create variables to collect the stream of chunks
|
||||
collected_chunks = []
|
||||
|
|
@ -154,41 +175,42 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# iterate through the stream of events
|
||||
async for chunk in response:
|
||||
collected_chunks.append(chunk) # save the event response
|
||||
chunk_message = chunk['choices'][0]['delta'] # extract the message
|
||||
chunk_message = chunk["choices"][0]["delta"] # extract the message
|
||||
collected_messages.append(chunk_message) # save the message
|
||||
if "content" in chunk_message:
|
||||
print(chunk_message["content"], end="")
|
||||
print()
|
||||
|
||||
full_reply_content = ''.join([m.get('content', '') for m in collected_messages])
|
||||
full_reply_content = "".join([m.get("content", "") for m in collected_messages])
|
||||
usage = self._calc_usage(messages, full_reply_content)
|
||||
self._update_costs(usage)
|
||||
return full_reply_content
|
||||
|
||||
def _cons_kwargs(self, messages: list[dict]) -> dict:
|
||||
if CONFIG.openai_api_type == 'azure':
|
||||
if CONFIG.openai_api_type == "azure":
|
||||
kwargs = {
|
||||
"deployment_id": CONFIG.deployment_id,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
else:
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"max_tokens": CONFIG.max_tokens_rsp,
|
||||
"max_tokens": self.get_max_tokens(messages),
|
||||
"n": 1,
|
||||
"stop": None,
|
||||
"temperature": 0.3
|
||||
"temperature": 0.3,
|
||||
}
|
||||
kwargs["timeout"] = 3
|
||||
return kwargs
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
rsp = await self.llm.ChatCompletion.acreate(**self._cons_kwargs(messages))
|
||||
self._update_costs(rsp.get('usage'))
|
||||
self._update_costs(rsp.get("usage"))
|
||||
return rsp
|
||||
|
||||
def _chat_completion(self, messages: list[dict]) -> dict:
|
||||
|
|
@ -206,7 +228,13 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
# messages = self.messages_to_dict(messages)
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
@retry(max_retries=6)
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(1),
|
||||
after=after_log(logger, logger.level('WARNING').name),
|
||||
retry=retry_if_exception_type(APIConnectionError),
|
||||
retry_error_callback=log_and_reraise,
|
||||
)
|
||||
async def acompletion_text(self, messages: list[dict], stream=False) -> str:
|
||||
"""when streaming, print each token in place."""
|
||||
if stream:
|
||||
|
|
@ -257,3 +285,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
|
||||
def get_costs(self) -> Costs:
|
||||
return self._cost_manager.get_costs()
|
||||
|
||||
def get_max_tokens(self, messages: list[dict]):
|
||||
if not self.auto_max_tokens:
|
||||
return CONFIG.max_tokens_rsp
|
||||
return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp)
|
||||
|
|
|
|||
|
|
@ -2,29 +2,27 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import os
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from typing import List
|
||||
import json
|
||||
import io
|
||||
import base64
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import Config
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
||||
config = Config()
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {
|
||||
"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"
|
||||
},
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
|
|
@ -36,21 +34,20 @@ payload = {
|
|||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
'enable_hr': False,
|
||||
'hr_scale': 2,
|
||||
'hr_upscaler': 'Latent',
|
||||
'hr_second_pass_steps': 0,
|
||||
'hr_resize_x': 0,
|
||||
'hr_resize_y': 0,
|
||||
'hr_upscale_to_x': 0,
|
||||
'hr_upscale_to_y': 0,
|
||||
'truncate_x': 0,
|
||||
'truncate_y': 0,
|
||||
'applied_old_hires_behavior_to': None,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {}
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
||||
|
|
@ -60,14 +57,20 @@ class SDEngine:
|
|||
def __init__(self):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.config = Config()
|
||||
self.sd_url = self.config.get('SD_URL')
|
||||
self.sd_url = self.config.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
||||
def construct_payload(self, prompt, negtive_prompt=default_negative_prompt, width=512, height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20"):
|
||||
|
||||
def construct_payload(
|
||||
self,
|
||||
prompt,
|
||||
negtive_prompt=default_negative_prompt,
|
||||
width=512,
|
||||
height=512,
|
||||
sd_model="galaxytimemachinesGTM_photoV20",
|
||||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
|
|
@ -76,13 +79,13 @@ class SDEngine:
|
|||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources"/"SD_Output"
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
|
|
@ -90,24 +93,25 @@ class SDEngine:
|
|||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json['images']
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: Add image-to-image interface call
|
||||
raise NotImplementedError
|
||||
async def run(self, url, payload, session):
|
||||
# Perform the HTTP POST request to the SD API
|
||||
async with session.post(url, json=payload, timeout=600) as rsp:
|
||||
data = await rsp.read()
|
||||
|
||||
rsp_json = json.loads(data)
|
||||
imgs = rsp_json["images"]
|
||||
logger.info(f"callback rsp json is {rsp_json.keys()}")
|
||||
return imgs
|
||||
|
||||
async def run_i2i(self):
|
||||
# todo: 添加图生图接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo:添加SAM接口调用
|
||||
raise NotImplementedError
|
||||
|
||||
async def run_sam(self):
|
||||
# todo: Add SAM interface call
|
||||
raise NotImplementedError
|
||||
|
||||
def decode_base64_to_image(img, save_name):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(img.split(",", 1)[0])))
|
||||
|
|
@ -122,12 +126,10 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
|
|
@ -7,118 +7,76 @@
|
|||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.search_engine_serpapi import SerpAPIWrapper
|
||||
from metagpt.tools.search_engine_serper import SerperWrapper
|
||||
|
||||
config = Config()
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import SearchEngineType
|
||||
|
||||
|
||||
class SearchEngine:
|
||||
"""
|
||||
TODO: Integrate Google Search and reverse proxy.
|
||||
Note: Google here requires a Proxifier or similar global proxy.
|
||||
- DDG: https://pypi.org/project/duckduckgo-search/
|
||||
- GOOGLE: https://programmablesearchengine.google.com/controlpanel/overview?cx=63f9de531d0e24de9
|
||||
"""
|
||||
def __init__(self, engine=None, run_func=None):
|
||||
self.config = Config()
|
||||
self.run_func = run_func
|
||||
self.engine = engine or self.config.search_engine
|
||||
"""Class representing a search engine.
|
||||
|
||||
@classmethod
|
||||
def run_google(cls, query, max_results=8):
|
||||
# results = ddg(query, max_results=max_results)
|
||||
results = google_official_search(query, num_results=max_results)
|
||||
logger.info(results)
|
||||
return results
|
||||
Args:
|
||||
engine: The search engine type. Defaults to the search engine specified in the config.
|
||||
run_func: The function to run the search. Defaults to None.
|
||||
|
||||
async def run(self, query: str, max_results=8):
|
||||
if self.engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
api = SerpAPIWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
rsp = SearchEngine.run_google(query, max_results)
|
||||
elif self.engine == SearchEngineType.SERPER_GOOGLE:
|
||||
api = SerperWrapper()
|
||||
rsp = await api.run(query)
|
||||
elif self.engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
rsp = self.run_func(query)
|
||||
Attributes:
|
||||
run_func: The function to run the search.
|
||||
engine: The search engine type.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
engine: SearchEngineType | None = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, str | list[str]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serpapi"
|
||||
run_func = importlib.import_module(module).SerpAPIWrapper().run
|
||||
elif engine == SearchEngineType.SERPER_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_serper"
|
||||
run_func = importlib.import_module(module).SerperWrapper().run
|
||||
elif engine == SearchEngineType.DIRECT_GOOGLE:
|
||||
module = "metagpt.tools.search_engine_googleapi"
|
||||
run_func = importlib.import_module(module).GoogleAPIWrapper().run
|
||||
elif engine == SearchEngineType.DUCK_DUCK_GO:
|
||||
module = "metagpt.tools.search_engine_ddg"
|
||||
run_func = importlib.import_module(module).DDGAPIWrapper().run
|
||||
elif engine == SearchEngineType.CUSTOM_ENGINE:
|
||||
pass # run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return rsp
|
||||
self.engine = engine
|
||||
self.run_func = run_func
|
||||
|
||||
def google_official_search(query: str, num_results: int = 8, focus=['snippet', 'link', 'title']) -> dict | list[dict]:
|
||||
"""Return the results of a Google search using the official Google API
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[True] = True,
|
||||
) -> str:
|
||||
...
|
||||
|
||||
Args:
|
||||
query (str): The search query.
|
||||
num_results (int): The number of results to return.
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
query: str,
|
||||
max_results: int = 8,
|
||||
as_string: Literal[False] = False,
|
||||
) -> list[dict[str, str]]:
|
||||
...
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True) -> str | list[dict[str, str]]:
|
||||
"""Run a search query.
|
||||
|
||||
from googleapiclient.discovery import build
|
||||
from googleapiclient.errors import HttpError
|
||||
Args:
|
||||
query: The search query.
|
||||
max_results: The maximum number of results to return. Defaults to 8.
|
||||
as_string: Whether to return the results as a string or a list of dictionaries. Defaults to True.
|
||||
|
||||
try:
|
||||
api_key = config.google_api_key
|
||||
custom_search_engine_id = config.google_cse_id
|
||||
|
||||
with build("customsearch", "v1", developerKey=api_key) as service:
|
||||
|
||||
result = (
|
||||
service.cse()
|
||||
.list(q=query, cx=custom_search_engine_id, num=num_results)
|
||||
.execute()
|
||||
)
|
||||
logger.info(result)
|
||||
# Extract the search result items from the response
|
||||
search_results = result.get("items", [])
|
||||
|
||||
# Create a list of only the URLs from the search results
|
||||
search_results_details = [{i: j for i, j in item_dict.items() if i in focus} for item_dict in search_results]
|
||||
|
||||
except HttpError as e:
|
||||
# Handle errors in the API call
|
||||
error_details = json.loads(e.content.decode())
|
||||
|
||||
# Check if the error is related to an invalid or missing API key
|
||||
if error_details.get("error", {}).get(
|
||||
"code"
|
||||
) == 403 and "invalid API key" in error_details.get("error", {}).get(
|
||||
"message", ""
|
||||
):
|
||||
return "Error: The provided Google API key is invalid or missing."
|
||||
else:
|
||||
return f"Error: {e}"
|
||||
|
||||
# Return the list of search result URLs
|
||||
return search_results_details
|
||||
|
||||
def safe_google_results(results: str | list) -> str:
|
||||
"""
|
||||
Return the results of a google search in a safe format.
|
||||
|
||||
Args:
|
||||
results (str | list): The search results.
|
||||
|
||||
Returns:
|
||||
str: The results of the search.
|
||||
"""
|
||||
if isinstance(results, list):
|
||||
safe_message = json.dumps(
|
||||
[result for result in results]
|
||||
)
|
||||
else:
|
||||
safe_message = results.encode("utf-8", "ignore").decode("utf-8")
|
||||
return safe_message
|
||||
|
||||
if __name__ == '__main__':
|
||||
SearchEngine.run(query='wtf')
|
||||
|
||||
Returns:
|
||||
The search results as a string or a list of dictionaries.
|
||||
"""
|
||||
return await self.run_func(query, max_results=max_results, as_string=as_string)
|
||||
|
|
|
|||
|
|
@ -37,16 +37,17 @@ class SerpAPIWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query))
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
|
||||
async def results(self, query: str) -> dict:
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
if self.serpapi_api_key:
|
||||
params["serp_api_key"] = self.serpapi_api_key
|
||||
params["output"] = "json"
|
||||
|
|
@ -74,10 +75,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
return params
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
focus = ["title", "snippet", "link"]
|
||||
get_focused = lambda x: {i: j for i, j in x.items() if i in focus}
|
||||
|
||||
if "error" in res.keys():
|
||||
|
|
@ -86,20 +87,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
toret = res["answer_box"]["answer"]
|
||||
elif "answer_box" in res.keys() and "snippet" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet"]
|
||||
elif (
|
||||
"answer_box" in res.keys()
|
||||
and "snippet_highlighted_words" in res["answer_box"].keys()
|
||||
):
|
||||
elif "answer_box" in res.keys() and "snippet_highlighted_words" in res["answer_box"].keys():
|
||||
toret = res["answer_box"]["snippet_highlighted_words"][0]
|
||||
elif (
|
||||
"sports_results" in res.keys()
|
||||
and "game_spotlight" in res["sports_results"].keys()
|
||||
):
|
||||
elif "sports_results" in res.keys() and "game_spotlight" in res["sports_results"].keys():
|
||||
toret = res["sports_results"]["game_spotlight"]
|
||||
elif (
|
||||
"knowledge_graph" in res.keys()
|
||||
and "description" in res["knowledge_graph"].keys()
|
||||
):
|
||||
elif "knowledge_graph" in res.keys() and "description" in res["knowledge_graph"].keys():
|
||||
toret = res["knowledge_graph"]["description"]
|
||||
elif "snippet" in res["organic_results"][0].keys():
|
||||
toret = res["organic_results"][0]["snippet"]
|
||||
|
|
@ -112,5 +104,10 @@ class SerpAPIWrapper(BaseModel):
|
|||
if res.get("organic_results"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic_results")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerpAPIWrapper().run)
|
||||
|
|
|
|||
|
|
@ -36,16 +36,19 @@ class SerperWrapper(BaseModel):
|
|||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
async def run(self, query: str, **kwargs: Any) -> str:
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through Serper and parse result async."""
|
||||
queries = query.split("\n")
|
||||
return "\n".join([self._process_response(res) for res in await self.results(queries)])
|
||||
if isinstance(query, str):
|
||||
return self._process_response((await self.results([query], max_results))[0], as_string=as_string)
|
||||
else:
|
||||
results = [self._process_response(res, as_string) for res in await self.results(query, max_results)]
|
||||
return "\n".join(results) if as_string else results
|
||||
|
||||
async def results(self, queries: list[str]) -> dict:
|
||||
async def results(self, queries: list[str], max_results: int = 8) -> dict:
|
||||
"""Use aiohttp to run query through Serper and return the results async."""
|
||||
|
||||
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
|
||||
payloads = self.get_payloads(queries)
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = self.get_headers()
|
||||
return url, payloads, headers
|
||||
|
|
@ -61,12 +64,13 @@ class SerperWrapper(BaseModel):
|
|||
|
||||
return res
|
||||
|
||||
def get_payloads(self, queries: list[str]) -> Dict[str, str]:
|
||||
def get_payloads(self, queries: list[str], max_results: int) -> Dict[str, str]:
|
||||
"""Get payloads for Serper."""
|
||||
payloads = []
|
||||
for query in queries:
|
||||
_payload = {
|
||||
"q": query,
|
||||
"num": max_results,
|
||||
}
|
||||
payloads.append({**self.payload, **_payload})
|
||||
return json.dumps(payloads, sort_keys=True)
|
||||
|
|
@ -79,7 +83,7 @@ class SerperWrapper(BaseModel):
|
|||
return headers
|
||||
|
||||
@staticmethod
|
||||
def _process_response(res: dict) -> str:
|
||||
def _process_response(res: dict, as_string: bool = False) -> str:
|
||||
"""Process response from SerpAPI."""
|
||||
# logger.debug(res)
|
||||
focus = ['title', 'snippet', 'link']
|
||||
|
|
@ -117,5 +121,10 @@ class SerperWrapper(BaseModel):
|
|||
if res.get("organic"):
|
||||
toret_l += [get_focused(i) for i in res.get("organic")]
|
||||
|
||||
return str(toret) + '\n' + str(toret_l)
|
||||
|
||||
return str(toret) + '\n' + str(toret_l) if as_string else toret_l
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
fire.Fire(SerperWrapper().run)
|
||||
|
|
|
|||
|
|
@ -1,22 +1,20 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import importlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Literal, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
from bs4 import BeautifulSoup
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class WebBrowserEngine:
|
||||
def __init__(
|
||||
self,
|
||||
engine: WebBrowserEngineType | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, str | list[str]]] | None = None,
|
||||
parse_func: Callable[[str], str] | None = None,
|
||||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
|
||||
|
|
@ -30,31 +28,25 @@ class WebBrowserEngine:
|
|||
run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
self.parse_func = parse_func or get_page_content
|
||||
self.run_func = run_func
|
||||
self.engine = engine
|
||||
|
||||
@overload
|
||||
async def run(self, url: str) -> str:
|
||||
async def run(self, url: str) -> WebPage:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def run(self, url: str, *urls: str) -> list[str]:
|
||||
async def run(self, url: str, *urls: str) -> list[WebPage]:
|
||||
...
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
page = await self.run_func(url, *urls)
|
||||
if isinstance(page, str):
|
||||
return self.parse_func(page)
|
||||
return [self.parse_func(i) for i in page]
|
||||
|
||||
|
||||
def get_page_content(page: str):
|
||||
soup = BeautifulSoup(page, "html.parser")
|
||||
return "\n".join(i.text.strip() for i in soup.find_all(["h1", "h2", "h3", "h4", "h5", "p", "pre"]))
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
return await self.run_func(url, *urls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(WebBrowserEngine().run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -2,16 +2,17 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from copy import deepcopy
|
||||
import importlib
|
||||
from concurrent import futures
|
||||
from copy import deepcopy
|
||||
from typing import Literal
|
||||
from metagpt.config import CONFIG
|
||||
import asyncio
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from concurrent import futures
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
||||
|
||||
class SeleniumWrapper:
|
||||
|
|
@ -48,7 +49,7 @@ class SeleniumWrapper:
|
|||
self.loop = loop
|
||||
self.executor = executor
|
||||
|
||||
async def run(self, url: str, *urls: str) -> str | list[str]:
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
await self._run_precheck()
|
||||
|
||||
_scrape = lambda url: self.loop.run_in_executor(self.executor, self._scrape_website, url)
|
||||
|
|
@ -69,9 +70,15 @@ class SeleniumWrapper:
|
|||
|
||||
def _scrape_website(self, url):
|
||||
with self._get_driver() as driver:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
return driver.page_source
|
||||
try:
|
||||
driver.get(url)
|
||||
WebDriverWait(driver, 30).until(EC.presence_of_element_located((By.TAG_NAME, "body")))
|
||||
inner_text = driver.execute_script("return document.body.innerText;")
|
||||
html = driver.page_source
|
||||
except Exception as e:
|
||||
inner_text = f"Fail to load page content for {e}"
|
||||
html = ""
|
||||
return WebPage(inner_text=inner_text, html=html, url=url)
|
||||
|
||||
|
||||
_webdriver_manager_types = {
|
||||
|
|
@ -97,6 +104,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
def _get_driver():
|
||||
options = Options()
|
||||
options.add_argument("--headless")
|
||||
options.add_argument("--enable-javascript")
|
||||
if browser_type == "chrome":
|
||||
options.add_argument("--no-sandbox")
|
||||
for i in args:
|
||||
|
|
@ -107,6 +115,9 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
text = asyncio.run(SeleniumWrapper("chrome").run("https://fuzhi.ai/"))
|
||||
print(text)
|
||||
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
@Author : alexanderwu
|
||||
@File : mermaid.py
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import PROJECT_ROOT
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -24,25 +24,36 @@ def mermaid_to_file(mermaid_code, output_file_without_suffix, width=2048, height
|
|||
:return: 0 if succed, -1 if failed
|
||||
"""
|
||||
# Write the Mermaid code to a temporary file
|
||||
tmp = Path(f'{output_file_without_suffix}.mmd')
|
||||
tmp.write_text(mermaid_code, encoding='utf-8')
|
||||
tmp = Path(f"{output_file_without_suffix}.mmd")
|
||||
tmp.write_text(mermaid_code, encoding="utf-8")
|
||||
|
||||
if check_cmd_exists('mmdc') != 0:
|
||||
logger.warning(
|
||||
"RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
|
||||
if check_cmd_exists("mmdc") != 0:
|
||||
logger.warning("RUN `npm install -g @mermaid-js/mermaid-cli` to install mmdc")
|
||||
return -1
|
||||
|
||||
for suffix in ['pdf', 'svg', 'png']:
|
||||
output_file = f'{output_file_without_suffix}.{suffix}'
|
||||
for suffix in ["pdf", "svg", "png"]:
|
||||
output_file = f"{output_file_without_suffix}.{suffix}"
|
||||
# Call the `mmdc` command to convert the Mermaid code to a PNG
|
||||
logger.info(f"Generating {output_file}..")
|
||||
|
||||
if CONFIG.puppeteer_config:
|
||||
subprocess.run([CONFIG.mmdc, '-p', CONFIG.puppeteer_config, '-i', str(tmp), '-o',
|
||||
output_file, '-w', str(width), '-H', str(height)])
|
||||
subprocess.run(
|
||||
[
|
||||
CONFIG.mmdc,
|
||||
"-p",
|
||||
CONFIG.puppeteer_config,
|
||||
"-i",
|
||||
str(tmp),
|
||||
"-o",
|
||||
output_file,
|
||||
"-w",
|
||||
str(width),
|
||||
"-H",
|
||||
str(height),
|
||||
]
|
||||
)
|
||||
else:
|
||||
subprocess.run([CONFIG.mmdc, '-i', str(tmp), '-o',
|
||||
output_file, '-w', str(width), '-H', str(height)])
|
||||
subprocess.run([CONFIG.mmdc, "-i", str(tmp), "-o", output_file, "-w", str(width), "-H", str(height)])
|
||||
return 0
|
||||
|
||||
|
||||
|
|
@ -97,8 +108,7 @@ MMC2 = """sequenceDiagram
|
|||
SE-->>M: return summary"""
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# logger.info(print_members(print_members))
|
||||
mermaid_to_file(MMC1, PROJECT_ROOT / 'tmp/1.png')
|
||||
mermaid_to_file(MMC2, PROJECT_ROOT / 'tmp/2.png')
|
||||
|
||||
mermaid_to_file(MMC1, PROJECT_ROOT / "tmp/1.png")
|
||||
mermaid_to_file(MMC2, PROJECT_ROOT / "tmp/2.png")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue