mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-10 16:22:37 +02:00
Merge branch 'main' into dev_updated
This commit is contained in:
commit
853086924a
429 changed files with 24237 additions and 5835 deletions
|
|
@ -22,3 +22,8 @@ class WebBrowserEngineType(Enum):
|
|||
PLAYWRIGHT = "playwright"
|
||||
SELENIUM = "selenium"
|
||||
CUSTOM = "custom"
|
||||
|
||||
@classmethod
|
||||
def __missing__(cls, key):
|
||||
"""Default type conversion"""
|
||||
return cls.CUSTOM
|
||||
|
|
|
|||
105
metagpt/tools/azure_tts.py
Normal file
105
metagpt/tools/azure_tts.py
Normal file
|
|
@ -0,0 +1,105 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/6/9 22:22
|
||||
@Author : Leo Xiao
|
||||
@File : azure_tts.py
|
||||
@Modified by: mashenquan, 2023/8/17. Azure TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import aiofiles
|
||||
from azure.cognitiveservices.speech import AudioConfig, SpeechConfig, SpeechSynthesizer
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class AzureTTS:
|
||||
"""Azure Text-to-Speech"""
|
||||
|
||||
def __init__(self, subscription_key, region):
|
||||
"""
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
"""
|
||||
self.subscription_key = subscription_key if subscription_key else CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
self.region = region if region else CONFIG.AZURE_TTS_REGION
|
||||
|
||||
# 参数参考:https://learn.microsoft.com/zh-cn/azure/cognitive-services/speech-service/language-support?tabs=tts#voice-styles-and-roles
|
||||
async def synthesize_speech(self, lang, voice, text, output_file):
|
||||
speech_config = SpeechConfig(subscription=self.subscription_key, region=self.region)
|
||||
speech_config.speech_synthesis_voice_name = voice
|
||||
audio_config = AudioConfig(filename=output_file)
|
||||
synthesizer = SpeechSynthesizer(speech_config=speech_config, audio_config=audio_config)
|
||||
|
||||
# More detail: https://learn.microsoft.com/en-us/azure/ai-services/speech-service/speech-synthesis-markup-voice
|
||||
ssml_string = (
|
||||
"<speak version='1.0' xmlns='http://www.w3.org/2001/10/synthesis' "
|
||||
f"xml:lang='{lang}' xmlns:mstts='http://www.w3.org/2001/mstts'>"
|
||||
f"<voice name='{voice}'>{text}</voice></speak>"
|
||||
)
|
||||
|
||||
return synthesizer.speak_ssml_async(ssml_string).get()
|
||||
|
||||
@staticmethod
|
||||
def role_style_text(role, style, text):
|
||||
return f'<mstts:express-as role="{role}" style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def role_text(role, text):
|
||||
return f'<mstts:express-as role="{role}">{text}</mstts:express-as>'
|
||||
|
||||
@staticmethod
|
||||
def style_text(style, text):
|
||||
return f'<mstts:express-as style="{style}">{text}</mstts:express-as>'
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_azsure_tts(text, lang="", voice="", style="", role="", subscription_key="", region=""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
:param lang: The value can contain a language code such as en (English), or a locale such as en-US (English - United States). For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param voice: For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`, `https://speech.microsoft.com/portal/voicegallery`
|
||||
:param style: Speaking style to express different emotions like cheerfulness, empathy, and calm. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param role: With roles, the same voice can act as a different age and gender. For more details, checkout: `https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
:param text: The text used for voice conversion.
|
||||
:param subscription_key: key is used to access your Azure AI service API, see: `https://portal.azure.com/` > `Resource Management` > `Keys and Endpoint`
|
||||
:param region: This is the location (or region) of your resource. You may need to use this field when making calls to this API.
|
||||
:return: Returns the Base64-encoded .wav file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
if not lang:
|
||||
lang = "zh-CN"
|
||||
if not voice:
|
||||
voice = "zh-CN-XiaomoNeural"
|
||||
if not role:
|
||||
role = "Girl"
|
||||
if not style:
|
||||
style = "affectionate"
|
||||
if not subscription_key:
|
||||
subscription_key = CONFIG.AZURE_TTS_SUBSCRIPTION_KEY
|
||||
if not region:
|
||||
region = CONFIG.AZURE_TTS_REGION
|
||||
|
||||
xml_value = AzureTTS.role_style_text(role=role, style=style, text=text)
|
||||
tts = AzureTTS(subscription_key=subscription_key, region=region)
|
||||
filename = Path(__file__).resolve().parent / (str(uuid4()).replace("-", "") + ".wav")
|
||||
try:
|
||||
await tts.synthesize_speech(lang=lang, voice=voice, text=xml_value, output_file=str(filename))
|
||||
async with aiofiles.open(filename, mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
return ""
|
||||
finally:
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
|
|
@ -1,187 +0,0 @@
|
|||
import re
|
||||
from typing import List, Callable, Dict
|
||||
from pathlib import Path
|
||||
|
||||
import wrapt
|
||||
import textwrap
|
||||
import inspect
|
||||
from interpreter.core.core import Interpreter
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.highlight import highlight
|
||||
from metagpt.actions.clone_function import CloneFunction, run_function_code, run_function_script
|
||||
|
||||
|
||||
def extract_python_code(code: str):
|
||||
"""Extract code blocks: If the code comments are the same, only the last code block is kept."""
|
||||
# Use regular expressions to match comment blocks and related code.
|
||||
pattern = r'(#\s[^\n]*)\n(.*?)(?=\n\s*#|$)'
|
||||
matches = re.findall(pattern, code, re.DOTALL)
|
||||
|
||||
# Extract the last code block when encountering the same comment.
|
||||
unique_comments = {}
|
||||
for comment, code_block in matches:
|
||||
unique_comments[comment] = code_block
|
||||
|
||||
# concatenate into functional form
|
||||
result_code = '\n'.join([f"{comment}\n{code_block}" for comment, code_block in unique_comments.items()])
|
||||
header_code = code[:code.find("#")]
|
||||
code = header_code + result_code
|
||||
|
||||
logger.info(f"Extract python code: \n {highlight(code)}")
|
||||
|
||||
return code
|
||||
|
||||
|
||||
class OpenCodeInterpreter(object):
|
||||
"""https://github.com/KillianLucas/open-interpreter"""
|
||||
def __init__(self, auto_run: bool = True) -> None:
|
||||
interpreter = Interpreter()
|
||||
interpreter.auto_run = auto_run
|
||||
interpreter.model = CONFIG.openai_api_model or "gpt-3.5-turbo"
|
||||
interpreter.api_key = CONFIG.openai_api_key
|
||||
# interpreter.api_base = CONFIG.openai_api_base
|
||||
self.interpreter = interpreter
|
||||
|
||||
def chat(self, query: str, reset: bool = True):
|
||||
if reset:
|
||||
self.interpreter.reset()
|
||||
return self.interpreter.chat(query)
|
||||
|
||||
@staticmethod
|
||||
def extract_function(query_respond: List, function_name: str, *, language: str = 'python',
|
||||
function_format: str = None) -> str:
|
||||
"""create a function from query_respond."""
|
||||
if language not in ('python'):
|
||||
raise NotImplementedError(f"Not support to parse language {language}!")
|
||||
|
||||
# set function form
|
||||
if function_format is None:
|
||||
assert language == 'python', f"Expect python language for default function_format, but got {language}."
|
||||
function_format = """def {function_name}():\n{code}"""
|
||||
# Extract the code module in the open-interpreter respond message.
|
||||
# The query_respond of open-interpreter before v0.1.4 is:
|
||||
# [{'role': 'user', 'content': your query string},
|
||||
# {'role': 'assistant', 'content': plan from llm, 'function_call': {
|
||||
# "name": "run_code", "arguments": "{"language": "python", "code": code of first plan},
|
||||
# "parsed_arguments": {"language": "python", "code": code of first plan}
|
||||
# ...]
|
||||
if "function_call" in query_respond[1]:
|
||||
code = [item['function_call']['parsed_arguments']['code'] for item in query_respond
|
||||
if "function_call" in item
|
||||
and "parsed_arguments" in item["function_call"]
|
||||
and 'language' in item["function_call"]['parsed_arguments']
|
||||
and item["function_call"]['parsed_arguments']['language'] == language]
|
||||
# The query_respond of open-interpreter v0.1.7 is:
|
||||
# [{'role': 'user', 'message': your query string},
|
||||
# {'role': 'assistant', 'message': plan from llm, 'language': 'python',
|
||||
# 'code': code of first plan, 'output': output of first plan code},
|
||||
# ...]
|
||||
elif "code" in query_respond[1]:
|
||||
code = [item['code'] for item in query_respond
|
||||
if "code" in item
|
||||
and 'language' in item
|
||||
and item['language'] == language]
|
||||
else:
|
||||
raise ValueError(f"Unexpect message format in query_respond: {query_respond[1].keys()}")
|
||||
# add indent.
|
||||
indented_code_str = textwrap.indent("\n".join(code), ' ' * 4)
|
||||
# Return the code after deduplication.
|
||||
if language == "python":
|
||||
return extract_python_code(function_format.format(function_name=function_name, code=indented_code_str))
|
||||
|
||||
|
||||
def gen_query(func: Callable, args, kwargs) -> str:
|
||||
# Get the annotation of the function as part of the query.
|
||||
desc = func.__doc__
|
||||
signature = inspect.signature(func)
|
||||
# Get the signature of the wrapped function and the assignment of the input parameters as part of the query.
|
||||
bound_args = signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
query = f"{desc}, {bound_args.arguments}, If you must use a third-party package, use the most popular ones, for example: pandas, numpy, ta, ..."
|
||||
return query
|
||||
|
||||
|
||||
def gen_template_fun(func: Callable) -> str:
|
||||
return f"def {func.__name__}{str(inspect.signature(func))}\n # here is your code ..."
|
||||
|
||||
|
||||
class OpenInterpreterDecorator(object):
|
||||
def __init__(self, save_code: bool = False, code_file_path: str = None, clear_code: bool = False) -> None:
|
||||
self.save_code = save_code
|
||||
self.code_file_path = code_file_path
|
||||
self.clear_code = clear_code
|
||||
|
||||
def _have_code(self, rsp: List[Dict]):
|
||||
# Is there any code generated?
|
||||
return 'code' in rsp[1] and rsp[1]['code'] not in ("", None)
|
||||
|
||||
def _is_faild_plan(self, rsp: List[Dict]):
|
||||
# is faild plan?
|
||||
func_code = OpenCodeInterpreter.extract_function(rsp, 'function')
|
||||
# If there is no more than 1 '\n', the plan execution fails.
|
||||
if isinstance(func_code, str) and func_code.count('\n') <= 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _check_respond(self, query: str, interpreter: OpenCodeInterpreter, respond: List[Dict], max_try: int = 3):
|
||||
for _ in range(max_try):
|
||||
# TODO: If no code or faild plan is generated, execute chat again, repeating no more than max_try times.
|
||||
if self._have_code(respond) and not self._is_faild_plan(respond):
|
||||
break
|
||||
elif not self._have_code(respond):
|
||||
logger.warning(f"llm did not return executable code, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
elif self._is_faild_plan(respond):
|
||||
logger.warning(f"llm did not generate successful plan, resend the query: \n{query}")
|
||||
respond = interpreter.chat(query)
|
||||
|
||||
# Post-processing of respond
|
||||
if not self._have_code(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if self._is_faild_plan(respond):
|
||||
error_msg = f"OpenCodeInterpreter do not generate code for query: \n{query}"
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
return respond
|
||||
|
||||
def __call__(self, wrapped):
|
||||
@wrapt.decorator
|
||||
async def wrapper(wrapped: Callable, instance, args, kwargs):
|
||||
# Get the decorated function name.
|
||||
func_name = wrapped.__name__
|
||||
# If the script exists locally and clearcode is not required, execute the function from the script.
|
||||
if self.code_file_path and Path(self.code_file_path).is_file() and not self.clear_code:
|
||||
return run_function_script(self.code_file_path, func_name, *args, **kwargs)
|
||||
|
||||
# Auto run generate code by using open-interpreter.
|
||||
interpreter = OpenCodeInterpreter()
|
||||
query = gen_query(wrapped, args, kwargs)
|
||||
logger.info(f"query for OpenCodeInterpreter: \n {query}")
|
||||
respond = interpreter.chat(query)
|
||||
# Make sure the response is as expected.
|
||||
respond = self._check_respond(query, interpreter, respond, 3)
|
||||
# Assemble the code blocks generated by open-interpreter into a function without parameters.
|
||||
func_code = interpreter.extract_function(respond, func_name)
|
||||
# Clone the `func_code` into wrapped, that is,
|
||||
# keep the `func_code` and wrapped functions with the same input parameter and return value types.
|
||||
template_func = gen_template_fun(wrapped)
|
||||
cf = CloneFunction()
|
||||
code = await cf.run(template_func=template_func, source_code=func_code)
|
||||
# Display the generated function in the terminal.
|
||||
logger_code = highlight(code, "python")
|
||||
logger.info(f"Creating following Python function:\n{logger_code}")
|
||||
# execute this function.
|
||||
try:
|
||||
res = run_function_code(code, func_name, *args, **kwargs)
|
||||
if self.save_code and self.code_file_path:
|
||||
cf._save(self.code_file_path, code)
|
||||
except Exception as e:
|
||||
logger.error(f"Could not evaluate Python code \n{logger_code}: \nError: {e}")
|
||||
raise Exception("Could not evaluate Python code", e)
|
||||
return res
|
||||
return wrapper(wrapped)
|
||||
152
metagpt/tools/iflytek_tts.py
Normal file
152
metagpt/tools/iflytek_tts.py
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : iflytek_tts.py
|
||||
@Desc : iFLYTEK TTS OAS3 api, which provides text-to-speech functionality
|
||||
"""
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import aiofiles
|
||||
import websockets as websockets
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class IFlyTekTTSStatus(Enum):
|
||||
STATUS_FIRST_FRAME = 0 # The first frame
|
||||
STATUS_CONTINUE_FRAME = 1 # The intermediate frame
|
||||
STATUS_LAST_FRAME = 2 # The last frame
|
||||
|
||||
|
||||
class AudioData(BaseModel):
|
||||
audio: str
|
||||
status: int
|
||||
ced: str
|
||||
|
||||
|
||||
class IFlyTekTTSResponse(BaseModel):
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[AudioData] = None
|
||||
sid: str
|
||||
|
||||
|
||||
DEFAULT_IFLYTEK_VOICE = "xiaoyan"
|
||||
|
||||
|
||||
class IFlyTekTTS(object):
|
||||
def __init__(self, app_id: str, api_key: str, api_secret: str):
|
||||
"""
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
"""
|
||||
self.app_id = app_id or CONFIG.IFLYTEK_APP_ID
|
||||
self.api_key = api_key or CONFIG.IFLYTEK_API_KEY
|
||||
self.api_secret = api_secret or CONFIG.API_SECRET
|
||||
|
||||
async def synthesize_speech(self, text, output_file: str, voice=DEFAULT_IFLYTEK_VOICE):
|
||||
url = self._create_url()
|
||||
data = {
|
||||
"common": {"app_id": self.app_id},
|
||||
"business": {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "vcn": voice, "tte": "utf8"},
|
||||
"data": {"status": 2, "text": str(base64.b64encode(text.encode("utf-8")), "UTF8")},
|
||||
}
|
||||
req = json.dumps(data)
|
||||
async with websockets.connect(url) as websocket:
|
||||
# send request
|
||||
await websocket.send(req)
|
||||
|
||||
# receive frames
|
||||
async with aiofiles.open(str(output_file), "wb") as writer:
|
||||
while True:
|
||||
v = await websocket.recv()
|
||||
rsp = IFlyTekTTSResponse(**json.loads(v))
|
||||
if rsp.data:
|
||||
binary_data = base64.b64decode(rsp.data.audio)
|
||||
await writer.write(binary_data)
|
||||
if rsp.data.status != IFlyTekTTSStatus.STATUS_LAST_FRAME.value:
|
||||
continue
|
||||
break
|
||||
|
||||
def _create_url(self):
|
||||
"""Create request url"""
|
||||
url = "wss://tts-api.xfyun.cn/v2/tts"
|
||||
# Generate a timestamp in RFC1123 format
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
signature_origin = "host: " + "ws-api.xfyun.cn" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v2/tts " + "HTTP/1.1"
|
||||
# Perform HMAC-SHA256 encryption
|
||||
signature_sha = hmac.new(
|
||||
self.api_secret.encode("utf-8"), signature_origin.encode("utf-8"), digestmod=hashlib.sha256
|
||||
).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding="utf-8")
|
||||
|
||||
authorization_origin = 'api_key="%s", algorithm="%s", headers="%s", signature="%s"' % (
|
||||
self.api_key,
|
||||
"hmac-sha256",
|
||||
"host date request-line",
|
||||
signature_sha,
|
||||
)
|
||||
authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(encoding="utf-8")
|
||||
# Combine the authentication parameters of the request into a dictionary.
|
||||
v = {"authorization": authorization, "date": date, "host": "ws-api.xfyun.cn"}
|
||||
# Concatenate the authentication parameters to generate the URL.
|
||||
url = url + "?" + urlencode(v)
|
||||
return url
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_iflytek_tts(text: str, voice: str = "", app_id: str = "", api_key: str = "", api_secret: str = ""):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://www.xfyun.cn/doc/tts/online_tts/API.html`
|
||||
|
||||
:param voice: Default `xiaoyan`. For more details, checkout: `https://www.xfyun.cn/doc/tts/online_tts/API.html#%E6%8E%A5%E5%8F%A3%E8%B0%83%E7%94%A8%E6%B5%81%E7%A8%8B`
|
||||
:param text: The text used for voice conversion.
|
||||
:param app_id: Application ID is used to access your iFlyTek service API, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_key: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:param api_secret: WebAPI argument, see: `https://console.xfyun.cn/services/tts`
|
||||
:return: Returns the Base64-encoded .mp3 file data if successful, otherwise an empty string.
|
||||
|
||||
"""
|
||||
if not app_id:
|
||||
app_id = CONFIG.IFLYTEK_APP_ID
|
||||
if not api_key:
|
||||
api_key = CONFIG.IFLYTEK_API_KEY
|
||||
if not api_secret:
|
||||
api_secret = CONFIG.IFLYTEK_API_SECRET
|
||||
if not voice:
|
||||
voice = CONFIG.IFLYTEK_VOICE or DEFAULT_IFLYTEK_VOICE
|
||||
|
||||
filename = Path(__file__).parent / (uuid.uuid4().hex + ".mp3")
|
||||
try:
|
||||
tts = IFlyTekTTS(app_id=app_id, api_key=api_key, api_secret=api_secret)
|
||||
await tts.synthesize_speech(text=text, output_file=str(filename), voice=voice)
|
||||
async with aiofiles.open(str(filename), mode="rb") as reader:
|
||||
data = await reader.read()
|
||||
base64_string = base64.b64encode(data).decode("utf-8")
|
||||
except Exception as e:
|
||||
logger.error(f"text:{text}, error:{e}")
|
||||
base64_string = ""
|
||||
finally:
|
||||
filename.unlink(missing_ok=True)
|
||||
|
||||
return base64_string
|
||||
32
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
32
metagpt/tools/metagpt_oas3_api_svc.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : metagpt_oas3_api_svc.py
|
||||
@Desc : MetaGPT OpenAPI Specification 3.0 REST API service
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://localhost:8080/openapi/greeting/dave' \
|
||||
-H 'accept: text/plain' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
|
||||
def oas_http_svc():
|
||||
"""Start the OAS 3.0 OpenAPI HTTP service"""
|
||||
print("http://localhost:8080/oas3/ui/")
|
||||
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("metagpt_oas3_api.yaml")
|
||||
app.add_api("openapi.yaml")
|
||||
app.run(port=8080)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
oas_http_svc()
|
||||
98
metagpt/tools/metagpt_text_to_image.py
Normal file
98
metagpt/tools/metagpt_text_to_image.py
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : metagpt_text_to_image.py
|
||||
@Desc : MetaGPT Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
import base64
|
||||
from typing import Dict, List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class MetaGPTText2Image:
|
||||
def __init__(self, model_url):
|
||||
"""
|
||||
:param model_url: Model reset api url
|
||||
"""
|
||||
self.model_url = model_url if model_url else CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL
|
||||
|
||||
async def text_2_image(self, text, size_type="512x512"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['512x512', '512x768']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
dims = size_type.split("x")
|
||||
data = {
|
||||
"prompt": text,
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
"override_settings": {"sd_model_checkpoint": "galaxytimemachinesGTM_photoV20"},
|
||||
"seed": -1,
|
||||
"batch_size": 1,
|
||||
"n_iter": 1,
|
||||
"steps": 20,
|
||||
"cfg_scale": 11,
|
||||
"width": int(dims[0]),
|
||||
"height": int(dims[1]), # 768,
|
||||
"restore_faces": False,
|
||||
"tiling": False,
|
||||
"do_not_save_samples": False,
|
||||
"do_not_save_grid": False,
|
||||
"enable_hr": False,
|
||||
"hr_scale": 2,
|
||||
"hr_upscaler": "Latent",
|
||||
"hr_second_pass_steps": 0,
|
||||
"hr_resize_x": 0,
|
||||
"hr_resize_y": 0,
|
||||
"hr_upscale_to_x": 0,
|
||||
"hr_upscale_to_y": 0,
|
||||
"truncate_x": 0,
|
||||
"truncate_y": 0,
|
||||
"applied_old_hires_behavior_to": None,
|
||||
"eta": None,
|
||||
"sampler_index": "DPM++ SDE Karras",
|
||||
"alwayson_scripts": {},
|
||||
}
|
||||
|
||||
class ImageResult(BaseModel):
|
||||
images: List
|
||||
parameters: Dict
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(self.model_url, headers=headers, json=data) as response:
|
||||
result = ImageResult(**await response.json())
|
||||
if len(result.images) == 0:
|
||||
return 0
|
||||
data = base64.b64decode(result.images[0])
|
||||
return data
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return 0
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url=""):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param model_url: Model reset api
|
||||
:param size_type: One of ['512x512', '512x768']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
if not model_url:
|
||||
model_url = CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL
|
||||
return await MetaGPTText2Image(model_url).text_2_image(text, size_type=size_type)
|
||||
|
|
@ -14,14 +14,19 @@ class Moderation:
|
|||
def __init__(self):
|
||||
self.llm = LLM()
|
||||
|
||||
def moderation(self, content: Union[str, list[str]]):
|
||||
def handle_moderation_results(self, results):
|
||||
resp = []
|
||||
for item in results:
|
||||
categories = item.categories.dict()
|
||||
true_categories = [category for category, item_flagged in categories.items() if item_flagged]
|
||||
resp.append({"flagged": item.flagged, "true_categories": true_categories})
|
||||
return resp
|
||||
|
||||
async def amoderation_with_categories(self, content: Union[str, list[str]]):
|
||||
resp = []
|
||||
if content:
|
||||
moderation_results = self.llm.moderation(content=content)
|
||||
results = moderation_results.results
|
||||
for item in results:
|
||||
resp.append(item.flagged)
|
||||
|
||||
moderation_results = await self.llm.amoderation(content=content)
|
||||
resp = self.handle_moderation_results(moderation_results.results)
|
||||
return resp
|
||||
|
||||
async def amoderation(self, content: Union[str, list[str]]):
|
||||
|
|
@ -33,8 +38,3 @@ class Moderation:
|
|||
resp.append(item.flagged)
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
moderation = Moderation()
|
||||
print(moderation.moderation(content=["I will kill you", "The weather is really nice today", "I want to hit you"]))
|
||||
|
|
|
|||
87
metagpt/tools/openai_text_to_embedding.py
Normal file
87
metagpt/tools/openai_text_to_embedding.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/18
|
||||
@Author : mashenquan
|
||||
@File : openai_text_to_embedding.py
|
||||
@Desc : OpenAI Text-to-Embedding OAS3 api, which provides text-to-embedding functionality.
|
||||
For more details, checkout: `https://platform.openai.com/docs/api-reference/embeddings/object`
|
||||
"""
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
"""Represents an embedding vector returned by embedding endpoint."""
|
||||
|
||||
object: str # The object type, which is always "embedding".
|
||||
embedding: List[
|
||||
float
|
||||
] # The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide.
|
||||
index: int # The index of the embedding in the list of embeddings.
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ResultEmbedding(BaseModel):
|
||||
class Config:
|
||||
alias = {"object_": "object"}
|
||||
|
||||
object_: str = ""
|
||||
data: List[Embedding] = []
|
||||
model: str = ""
|
||||
usage: Usage = Field(default_factory=Usage)
|
||||
|
||||
|
||||
class OpenAIText2Embedding:
|
||||
def __init__(self, openai_api_key):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self.openai_api_key = openai_api_key or CONFIG.OPENAI_API_KEY
|
||||
|
||||
async def text_2_embedding(self, text, model="text-embedding-ada-002"):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
|
||||
proxies = {"proxy": CONFIG.openai_proxy} if CONFIG.openai_proxy else {}
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.openai_api_key}"}
|
||||
data = {"input": text, "model": model}
|
||||
url = "https://api.openai.com/v1/embeddings"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, headers=headers, json=data, **proxies) as response:
|
||||
data = await response.json()
|
||||
return ResultEmbedding(**data)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ResultEmbedding()
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_openai_text_to_embedding(text, model="text-embedding-ada-002", openai_api_key=""):
|
||||
"""Text to embedding
|
||||
|
||||
:param text: The text used for embedding.
|
||||
:param model: One of ['text-embedding-ada-002'], ID of the model to use. For more details, checkout: `https://api.openai.com/v1/models`.
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
:return: A json object of :class:`ResultEmbedding` class if successful, otherwise `{}`.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
if not openai_api_key:
|
||||
openai_api_key = CONFIG.OPENAI_API_KEY
|
||||
return await OpenAIText2Embedding(openai_api_key).text_2_embedding(text, model=model)
|
||||
69
metagpt/tools/openai_text_to_image.py
Normal file
69
metagpt/tools/openai_text_to_image.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/8/17
|
||||
@Author : mashenquan
|
||||
@File : openai_text_to_image.py
|
||||
@Desc : OpenAI Text-to-Image OAS3 api, which provides text-to-image functionality.
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class OpenAIText2Image:
|
||||
def __init__(self):
|
||||
"""
|
||||
:param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys`
|
||||
"""
|
||||
self._llm = LLM()
|
||||
|
||||
async def text_2_image(self, text, size_type="1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
try:
|
||||
result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
if result and len(result.data) > 0:
|
||||
return await OpenAIText2Image.get_image_data(result.data[0].url)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
async def get_image_data(url):
|
||||
"""Fetch image data from a URL and encode it as Base64
|
||||
|
||||
:param url: Image url
|
||||
:return: Base64-encoded image data.
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url) as response:
|
||||
response.raise_for_status() # 如果是 4xx 或 5xx 响应,会引发异常
|
||||
image_data = await response.read()
|
||||
return image_data
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return 0
|
||||
|
||||
|
||||
# Export
|
||||
async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"):
|
||||
"""Text to image
|
||||
|
||||
:param text: The text used for image conversion.
|
||||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
if not text:
|
||||
return ""
|
||||
return await OpenAIText2Image().text_2_image(text, size_type=size_type)
|
||||
29
metagpt/tools/openapi_v3_hello.py
Normal file
29
metagpt/tools/openapi_v3_hello.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/2 16:03
|
||||
@Author : mashenquan
|
||||
@File : openapi_v3_hello.py
|
||||
@Desc : Implement the OpenAPI Specification 3.0 demo and use the following command to test the HTTP service:
|
||||
|
||||
curl -X 'POST' \
|
||||
'http://localhost:8082/openapi/greeting/dave' \
|
||||
-H 'accept: text/plain' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{}'
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
import connexion
|
||||
|
||||
|
||||
# openapi implement
|
||||
async def post_greeting(name: str) -> str:
|
||||
return f"Hello {name}\n"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
specification_dir = Path(__file__).parent.parent.parent / "docs/.well-known"
|
||||
app = connexion.AsyncApp(__name__, specification_dir=str(specification_dir))
|
||||
app.add_api("openapi.yaml", arguments={"title": "Hello World Example"})
|
||||
app.run(port=8082)
|
||||
|
|
@ -10,8 +10,9 @@ from typing import Union
|
|||
|
||||
class GPTPromptGenerator:
|
||||
"""Using LLM, given an output, request LLM to provide input (supporting instruction, chatbot, and query styles)"""
|
||||
|
||||
def __init__(self):
|
||||
self._generators = {i: getattr(self, f"gen_{i}_style") for i in ['instruction', 'chatbot', 'query']}
|
||||
self._generators = {i: getattr(self, f"gen_{i}_style") for i in ["instruction", "chatbot", "query"]}
|
||||
|
||||
def gen_instruction_style(self, example):
|
||||
"""Instruction style: Given an output, request LLM to provide input"""
|
||||
|
|
@ -35,7 +36,7 @@ Query: X
|
|||
Document: {example} What is the detailed query X?
|
||||
X:"""
|
||||
|
||||
def gen(self, example: str, style: str = 'all') -> Union[list[str], str]:
|
||||
def gen(self, example: str, style: str = "all") -> Union[list[str], str]:
|
||||
"""
|
||||
Generate one or multiple outputs using the example, allowing LLM to reply with the corresponding input
|
||||
|
||||
|
|
@ -43,7 +44,7 @@ X:"""
|
|||
:param style: (all|instruction|chatbot|query)
|
||||
:return: Expected LLM input sample (one or multiple)
|
||||
"""
|
||||
if style != 'all':
|
||||
if style != "all":
|
||||
return self._generators[style](example)
|
||||
return [f(example) for f in self._generators.values()]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,24 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
from os.path import join
|
||||
from typing import List
|
||||
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
from metagpt.config import Config
|
||||
from metagpt.const import WORKSPACE_ROOT
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SD_OUTPUT_FILE_REPO
|
||||
from metagpt.logs import logger
|
||||
|
||||
config = Config()
|
||||
|
||||
payload = {
|
||||
"prompt": "",
|
||||
"negative_prompt": "(easynegative:0.8),black, dark,Low resolution",
|
||||
|
|
@ -56,9 +53,8 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
class SDEngine:
|
||||
def __init__(self):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.config = Config()
|
||||
self.sd_url = self.config.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{self.config.get('SD_T2I_API')}"
|
||||
self.sd_url = CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
logger.info(self.sd_t2i_url)
|
||||
|
|
@ -81,10 +77,10 @@ class SDEngine:
|
|||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
save_dir = WORKSPACE_ROOT / "resources" / "SD_Output"
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, save_dir, save_name=save_name)
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
|
|
@ -120,11 +116,13 @@ def decode_base64_to_image(img, save_name):
|
|||
image.save(f"{save_name}.png", pnginfo=pnginfo)
|
||||
return pnginfo, image
|
||||
|
||||
|
||||
def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
||||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
@File : search_engine.py
|
||||
"""
|
||||
import importlib
|
||||
from typing import Callable, Coroutine, Literal, overload, Optional, Union
|
||||
from typing import Callable, Coroutine, Literal, Optional, Union, overload
|
||||
|
||||
from semantic_kernel.skill_definition import sk_function
|
||||
|
||||
|
|
@ -43,8 +43,8 @@ class SearchEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Optional[SearchEngineType] = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
engine: Optional[SearchEngineType] = None,
|
||||
run_func: Callable[[str, int, bool], Coroutine[None, None, Union[str, list[str]]]] = None,
|
||||
):
|
||||
engine = engine or CONFIG.search_engine
|
||||
if engine == SearchEngineType.SERPAPI_GOOGLE:
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -25,15 +25,14 @@ except ImportError:
|
|||
|
||||
|
||||
class GoogleAPIWrapper(BaseModel):
|
||||
google_api_key: Optional[str] = None
|
||||
google_cse_id: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
google_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("google_api_key", always=True)
|
||||
@field_validator("google_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or CONFIG.google_api_key
|
||||
|
|
@ -45,7 +44,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
)
|
||||
return val
|
||||
|
||||
@validator("google_cse_id", always=True)
|
||||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or CONFIG.google_cse_id
|
||||
|
|
|
|||
|
|
@ -11,6 +11,8 @@ from typing import List
|
|||
import meilisearch
|
||||
from meilisearch.index import Index
|
||||
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
class DataSource:
|
||||
def __init__(self, name: str, url: str):
|
||||
|
|
@ -29,16 +31,12 @@ class MeilisearchEngine:
|
|||
def add_documents(self, data_source: DataSource, documents: List[dict]):
|
||||
index_name = f"{data_source.name}_index"
|
||||
if index_name not in self.client.get_indexes():
|
||||
self.client.create_index(uid=index_name, options={'primaryKey': 'id'})
|
||||
self.client.create_index(uid=index_name, options={"primaryKey": "id"})
|
||||
index = self.client.get_index(index_name)
|
||||
index.add_documents(documents)
|
||||
self.set_index(index)
|
||||
|
||||
@handle_exception(exception_type=Exception, default_return=[])
|
||||
def search(self, query):
|
||||
try:
|
||||
search_results = self._index.search(query)
|
||||
return search_results['hits']
|
||||
except Exception as e:
|
||||
# Handle MeiliSearch API errors
|
||||
print(f"MeiliSearch API error: {e}")
|
||||
return []
|
||||
search_results = self._index.search(query)
|
||||
return search_results["hits"]
|
||||
|
|
|
|||
|
|
@ -8,28 +8,28 @@
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
search_engine: Any #: :meta private:
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
params: dict = Field(
|
||||
default={
|
||||
default_factory=lambda: {
|
||||
"engine": "google",
|
||||
"google_domain": "google.com",
|
||||
"gl": "us",
|
||||
"hl": "en",
|
||||
}
|
||||
)
|
||||
serpapi_api_key: Optional[str] = None
|
||||
# should add `validate_default=True` to check with default value
|
||||
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("serpapi_api_key", always=True)
|
||||
@field_validator("serpapi_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or CONFIG.serpapi_api_key
|
||||
|
|
@ -43,7 +43,8 @@ class SerpAPIWrapper(BaseModel):
|
|||
|
||||
async def run(self, query, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
"""Run query through SerpAPI and parse result async."""
|
||||
return self._process_response(await self.results(query, max_results), as_string=as_string)
|
||||
result = await self.results(query, max_results)
|
||||
return self._process_response(result, as_string=as_string)
|
||||
|
||||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
|
|
|||
|
|
@ -9,21 +9,20 @@ import json
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
search_engine: Any #: :meta private:
|
||||
payload: dict = Field(default={"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("serper_api_key", always=True)
|
||||
@field_validator("serper_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or CONFIG.serper_api_key
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
@File : translator.py
|
||||
"""
|
||||
|
||||
prompt = '''
|
||||
prompt = """
|
||||
# 指令
|
||||
接下来,作为一位拥有20年翻译经验的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的{LANG}翻译。注意以下要求:
|
||||
1. 确保翻译结果流畅且易于理解
|
||||
|
|
@ -17,11 +17,10 @@ prompt = '''
|
|||
{ORIGINAL}
|
||||
|
||||
# 译文
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class Translator:
|
||||
|
||||
@classmethod
|
||||
def translate_prompt(cls, original, lang='中文'):
|
||||
return prompt.format(LANG=lang, ORIGINAL=original)
|
||||
def translate_prompt(cls, original, lang="中文"):
|
||||
return prompt.format(LANG=lang, ORIGINAL=original)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
from metagpt.utils.common import awrite
|
||||
|
||||
ICL_SAMPLE = '''Interface definition:
|
||||
ICL_SAMPLE = """Interface definition:
|
||||
```text
|
||||
Interface Name: Element Tagging
|
||||
Interface Path: /projects/{project_key}/node-tags
|
||||
|
|
@ -60,20 +61,20 @@ def test_node_tags(project_key, nodes, operations, expected_msg):
|
|||
# 3. If comments are needed, use Chinese.
|
||||
|
||||
# If you understand, please wait for me to give the interface definition and just answer "Understood" to save tokens.
|
||||
'''
|
||||
"""
|
||||
|
||||
ACT_PROMPT_PREFIX = '''Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type.
|
||||
ACT_PROMPT_PREFIX = """Refer to the test types: such as missing request parameters, field boundary verification, incorrect field type.
|
||||
Please output 10 test cases within one `@pytest.mark.parametrize` scope.
|
||||
```text
|
||||
'''
|
||||
"""
|
||||
|
||||
YFT_PROMPT_PREFIX = '''Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation,
|
||||
YFT_PROMPT_PREFIX = """Refer to the test types: such as SQL injection, cross-site scripting (XSS), unauthorized access and privilege escalation,
|
||||
authentication and authorization, parameter verification, exception handling, file upload and download.
|
||||
Please output 10 test cases within one `@pytest.mark.parametrize` scope.
|
||||
```text
|
||||
'''
|
||||
"""
|
||||
|
||||
OCR_API_DOC = '''```text
|
||||
OCR_API_DOC = """```text
|
||||
Interface Name: OCR recognition
|
||||
Interface Path: /api/v1/contract/treaty/task/ocr
|
||||
Method: POST
|
||||
|
|
@ -96,14 +97,20 @@ code integer Yes
|
|||
message string Yes
|
||||
data object Yes
|
||||
```
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
class UTGenerator:
|
||||
"""UT Generator: Construct UT through API documentation"""
|
||||
|
||||
def __init__(self, swagger_file: str, ut_py_path: str, questions_path: str,
|
||||
chatgpt_method: str = "API", template_prefix=YFT_PROMPT_PREFIX) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
swagger_file: str,
|
||||
ut_py_path: str,
|
||||
questions_path: str,
|
||||
chatgpt_method: str = "API",
|
||||
template_prefix=YFT_PROMPT_PREFIX,
|
||||
) -> None:
|
||||
"""Initialize UT Generator
|
||||
|
||||
Args:
|
||||
|
|
@ -168,6 +175,9 @@ class UTGenerator:
|
|||
return doc
|
||||
|
||||
for name, prop in node.items():
|
||||
if not isinstance(prop, dict):
|
||||
doc += f'{" " * level}{self._para_to_str(node)}\n'
|
||||
break
|
||||
doc += f'{" " * level}{self.para_to_str(name, prop, prop_object_required)}\n'
|
||||
doc += dive_into_object(prop)
|
||||
if prop["type"] == "array":
|
||||
|
|
@ -196,12 +206,12 @@ class UTGenerator:
|
|||
|
||||
return tags
|
||||
|
||||
def generate_ut(self, include_tags) -> bool:
|
||||
async def generate_ut(self, include_tags) -> bool:
|
||||
"""Generate test case files"""
|
||||
tags = self.get_tags_mapping()
|
||||
for tag, paths in tags.items():
|
||||
if include_tags is None or tag in include_tags:
|
||||
self._generate_ut(tag, paths)
|
||||
await self._generate_ut(tag, paths)
|
||||
return True
|
||||
|
||||
def build_api_doc(self, node: dict, path: str, method: str) -> str:
|
||||
|
|
@ -244,21 +254,16 @@ class UTGenerator:
|
|||
|
||||
return doc
|
||||
|
||||
def _store(self, data, base, folder, fname):
|
||||
"""Store data in a file."""
|
||||
file_path = self.get_file_path(Path(base) / folder, fname)
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
file.write(data)
|
||||
|
||||
def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
async def ask_gpt_and_save(self, question: str, tag: str, fname: str):
|
||||
"""Generate questions and store both questions and answers"""
|
||||
messages = [self.icl_sample, question]
|
||||
result = self.gpt_msgs_to_code(messages=messages)
|
||||
result = await self.gpt_msgs_to_code(messages=messages)
|
||||
|
||||
self._store(question, self.questions_path, tag, f"{fname}.txt")
|
||||
self._store(result, self.ut_py_path, tag, f"{fname}.py")
|
||||
await awrite(Path(self.questions_path) / tag / f"{fname}.txt", question)
|
||||
data = result.get("code", "") if result else ""
|
||||
await awrite(Path(self.ut_py_path) / tag / f"{fname}.py", data)
|
||||
|
||||
def _generate_ut(self, tag, paths):
|
||||
async def _generate_ut(self, tag, paths):
|
||||
"""Process the structure under a data path
|
||||
|
||||
Args:
|
||||
|
|
@ -270,24 +275,12 @@ class UTGenerator:
|
|||
summary = node["summary"]
|
||||
question = self.template_prefix
|
||||
question += self.build_api_doc(node, path, method)
|
||||
self.ask_gpt_and_save(question, tag, summary)
|
||||
await self.ask_gpt_and_save(question, tag, summary)
|
||||
|
||||
def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
async def gpt_msgs_to_code(self, messages: list) -> str:
|
||||
"""Choose based on different calling methods"""
|
||||
result = ''
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = GPTAPI().ask_code(msgs=messages)
|
||||
result = await GPTAPI().aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
||||
def get_file_path(self, base: Path, fname: str):
|
||||
"""Save different file paths
|
||||
|
||||
Args:
|
||||
base (str): Path
|
||||
fname (str): File name
|
||||
"""
|
||||
path = Path(base)
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
file_path = path / fname
|
||||
return str(file_path)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Any, Callable, Coroutine, Literal, overload
|
||||
from typing import Any, Callable, Coroutine, overload
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.tools import WebBrowserEngineType
|
||||
|
|
@ -17,14 +20,16 @@ class WebBrowserEngine:
|
|||
run_func: Callable[..., Coroutine[Any, Any, WebPage | list[WebPage]]] | None = None,
|
||||
):
|
||||
engine = engine or CONFIG.web_browser_engine
|
||||
if engine is None:
|
||||
raise NotImplementedError
|
||||
|
||||
if engine == WebBrowserEngineType.PLAYWRIGHT:
|
||||
if WebBrowserEngineType(engine) is WebBrowserEngineType.PLAYWRIGHT:
|
||||
module = "metagpt.tools.web_browser_engine_playwright"
|
||||
run_func = importlib.import_module(module).PlaywrightWrapper().run
|
||||
elif engine == WebBrowserEngineType.SELENIUM:
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.SELENIUM:
|
||||
module = "metagpt.tools.web_browser_engine_selenium"
|
||||
run_func = importlib.import_module(module).SeleniumWrapper().run
|
||||
elif engine == WebBrowserEngineType.CUSTOM:
|
||||
elif WebBrowserEngineType(engine) is WebBrowserEngineType.CUSTOM:
|
||||
run_func = run_func
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
|
@ -41,12 +46,3 @@ class WebBrowserEngine:
|
|||
|
||||
async def run(self, url: str, *urls: str) -> WebPage | list[WebPage]:
|
||||
return await self.run_func(url, *urls)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, engine_type: Literal["playwright", "selenium"] = "playwright", **kwargs):
|
||||
return await WebBrowserEngine(WebBrowserEngineType(engine_type), **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -138,12 +142,3 @@ async def _log_stream(sr, log_func):
|
|||
|
||||
_install_lock: asyncio.Lock = None
|
||||
_install_cache = set()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chromium", **kwargs):
|
||||
return await PlaywrightWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
|
@ -10,6 +14,8 @@ from typing import Literal
|
|||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions as EC
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
from webdriver_manager.core.download_manager import WDMDownloadManager
|
||||
from webdriver_manager.core.http import WDMHttpClient
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.utils.parse_html import WebPage
|
||||
|
|
@ -89,6 +95,13 @@ _webdriver_manager_types = {
|
|||
}
|
||||
|
||||
|
||||
class WDMHttpProxyClient(WDMHttpClient):
|
||||
def get(self, url, **kwargs):
|
||||
if "proxies" not in kwargs and CONFIG.global_proxy:
|
||||
kwargs["proxies"] = {"all_proxy": CONFIG.global_proxy}
|
||||
return super().get(url, **kwargs)
|
||||
|
||||
|
||||
def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
||||
WebDriver = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.webdriver"), "WebDriver")
|
||||
Service = getattr(importlib.import_module(f"selenium.webdriver.{browser_type}.service"), "Service")
|
||||
|
|
@ -97,7 +110,7 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
if not executable_path:
|
||||
module_name, type_name = _webdriver_manager_types[browser_type]
|
||||
DriverManager = getattr(importlib.import_module(module_name), type_name)
|
||||
driver_manager = DriverManager()
|
||||
driver_manager = DriverManager(download_manager=WDMDownloadManager(http_client=WDMHttpProxyClient()))
|
||||
# driver_manager.driver_cache.find_driver(driver_manager.driver))
|
||||
executable_path = driver_manager.install()
|
||||
|
||||
|
|
@ -106,18 +119,11 @@ def _gen_get_driver_func(browser_type, *args, executable_path=None):
|
|||
options.add_argument("--headless")
|
||||
options.add_argument("--enable-javascript")
|
||||
if browser_type == "chrome":
|
||||
options.add_argument("--disable-gpu") # This flag can help avoid renderer issue
|
||||
options.add_argument("--disable-dev-shm-usage") # Overcome limited resource problems
|
||||
options.add_argument("--no-sandbox")
|
||||
for i in args:
|
||||
options.add_argument(i)
|
||||
return WebDriver(options=deepcopy(options), service=Service(executable_path=executable_path))
|
||||
|
||||
return _get_driver
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import fire
|
||||
|
||||
async def main(url: str, *urls: str, browser_type: str = "chrome", **kwargs):
|
||||
return await SeleniumWrapper(browser_type, **kwargs).run(url, *urls)
|
||||
|
||||
fire.Fire(main)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue