diff --git a/config/config.yaml b/config/config.yaml index 88cca08e5..7c3d212f6 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -77,4 +77,10 @@ MODEL_FOR_RESEARCHER_SUMMARY: gpt-3.5-turbo MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k ### Meta Models -#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL \ No newline at end of file +#METAGPT_TEXT_TO_IMAGE_MODEL: MODEL_URL + +### S3 config +S3: + access_key: "YOUR_S3_ACCESS_KEY" + secret_key: "YOUR_S3_SECRET_KEY" + endpoint_url: "YOUR_S3_ENDPOINT_URL" \ No newline at end of file diff --git a/metagpt/utils/s3.py b/metagpt/utils/s3.py new file mode 100644 index 000000000..2b4b8cb5f --- /dev/null +++ b/metagpt/utils/s3.py @@ -0,0 +1,127 @@ + +from typing import Optional + +import aioboto3 +from metagpt.logs import logger +from metagpt.config import Config + + +class S3: + """A class for interacting with Amazon S3 storage.""" + + def __init__(self): + self.session = aioboto3.Session() + self.s3_config = Config().get("S3") + self.auth_config = { + "service_name": "s3", + "aws_access_key_id": self.s3_config["access_key"], + "aws_secret_access_key": self.s3_config["secret_key"], + "endpoint_url": self.s3_config["endpoint_url"] + } + + async def upload_file( + self, + bucket: str, + local_path: str, + object_name: str, + ) -> None: + """Upload a file from the local path to the specified path of the storage bucket specified in s3. + + Args: + bucket: The name of the S3 storage bucket. + local_path: The local file path, including the file name. + object_name: The complete path of the uploaded file to be stored in S3, including the file name. + + Raises: + Exception: If an error occurs during the upload process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + with open(local_path, "rb") as file: + await client.put_object(Body=file, Bucket=bucket, Key=object_name) + logger.info(f"Successfully uploaded the file to path {object_name} in bucket {bucket} of s3.") + except Exception as e: + logger.error(f"Failed to upload the file to path {object_name} in bucket {bucket} of s3: {e}") + raise e + + async def get_object_url( + self, + bucket: str, + object_name: str, + ) -> str: + """Get the URL for a downloadable or preview file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The URL for the downloadable or preview file. + + Raises: + Exception: If an error occurs while retrieving the URL, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + file = await client.get_object(Bucket=bucket, Key=object_name) + return str(file["Body"].url) + except Exception as e: + logger.error(f"Failed to get the url for a downloadable or preview file: {e}") + raise e + + async def get_object( + self, + bucket: str, + object_name: str, + ) -> bytes: + """Get the binary data of a file stored in the specified S3 bucket. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + + Returns: + The binary data of the requested file. + + Raises: + Exception: If an error occurs while retrieving the file data, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + return await s3_object["Body"].read() + except Exception as e: + logger.error(f"Failed to get the binary data of the file: {e}") + raise e + + async def download_file( + self, + bucket: str, + object_name: str, + local_path: str, + chunk_size: Optional[int] = 128 * 1024 + ) -> None: + """Download an S3 object to a local file. + + Args: + bucket: The name of the S3 storage bucket. + object_name: The complete path of the file stored in S3, including the file name. + local_path: The local file path where the S3 object will be downloaded. + chunk_size: The size of data chunks to read and write at a time. Default is 128 KB. + + Raises: + Exception: If an error occurs during the download process, an exception is raised. + """ + try: + async with self.session.client(**self.auth_config) as client: + s3_object = await client.get_object(Bucket=bucket, Key=object_name) + stream = s3_object["Body"] + with open(local_path, 'wb') as local_file: + while True: + file_data = await stream.read(chunk_size) + if not file_data: + break + local_file.write(file_data) + except Exception as e: + logger.error(f"Failed to download the file from S3: {e}") + raise e \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ca7fcbfda..2e5112aba 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,4 +40,6 @@ libcst==1.0.1 qdrant-client==1.4.0 connexion[swagger-ui] aiohttp_jinja2 -azure-cognitiveservices-speech==1.31.0 \ No newline at end of file +azure-cognitiveservices-speech==1.31.0 +aioboto3~=11.3.0 +pytest-asyncio~=0.21.1 \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 8f5069bbe..0bc17bd6a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,7 +9,6 @@ from unittest.mock import Mock import pytest -import pytest_asyncio from metagpt.config import Config from metagpt.logs import logger @@ -17,6 +16,8 @@ from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI import asyncio import re +from metagpt.utils.s3 import S3 + class Context: def __init__(self): @@ -74,3 +75,7 @@ def proxy(): @pytest.fixture(scope="session", autouse=True) def init_config(): Config() + +@pytest.fixture(scope="session", autouse=True) +def s3(): + return S3() diff --git a/tests/metagpt/utils/test_s3.py b/tests/metagpt/utils/test_s3.py new file mode 100644 index 000000000..760a976b0 --- /dev/null +++ b/tests/metagpt/utils/test_s3.py @@ -0,0 +1,55 @@ +import os +import pytest + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["bucket", "local_path", "object_name"], + [ + ( + "agent-store", + "/code/send18-MetaGPT/workspace/resources/SD_Output/Flappy Bird_output_0.png", + "ui-designer/2023-09-01/1.png" + ) + ] +) +async def test_upload_file(s3, bucket, local_path, object_name): + await s3.upload_file(bucket=bucket, local_path=local_path, object_name=object_name) + s3_object = await s3.get_object(bucket=bucket, object_name=object_name) + assert s3_object + assert isinstance(s3_object, bytes) + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["bucket", "object_name"], + [("agent-store", "ui-designer/2023-09-01/1.png")] +) +async def test_get_object_url(s3, bucket, object_name): + url = await s3.get_object_url(bucket=bucket, object_name=object_name) + assert bucket in url + assert object_name in url + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["bucket", "object_name"], + [("agent-store", "ui-designer/2023-09-01/1.png")] +) +async def test_get_object(s3, bucket, object_name): + s3_object = await s3.get_object(bucket=bucket, object_name=object_name) + assert s3_object + assert isinstance(s3_object, bytes) + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["bucket", "local_path", "object_name"], + [ + ( + "agent-store", + "/code/send18-MetaGPT/workspace/resources/SD_Output/Flappy Bird_output_0.png", + "ui-designer/2023-09-01/1.png" + ) + ] +) +async def test_download_file(s3, bucket, local_path, object_name): + await s3.download_file(bucket=bucket, object_name=object_name, local_path=local_path) + assert os.path.exists(local_path) \ No newline at end of file