mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
feat: +s3
This commit is contained in:
parent
6b66429af8
commit
ca60cd0557
5 changed files with 79 additions and 53 deletions
|
|
@ -54,3 +54,6 @@ METAGPT_API_KEY = "METAGPT_API_KEY"
|
|||
METAGPT_API_BASE = "METAGPT_API_BASE"
|
||||
METAGPT_API_TYPE = "METAGPT_API_TYPE"
|
||||
METAGPT_API_VERSION = "METAGPT_API_VERSION"
|
||||
|
||||
# format
|
||||
BASE64_FORMAT = "base64"
|
||||
|
|
|
|||
|
|
@ -6,10 +6,13 @@
|
|||
@File : text_to_image.py
|
||||
@Desc : Text-to-Image skill, which provides text-to-image functionality.
|
||||
"""
|
||||
import openai.error
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image
|
||||
from metagpt.tools.openai_text_to_image import oas3_openai_text_to_image
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_image(text, size_type: str = "512x512", openai_api_key="", model_url="", **kwargs):
|
||||
|
|
@ -23,13 +26,14 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod
|
|||
"""
|
||||
image_declaration = "data:image/png;base64,"
|
||||
if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url:
|
||||
data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
return image_declaration + data if data else ""
|
||||
|
||||
if CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
data = await oas3_openai_text_to_image(text, size_type, openai_api_key)
|
||||
return image_declaration + data if data else ""
|
||||
|
||||
raise EnvironmentError
|
||||
|
||||
base64_data = await oas3_metagpt_text_to_image(text, size_type, model_url)
|
||||
elif CONFIG.OPENAI_API_KEY or openai_api_key:
|
||||
base64_data = await oas3_openai_text_to_image(text, size_type, openai_api_key)
|
||||
else:
|
||||
raise openai.error.InvalidRequestError("缺少必要的参数")
|
||||
|
||||
s3 = S3()
|
||||
url = await s3.cache(base64_data, BASE64_FORMAT)
|
||||
if url:
|
||||
return url
|
||||
return image_declaration + base64_data if base64_data else ""
|
||||
|
|
|
|||
|
|
@ -6,14 +6,24 @@
|
|||
@File : text_to_speech.py
|
||||
@Desc : Text-to-Speech skill, which provides text-to-speech functionality
|
||||
"""
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
from metagpt.const import BASE64_FORMAT
|
||||
from metagpt.tools.azure_tts import oas3_azsure_tts
|
||||
from metagpt.utils.s3 import S3
|
||||
|
||||
|
||||
async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="affectionate", role="Girl",
|
||||
subscription_key="", region="", **kwargs):
|
||||
async def text_to_speech(
|
||||
text,
|
||||
lang="zh-CN",
|
||||
voice="zh-CN-XiaomoNeural",
|
||||
style="affectionate",
|
||||
role="Girl",
|
||||
subscription_key="",
|
||||
region="",
|
||||
**kwargs
|
||||
):
|
||||
"""Text to speech
|
||||
For more details, check out:`https://learn.microsoft.com/en-us/azure/ai-services/speech-service/language-support?tabs=tts`
|
||||
|
||||
|
|
@ -28,9 +38,12 @@ async def text_to_speech(text, lang="zh-CN", voice="zh-CN-XiaomoNeural", style="
|
|||
|
||||
"""
|
||||
audio_declaration = "data:audio/wav;base64,"
|
||||
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or \
|
||||
(subscription_key and region):
|
||||
data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
return audio_declaration + data if data else data
|
||||
if (CONFIG.AZURE_TTS_SUBSCRIPTION_KEY and CONFIG.AZURE_TTS_REGION) or (subscription_key and region):
|
||||
base64_data = await oas3_azsure_tts(text, lang, voice, style, role, subscription_key, region)
|
||||
s3 = S3()
|
||||
url = await s3.cache(base64_data, BASE64_FORMAT)
|
||||
if url:
|
||||
return url
|
||||
return audio_declaration + base64_data if base64_data else base64_data
|
||||
|
||||
raise EnvironmentError
|
||||
raise openai.error.InvalidRequestError("缺少必要的参数")
|
||||
|
|
|
|||
|
|
@ -8,18 +8,12 @@
|
|||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from metagpt.config import CONFIG, Config
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) # fix-bug: No module named 'metagpt'
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
|
|
@ -37,27 +31,21 @@ class OpenAIText2Image:
|
|||
:param size_type: One of ['256x256', '512x512', '1024x1024']
|
||||
:return: The image data is returned in Base64 encoding.
|
||||
"""
|
||||
|
||||
class ImageUrl(BaseModel):
|
||||
url: str
|
||||
|
||||
class ImageResult(BaseModel):
|
||||
data: List[ImageUrl]
|
||||
created: int
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.openai_api_key}"
|
||||
}
|
||||
data = {"prompt": text, "n": 1, "size": size_type}
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post("https://api.openai.com/v1/images/generations", headers=headers, json=data) as response:
|
||||
result = ImageResult(** await response.json())
|
||||
except requests.exceptions.RequestException as e:
|
||||
result = await openai.Image.acreate(
|
||||
api_key=CONFIG.OPENAI_API_KEY,
|
||||
api_base=CONFIG.OPENAI_API_BASE,
|
||||
api_type=None,
|
||||
api_version=None,
|
||||
organization=None,
|
||||
prompt=text,
|
||||
n=1,
|
||||
size=size_type,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred:{e}")
|
||||
return ""
|
||||
if len(result.data) > 0:
|
||||
if result and len(result.data) > 0:
|
||||
return await OpenAIText2Image.get_image_data(result.data[0].url)
|
||||
return ""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
|
||||
import base64
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3
|
||||
import aiofiles
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import BASE64_FORMAT, WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.config import Config
|
||||
|
||||
|
||||
class S3:
|
||||
|
|
@ -11,12 +16,12 @@ class S3:
|
|||
|
||||
def __init__(self):
|
||||
self.session = aioboto3.Session()
|
||||
self.s3_config = Config().get("S3")
|
||||
self.s3_config = CONFIG.S3
|
||||
self.auth_config = {
|
||||
"service_name": "s3",
|
||||
"aws_access_key_id": self.s3_config["access_key"],
|
||||
"aws_secret_access_key": self.s3_config["secret_key"],
|
||||
"endpoint_url": self.s3_config["endpoint_url"]
|
||||
"endpoint_url": self.s3_config["endpoint_url"],
|
||||
}
|
||||
|
||||
async def upload_file(
|
||||
|
|
@ -95,11 +100,7 @@ class S3:
|
|||
raise e
|
||||
|
||||
async def download_file(
|
||||
self,
|
||||
bucket: str,
|
||||
object_name: str,
|
||||
local_path: str,
|
||||
chunk_size: Optional[int] = 128 * 1024
|
||||
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
|
||||
) -> None:
|
||||
"""Download an S3 object to a local file.
|
||||
|
||||
|
|
@ -116,7 +117,7 @@ class S3:
|
|||
async with self.session.client(**self.auth_config) as client:
|
||||
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
|
||||
stream = s3_object["Body"]
|
||||
with open(local_path, 'wb') as local_file:
|
||||
with open(local_path, "wb") as local_file:
|
||||
while True:
|
||||
file_data = await stream.read(chunk_size)
|
||||
if not file_data:
|
||||
|
|
@ -124,4 +125,21 @@ class S3:
|
|||
local_file.write(file_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to download the file from S3: {e}")
|
||||
raise e
|
||||
raise e
|
||||
|
||||
async def cache(self, data: str, format: str = "") -> str:
|
||||
"""Save data to remote S3 and return url"""
|
||||
object_name = str(uuid.uuid4()).replace("-", "")
|
||||
pathname = WORKSPACE_ROOT / "s3_tmp" / object_name
|
||||
try:
|
||||
async with aiofiles.open(pathname, mode="w") as file:
|
||||
if format == BASE64_FORMAT:
|
||||
data = base64.b64decode(data)
|
||||
await file.write(data)
|
||||
|
||||
bucket = CONFIG.S3.get("bucket")
|
||||
await self.upload_file(bucket=bucket, local_path=pathname, object_name=object_name)
|
||||
return await self.get_object_url(bucket=bucket, object_name=object_name)
|
||||
except Exception as e:
|
||||
logger.exception(f"{e}, stack:{traceback.format_exc()}")
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue