diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 2ec78317a..8f9d91d6c 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -28,6 +28,7 @@ from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message +from metagpt.utils.common import decode_image from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.exceptions import handle_exception from metagpt.utils.token_counter import ( @@ -243,3 +244,27 @@ class OpenAILLM(BaseLLM): async def aspeech_to_text(self, **kwargs): """speech to text""" return await self.aclient.audio.transcriptions.create(**kwargs) + + async def gen_image( + self, + prompt: str, + size: str = "1024x1024", + quality: str = "standard", + model: str = None, + resp_format: str = "url", + ) -> list["Image"]: + """image generate""" + assert resp_format in ["url", "b64_json"] + if not model: + model = self.model + res = await self.aclient.images.generate( + model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format + ) + imgs = [] + for item in res.data: + if resp_format == "url": + img_url_or_b64 = item.url + else: + img_url_or_b64 = item.b64_json + imgs.append(decode_image(img_url_or_b64)) + return imgs diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 73017cf77..93921e983 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -24,11 +24,14 @@ import re import sys import traceback import typing +from io import BytesIO from pathlib import Path from typing import Any, List, Tuple, Union import aiofiles import loguru +import requests +from PIL import Image from pydantic_core import to_jsonable_python from tenacity import RetryCallState, RetryError, _utils @@ -600,6 +603,29 @@ def list_files(root: str | Path) -> List[Path]: return files -def encode_image(image_path: Path, encoding: str = "utf-8") -> str: - with open(str(image_path), "rb") as image_file: - return base64.b64encode(image_file.read()).decode(encoding) +def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str: + """encode image from file or PIL.Image into base64""" + if isinstance(image_path_or_pil, Image): + buffer = BytesIO() + image_path_or_pil.save(buffer, format="JPEG") + bytes_data = buffer.getvalue() + else: + if not image_path_or_pil.exists(): + raise FileNotFoundError(f"{image_path_or_pil} not exists") + with open(str(image_path_or_pil), "rb") as image_file: + bytes_data = image_file.read() + return base64.b64encode(bytes_data).decode(encoding) + + +def decode_image(img_url_or_b64: str) -> Image: + """decode image from url or base64 into PIL.Image""" + if img_url_or_b64.startswith("http"): + # image http(s) url + resp = requests.get(img_url_or_b64) + img = Image.open(BytesIO(resp.content)) + else: + # image b64_json + b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64) + img_data = BytesIO(base64.b64decode(b64_data)) + img = Image.open(img_data) + return img diff --git a/requirements.txt b/requirements.txt index d54a1d22e..93091d137 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,3 +59,4 @@ networkx~=3.2.1 google-generativeai==0.3.2 # playwright==1.40.0 # playwright extras require anytree +Pillow \ No newline at end of file diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index bc7f92f33..82ab091c5 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,4 +1,5 @@ import pytest +from PIL import Image from metagpt.const import TEST_DATA_PATH from metagpt.llm import LLM @@ -62,6 +63,18 @@ async def test_speech_to_text(): assert "你好" == resp.text +@pytest.mark.asyncio +async def test_gen_image(): + llm = LLM() + model = "dall-e-3" + prompt = 'a logo with word "MetaGPT"' + images: list[Image] = await llm.gen_image(model=model, prompt=prompt) + assert images[0].size == (1024, 1024) + + images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json") + assert images[0].size == (1024, 1024) + + class TestOpenAI: def test_make_client_kwargs_without_proxy(self): instance = OpenAILLM(mock_llm_config)