add openai dall-e support

This commit is contained in:
better629 2024-01-30 14:31:00 +08:00
parent 310687258e
commit e7cd90f7f8
4 changed files with 68 additions and 3 deletions

View file

@ -28,6 +28,7 @@ from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.common import decode_image
from metagpt.utils.cost_manager import CostManager, Costs
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.token_counter import (
@ -243,3 +244,27 @@ class OpenAILLM(BaseLLM):
async def aspeech_to_text(self, **kwargs):
"""speech to text"""
return await self.aclient.audio.transcriptions.create(**kwargs)
async def gen_image(
self,
prompt: str,
size: str = "1024x1024",
quality: str = "standard",
model: str = None,
resp_format: str = "url",
) -> list["Image"]:
"""image generate"""
assert resp_format in ["url", "b64_json"]
if not model:
model = self.model
res = await self.aclient.images.generate(
model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format
)
imgs = []
for item in res.data:
if resp_format == "url":
img_url_or_b64 = item.url
else:
img_url_or_b64 = item.b64_json
imgs.append(decode_image(img_url_or_b64))
return imgs

View file

@ -24,11 +24,14 @@ import re
import sys
import traceback
import typing
from io import BytesIO
from pathlib import Path
from typing import Any, List, Tuple, Union
import aiofiles
import loguru
import requests
from PIL import Image
from pydantic_core import to_jsonable_python
from tenacity import RetryCallState, RetryError, _utils
@ -600,6 +603,29 @@ def list_files(root: str | Path) -> List[Path]:
return files
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
with open(str(image_path), "rb") as image_file:
return base64.b64encode(image_file.read()).decode(encoding)
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
"""encode image from file or PIL.Image into base64"""
if isinstance(image_path_or_pil, Image):
buffer = BytesIO()
image_path_or_pil.save(buffer, format="JPEG")
bytes_data = buffer.getvalue()
else:
if not image_path_or_pil.exists():
raise FileNotFoundError(f"{image_path_or_pil} not exists")
with open(str(image_path_or_pil), "rb") as image_file:
bytes_data = image_file.read()
return base64.b64encode(bytes_data).decode(encoding)
def decode_image(img_url_or_b64: str) -> Image:
"""decode image from url or base64 into PIL.Image"""
if img_url_or_b64.startswith("http"):
# image http(s) url
resp = requests.get(img_url_or_b64)
img = Image.open(BytesIO(resp.content))
else:
# image b64_json
b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64)
img_data = BytesIO(base64.b64decode(b64_data))
img = Image.open(img_data)
return img

View file

@ -59,3 +59,4 @@ networkx~=3.2.1
google-generativeai==0.3.2
# playwright==1.40.0 # playwright extras require
anytree
Pillow

View file

@ -1,4 +1,5 @@
import pytest
from PIL import Image
from metagpt.const import TEST_DATA_PATH
from metagpt.llm import LLM
@ -62,6 +63,18 @@ async def test_speech_to_text():
assert "你好" == resp.text
@pytest.mark.asyncio
async def test_gen_image():
llm = LLM()
model = "dall-e-3"
prompt = 'a logo with word "MetaGPT"'
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
assert images[0].size == (1024, 1024)
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
assert images[0].size == (1024, 1024)
class TestOpenAI:
def test_make_client_kwargs_without_proxy(self):
instance = OpenAILLM(mock_llm_config)