feat: +s3

This commit is contained in:
莘权 马 2023-09-02 14:30:45 +08:00
parent 6b66429af8
commit ca60cd0557
5 changed files with 79 additions and 53 deletions

View file

@ -1,9 +1,14 @@
import base64
import traceback
import uuid
from typing import Optional
import aioboto3
import aiofiles
from metagpt.config import CONFIG
from metagpt.const import BASE64_FORMAT, WORKSPACE_ROOT
from metagpt.logs import logger
from metagpt.config import Config
class S3:
@ -11,12 +16,12 @@ class S3:
def __init__(self):
self.session = aioboto3.Session()
self.s3_config = Config().get("S3")
self.s3_config = CONFIG.S3
self.auth_config = {
"service_name": "s3",
"aws_access_key_id": self.s3_config["access_key"],
"aws_secret_access_key": self.s3_config["secret_key"],
"endpoint_url": self.s3_config["endpoint_url"]
"endpoint_url": self.s3_config["endpoint_url"],
}
async def upload_file(
@ -95,11 +100,7 @@ class S3:
raise e
async def download_file(
self,
bucket: str,
object_name: str,
local_path: str,
chunk_size: Optional[int] = 128 * 1024
self, bucket: str, object_name: str, local_path: str, chunk_size: Optional[int] = 128 * 1024
) -> None:
"""Download an S3 object to a local file.
@ -116,7 +117,7 @@ class S3:
async with self.session.client(**self.auth_config) as client:
s3_object = await client.get_object(Bucket=bucket, Key=object_name)
stream = s3_object["Body"]
with open(local_path, 'wb') as local_file:
with open(local_path, "wb") as local_file:
while True:
file_data = await stream.read(chunk_size)
if not file_data:
@ -124,4 +125,21 @@ class S3:
local_file.write(file_data)
except Exception as e:
logger.error(f"Failed to download the file from S3: {e}")
raise e
raise e
async def cache(self, data: str, format: str = "") -> str:
"""Save data to remote S3 and return url"""
object_name = str(uuid.uuid4()).replace("-", "")
pathname = WORKSPACE_ROOT / "s3_tmp" / object_name
try:
async with aiofiles.open(pathname, mode="w") as file:
if format == BASE64_FORMAT:
data = base64.b64decode(data)
await file.write(data)
bucket = CONFIG.S3.get("bucket")
await self.upload_file(bucket=bucket, local_path=pathname, object_name=object_name)
return await self.get_object_url(bucket=bucket, object_name=object_name)
except Exception as e:
logger.exception(f"{e}, stack:{traceback.format_exc()}")
return None