From 02d78c3a010b36a231a6be13ce5ee859cee97992 Mon Sep 17 00:00:00 2001 From: shenchucheng Date: Thu, 8 Feb 2024 14:35:11 +0800 Subject: [PATCH] add config for text to image skill --- metagpt/learn/text_to_image.py | 6 +++--- metagpt/tools/metagpt_text_to_image.py | 2 +- metagpt/tools/openai_text_to_image.py | 25 ++++++++++++++++++++----- 3 files changed, 24 insertions(+), 9 deletions(-) diff --git a/metagpt/learn/text_to_image.py b/metagpt/learn/text_to_image.py index c3c62fb67..17c90dcc8 100644 --- a/metagpt/learn/text_to_image.py +++ b/metagpt/learn/text_to_image.py @@ -26,9 +26,9 @@ async def text_to_image(text, size_type: str = "512x512", openai_api_key="", mod """ image_declaration = "data:image/png;base64," if CONFIG.METAGPT_TEXT_TO_IMAGE_MODEL_URL or model_url: - binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url) - elif CONFIG.OPENAI_API_KEY or openai_api_key: - binary_data = await oas3_openai_text_to_image(text, size_type) + binary_data = await oas3_metagpt_text_to_image(text, size_type, model_url, **kwargs) + elif CONFIG.OPENAI_TEXT_TO_IMAGE_API_KEY or openai_api_key: + binary_data = await oas3_openai_text_to_image(text, size_type, openai_api_key, **kwargs) else: raise ValueError("Missing necessary parameters.") base64_data = base64.b64encode(binary_data).decode("utf-8") diff --git a/metagpt/tools/metagpt_text_to_image.py b/metagpt/tools/metagpt_text_to_image.py index 9a84e69eb..2b79fd3c4 100644 --- a/metagpt/tools/metagpt_text_to_image.py +++ b/metagpt/tools/metagpt_text_to_image.py @@ -83,7 +83,7 @@ class MetaGPTText2Image: # Export -async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url=""): +async def oas3_metagpt_text_to_image(text, size_type: str = "512x512", model_url="", **kwargs): """Text to image :param text: The text used for image conversion. diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index aa00abdcc..4fbc458a3 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -9,17 +9,25 @@ import aiohttp import requests +from openai import AsyncOpenAI +from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger class OpenAIText2Image: - def __init__(self): + def __init__(self, api_key: str = "", **kwargs): """ :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ - self._llm = LLM() + if not api_key: + api_key = CONFIG.OPENAI_TEXT_TO_IMAGE_API_KEY + + if not api_key: + self._client = LLM().aclient + else: + self._client = AsyncOpenAI(api_key=api_key, base_url=CONFIG.OPENAI_TEXT_TO_IMAGE_BASE_URL) async def text_2_image(self, text, size_type="1024x1024"): """Text to image @@ -29,7 +37,13 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await self._llm.aclient.images.generate(prompt=text, n=1, size=size_type) + params = { + "n": 1, + "size": size_type, + } + if CONFIG.OPENAI_TEXT_TO_IMAGE_API_MODEL: + params["model"] = CONFIG.OPENAI_TEXT_TO_IMAGE_API_MODEL + result = await self._client.images.generate(prompt=text, **params) except Exception as e: logger.error(f"An error occurred:{e}") return "" @@ -57,7 +71,7 @@ class OpenAIText2Image: # Export -async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"): +async def oas3_openai_text_to_image(text, size_type: str = "1024x1024", openai_api_key: str = "", **kwargs): """Text to image :param text: The text used for image conversion. @@ -66,4 +80,5 @@ async def oas3_openai_text_to_image(text, size_type: str = "1024x1024"): """ if not text: return "" - return await OpenAIText2Image().text_2_image(text, size_type=size_type) + + return await OpenAIText2Image(openai_api_key).text_2_image(text, size_type=size_type)