feat: +s3

This commit is contained in:
莘权 马 2023-09-02 14:30:45 +08:00
parent 6b66429af8
commit ca60cd0557
5 changed files with 79 additions and 53 deletions

View file

@ -54,3 +54,6 @@ METAGPT_API_KEY = "METAGPT_API_KEY"
METAGPT_API_BASE = "METAGPT_API_BASE"
METAGPT_API_TYPE = "METAGPT_API_TYPE"
METAGPT_API_VERSION = "METAGPT_API_VERSION"
# format
BASE64_FORMAT = "base64"

View file

@ -6,10 +6,13 @@
@File : text_to_image.py
@Desc : Text-to-Image skill, which provides text-to-image functionality.
"""
import openai.error
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
from metagpt.utils.s3 import S3
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
@ -23,13 +26,14 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod
"""
image_declaration = "data:image/png;base64,"
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
data = await oas3_metagpt_text_to_image(text, size_type, model_url)
return image_declaration + data if data else ""
if CONFIG.OPENAI_API_KEY or openai_api_key:
data = await oas3_openai_text_to_image(text, size_type, openai_api_key)
return image_declaration + data if data else ""
raise EnvironmentError
base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
elif CONFIG.OPENAI_API_KEY or openai_api_key:
base64_data = await oas3_openai_text_to_image(text, size_type, openai_api_key)
else:
raise openai.error.InvalidRequestError("缺少必要的参数")
s3 = S3()
url = await s3.cache(base64_data, BASE64_FORMAT)
if url:
return url
return image_declaration + base64_data if base64_data else ""

View file

@ -6,14 +6,24 @@
@File : text_to_speech.py
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
"""
import openai
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT
from metagpt.tools.azure_tts import oas3_azsure_tts
from metagpt.utils.s3 import S3
async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl",
subscription_key="", region="", **kwargs):
async def text_to_speech(
text,
lang="zh-CN",
voice="zh-CN-XiaomoNeural",
style="affectionate",
role="Girl",
subscription_key="",
region="",
**kwargs
):
"""Text to speech
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
@ -28,9 +38,12 @@ async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="
"""
audio_declaration = "data:audio/wav;base64,"
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or \
(subscription_key and region):
data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
return audio_declaration + data if data else data
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
s3 = S3()
url = await s3.cache(base64_data, BASE64_FORMAT)
if url:
return url
return audio_declaration + base64_data if base64_data else base64_data
raise EnvironmentError
raise openai.error.InvalidRequestError("缺少必要的参数")

View file

@ -8,18 +8,12 @@
"""
import asyncio
import base64
import os
import sys
from pathlib import Path
from typing import List
import aiohttp
import openai
import requests
from pydantic import BaseModel
from metagpt.config import CONFIG, Config
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt'
from metagpt.logs import logger
@ -37,27 +31,21 @@ class OpenAIText2Image:
:param size_type: One of ['256x256', '512x512', '1024x1024']
:return: The image data is returned in Base64 encoding.
"""
class ImageUrl(BaseModel):
url: str
class ImageResult(BaseModel):
data: List[ImageUrl]
created: int
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.openai_api_key}"
}
data = {"prompt": text, "n": 1, "size": size_type}
try:
async with aiohttp.ClientSession() as session:
async with session.post("https://api.openai.com/v1/images/generations", headers=headers, json=data) as response:
result = ImageResult(** await response.json())
except requests.exceptions.RequestException as e:
result = await openai.Image.acreate(
api_key=CONFIG.OPENAI_API_KEY,
api_base=CONFIG.OPENAI_API_BASE,
api_type=None,
api_version=None,
organization=None,
prompt=text,
n=1,
size=size_type,
)
except Exception as e:
logger.error(f"An error occurred:{e}")
return ""
if len(result.data) > 0:
if result and len(result.data) > 0:
return await OpenAIText2Image.get_image_data(result.data[0].url)
return ""

View file

@ -1,9 +1,14 @@
import base64
import traceback
import uuid
from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT, WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.config import Config
class S3:
@ -11,12 +16,12 @@ class S3:
def __init__(self):
self.session = aioboto3.Session()
self.s3_config = Config().get("S3")
self.s3_config = CONFIG.S3
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": self.s3_config["access_key"],
"aws_secret_access_key": self.s3_config["secret_key"],
"endpoint_url": self.s3_config["endpoint_url"]
"endpoint_url": self.s3_config["endpoint_url"],
}
async def upload_file(
@ -95,11 +100,7 @@ class S3:
raise e
async def download_file(
self,
bucket: str,
object_name: str,
local_path: str,
chunk_size: Optional[int] = 128 * 1024
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
) -> None:
"""Download an S3 object to a local file.
@ -116,7 +117,7 @@ class S3:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
with open(local_path, 'wb') as local_file:
with open(local_path, "wb") as local_file:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
@ -124,4 +125,21 @@ class S3:
local_file.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e
raise e
async def cache(self, data: str, format: str = "") -> str:
"""Save data to remote S3 and return url"""
object_name = str(uuid.uuid4()).replace("-", "")
pathname = WORKSPACE_ROOT / "s3_tmp" / object_name
try:
async with aiofiles.open(pathname, mode="w") as file:
if format == BASE64_FORMAT:
data = base64.b64decode(data)
await file.write(data)
bucket = CONFIG.S3.get("bucket")
await self.upload_file(bucket=bucket, local_path=pathname, object_name=object_name)
return await self.get_object_url(bucket=bucket, object_name=object_name)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None