mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
add openai dall-e support
This commit is contained in:
parent
310687258e
commit
e7cd90f7f8
4 changed files with 68 additions and 3 deletions
|
|
@ -28,6 +28,7 @@ from metagpt.provider.base_llm import BaseLLM
|
|||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import decode_image
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -243,3 +244,27 @@ class OpenAILLM(BaseLLM):
|
|||
async def aspeech_to_text(self, **kwargs):
|
||||
"""speech to text"""
|
||||
return await self.aclient.audio.transcriptions.create(**kwargs)
|
||||
|
||||
async def gen_image(
|
||||
self,
|
||||
prompt: str,
|
||||
size: str = "1024x1024",
|
||||
quality: str = "standard",
|
||||
model: str = None,
|
||||
resp_format: str = "url",
|
||||
) -> list["Image"]:
|
||||
"""image generate"""
|
||||
assert resp_format in ["url", "b64_json"]
|
||||
if not model:
|
||||
model = self.model
|
||||
res = await self.aclient.images.generate(
|
||||
model=model, prompt=prompt, size=size, quality=quality, n=1, response_format=resp_format
|
||||
)
|
||||
imgs = []
|
||||
for item in res.data:
|
||||
if resp_format == "url":
|
||||
img_url_or_b64 = item.url
|
||||
else:
|
||||
img_url_or_b64 = item.b64_json
|
||||
imgs.append(decode_image(img_url_or_b64))
|
||||
return imgs
|
||||
|
|
|
|||
|
|
@ -24,11 +24,14 @@ import re
|
|||
import sys
|
||||
import traceback
|
||||
import typing
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import aiofiles
|
||||
import loguru
|
||||
import requests
|
||||
from PIL import Image
|
||||
from pydantic_core import to_jsonable_python
|
||||
from tenacity import RetryCallState, RetryError, _utils
|
||||
|
||||
|
|
@ -600,6 +603,29 @@ def list_files(root: str | Path) -> List[Path]:
|
|||
return files
|
||||
|
||||
|
||||
def encode_image(image_path: Path, encoding: str = "utf-8") -> str:
|
||||
with open(str(image_path), "rb") as image_file:
|
||||
return base64.b64encode(image_file.read()).decode(encoding)
|
||||
def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str:
|
||||
"""encode image from file or PIL.Image into base64"""
|
||||
if isinstance(image_path_or_pil, Image):
|
||||
buffer = BytesIO()
|
||||
image_path_or_pil.save(buffer, format="JPEG")
|
||||
bytes_data = buffer.getvalue()
|
||||
else:
|
||||
if not image_path_or_pil.exists():
|
||||
raise FileNotFoundError(f"{image_path_or_pil} not exists")
|
||||
with open(str(image_path_or_pil), "rb") as image_file:
|
||||
bytes_data = image_file.read()
|
||||
return base64.b64encode(bytes_data).decode(encoding)
|
||||
|
||||
|
||||
def decode_image(img_url_or_b64: str) -> Image:
|
||||
"""decode image from url or base64 into PIL.Image"""
|
||||
if img_url_or_b64.startswith("http"):
|
||||
# image http(s) url
|
||||
resp = requests.get(img_url_or_b64)
|
||||
img = Image.open(BytesIO(resp.content))
|
||||
else:
|
||||
# image b64_json
|
||||
b64_data = re.sub("^data:image/.+;base64,", "", img_url_or_b64)
|
||||
img_data = BytesIO(base64.b64decode(b64_data))
|
||||
img = Image.open(img_data)
|
||||
return img
|
||||
|
|
|
|||
|
|
@ -59,3 +59,4 @@ networkx~=3.2.1
|
|||
google-generativeai==0.3.2
|
||||
# playwright==1.40.0 # playwright extras require
|
||||
anytree
|
||||
Pillow
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from metagpt.const import TEST_DATA_PATH
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -62,6 +63,18 @@ async def test_speech_to_text():
|
|||
assert "你好" == resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gen_image():
|
||||
llm = LLM()
|
||||
model = "dall-e-3"
|
||||
prompt = 'a logo with word "MetaGPT"'
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt)
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
images: list[Image] = await llm.gen_image(model=model, prompt=prompt, resp_format="b64_json")
|
||||
assert images[0].size == (1024, 1024)
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
def test_make_client_kwargs_without_proxy(self):
|
||||
instance = OpenAILLM(mock_llm_config)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue