From ca60cd0557effda735c4850b0f3b36fadd555fdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Sat, 2 Sep 2023 14:30:45 +0800 Subject: [PATCH] feat: +s3 --- metagpt/const.py | 3 ++ metagpt/learn/text_to_image.py | 22 +++++++++------ metagpt/learn/text_to_speech.py | 29 +++++++++++++------ metagpt/tools/openai_text_to_image.py | 38 +++++++++---------------- metagpt/utils/s3.py | 40 +++++++++++++++++++-------- 5 files changed, 79 insertions(+), 53 deletions(-) diff --git a/metagpt/const.py b/metagpt/const.py index f2f1b4837..fbc2c928a 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -54,3 +54,6 @@ METAGPT_API_KEY = "METAGPT_API_KEY" METAGPT_API_BASE = "METAGPT_API_BASE" METAGPT_API_TYPE = "METAGPT_API_TYPE" METAGPT_API_VERSION = "METAGPT_API_VERSION" + +# format +BASE64_FORMAT = "base64" diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index 620e58180..c5f554ef3 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -6,10 +6,13 @@ @File : text_to_image.py @Desc : Text-to-Image skill, which provides text-to-image functionality. """ +import openai.error from metagpt.config import CONFIG +from metagpt.const import BASE64_FORMAT from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image +from metagpt.utils.s3 import S3 async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs): @@ -23,13 +26,14 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod """ image_declaration = "data:image/png;base64," if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: - data = await oas3_metagpt_text_to_image(text, size_type, model_url) - return image_declaration + data if data else "" - - if CONFIG.OPENAI_API_KEY or openai_api_key: - data = await oas3_openai_text_to_image(text, size_type, openai_api_key) - return image_declaration + data if data else "" - - raise EnvironmentError - + base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url) + elif CONFIG.OPENAI_API_KEY or openai_api_key: + base64_data = await oas3_openai_text_to_image(text, size_type, openai_api_key) + else: + raise openai.error.InvalidRequestError("缺少必要的参数") + s3 = S3() + url = await s3.cache(base64_data, BASE64_FORMAT) + if url: + return url + return image_declaration + base64_data if base64_data else "" diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 66fbba5be..7883ae9f3 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -6,14 +6,24 @@ @File : text_to_speech.py @Desc : Text-to-Speech skill, which provides text-to-speech functionality """ +import openai from metagpt.config import CONFIG - +from metagpt.const import BASE64_FORMAT from metagpt.tools.azure_tts import oas3_azsure_tts +from metagpt.utils.s3 import S3 -async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl", - subscription_key="", region="", **kwargs): +async def text_to_speech( + text, + lang="zh-CN", + voice="zh-CN-XiaomoNeural", + style="affectionate", + role="Girl", + subscription_key="", + region="", + **kwargs +): """Text to speech For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts` @@ -28,9 +38,12 @@ async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style=" """ audio_declaration = "data:audio/wav;base64," - if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or \ - (subscription_key and region): - data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region) - return audio_declaration + data if data else data + if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region): + base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region) + s3 = S3() + url = await s3.cache(base64_data, BASE64_FORMAT) + if url: + return url + return audio_declaration + base64_data if base64_data else base64_data - raise EnvironmentError + raise openai.error.InvalidRequestError("缺少必要的参数") diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index 395fa8133..6025f04ba 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -8,18 +8,12 @@ """ import asyncio import base64 -import os -import sys -from pathlib import Path -from typing import List import aiohttp +import openai import requests -from pydantic import BaseModel from metagpt.config import CONFIG, Config - -sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt' from metagpt.logs import logger @@ -37,27 +31,21 @@ class OpenAIText2Image: :param size_type: One of ['256x256', '512x512', '1024x1024'] :return: The image data is returned in Base64 encoding. """ - - class ImageUrl(BaseModel): - url: str - - class ImageResult(BaseModel): - data: List[ImageUrl] - created: int - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self.openai_api_key}" - } - data = {"prompt": text, "n": 1, "size": size_type} try: - async with aiohttp.ClientSession() as session: - async with session.post("https://api.openai.com/v1/images/generations", headers=headers, json=data) as response: - result = ImageResult(** await response.json()) - except requests.exceptions.RequestException as e: + result = await openai.Image.acreate( + api_key=CONFIG.OPENAI_API_KEY, + api_base=CONFIG.OPENAI_API_BASE, + api_type=None, + api_version=None, + organization=None, + prompt=text, + n=1, + size=size_type, + ) + except Exception as e: logger.error(f"An error occurred:{e}") return "" - if len(result.data) > 0: + if result and len(result.data) > 0: return await OpenAIText2Image.get_image_data(result.data[0].url) return "" diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py index 2b4b8cb5f..85837fedb 100644 --- a/metagpt/utils/s3.py +++ b/metagpt/utils/s3.py @@ -1,9 +1,14 @@ - +import base64 +import traceback +import uuid from typing import Optional import aioboto3 +import aiofiles + +from metagpt.config import CONFIG +from metagpt.const import BASE64_FORMAT, WORKSPACE_ROOT from metagpt.logs import logger -from metagpt.config import Config class S3: @@ -11,12 +16,12 @@ class S3: def __init__(self): self.session = aioboto3.Session() - self.s3_config = Config().get("S3") + self.s3_config = CONFIG.S3 self.auth_config = { "service_name": "s3", "aws_access_key_id": self.s3_config["access_key"], "aws_secret_access_key": self.s3_config["secret_key"], - "endpoint_url": self.s3_config["endpoint_url"] + "endpoint_url": self.s3_config["endpoint_url"], } async def upload_file( @@ -95,11 +100,7 @@ class S3: raise e async def download_file( - self, - bucket: str, - object_name: str, - local_path: str, - chunk_size: Optional[int] = 128 * 1024 + self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024 ) -> None: """Download an S3 object to a local file. @@ -116,7 +117,7 @@ class S3: async with self.session.client(**self.auth_config) as client: s3_object = await client.get_object(Bucket=bucket, Key=object_name) stream = s3_object["Body"] - with open(local_path, 'wb') as local_file: + with open(local_path, "wb") as local_file: while True: file_data = await stream.read(chunk_size) if not file_data: @@ -124,4 +125,21 @@ class S3: local_file.write(file_data) except Exception as e: logger.error(f"Failed to download the file from S3: {e}") - raise e \ No newline at end of file + raise e + + async def cache(self, data: str, format: str = "") -> str: + """Save data to remote S3 and return url""" + object_name = str(uuid.uuid4()).replace("-", "") + pathname = WORKSPACE_ROOT / "s3_tmp" / object_name + try: + async with aiofiles.open(pathname, mode="w") as file: + if format == BASE64_FORMAT: + data = base64.b64decode(data) + await file.write(data) + + bucket = CONFIG.S3.get("bucket") + await self.upload_file(bucket=bucket, local_path=pathname, object_name=object_name) + return await self.get_object_url(bucket=bucket, object_name=object_name) + except Exception as e: + logger.exception(f"{e}, stack:{traceback.format_exc()}") + return None