From 6b1f3ee39821e839293a30c762533c5a0ec95896 Mon Sep 17 00:00:00 2001 From: better629 Date: Tue, 30 Jan 2024 20:39:56 +0800 Subject: [PATCH] add dalle/gpt4v example --- examples/dalle_gpt4v_agent.py | 80 +++++++++++++++++++++++++++++++++++ metagpt/utils/common.py | 2 +- 2 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 examples/dalle_gpt4v_agent.py diff --git a/examples/dalle_gpt4v_agent.py b/examples/dalle_gpt4v_agent.py new file mode 100644 index 000000000..3974f5d9b --- /dev/null +++ b/examples/dalle_gpt4v_agent.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : use gpt4v to improve prompt and draw image with dall-e-3 + +""" +set the configuration in `config2.yaml` like below +``` +llm: + base_url: "xxx" + api_key: "sk-xxx" + model: "gpt-4-vision-preview" +``` +""" + +import asyncio + +from PIL import Image + +from metagpt.roles.role import Role +from metagpt.actions.action import Action +from metagpt.utils.common import encode_image +from metagpt.schema import Message +from metagpt.logs import logger + + +class GenAndImproveImageAction(Action): + save_image: bool = True + + async def generate_image(self, prompt: str) -> Image: + imgs = await self.llm.gen_image(model="dall-e-3", prompt=prompt) + return imgs[0] + + async def refine_prompt(self, old_prompt: str, image: Image) -> str: + msg = f"You are a creative painter, with the given generated image and old prompt: {old_prompt}, " \ + f"please refine the prompt and generate new one. Just output the new prompt." + b64_img = encode_image(image) + new_prompt = await self.llm.aask(msg=msg, images=[b64_img]) + return new_prompt + + async def evaluate_images(self, old_prompt: str, images: list[Image]) -> str: + msg = "With the prompt and two generated image, to judge if the second one is better than the first one. " \ + "If so, just output True else output False" + b64_imgs = [encode_image(img) for img in images] + res = await self.llm.aask(msg=msg, images=b64_imgs) + return res + + async def run(self, messages: list[Message]) -> str: + prompt = messages[-1].content + + old_img: Image = await self.generate_image(prompt) + new_prompt = await self.refine_prompt(old_prompt=prompt, image=old_img) + logger.info(f"original prompt: {prompt}") + logger.info(f"refined prompt: {new_prompt}") + new_img: Image = await self.generate_image(new_prompt) + if self.save_image: + old_img.save("./img_by-dall-e_old.png") + new_img.save("./img_by-dall-e_new.png") + res = await self.evaluate_images(old_prompt=prompt, images=[old_img, new_img]) + opinion = f"The second generated image is better than the first one: {res}" + logger.info(f"evaluate opinion: {opinion}") + return opinion + + +class Painter(Role): + name: str = "MaLiang" + profile: str = "Painter" + goal: str = "to generate fine painting" + + def __init__(self, **data): + super().__init__(**data) + + self.set_actions([GenAndImproveImageAction]) + + +async def main(): + role = Painter() + await role.run(with_message="a girl with flowers") + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index 79e7de5b6..2e05afa74 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -619,7 +619,7 @@ def load_mc_skills_code(skill_names: list[str] = None, skills_dir: Path = None) def encode_image(image_path_or_pil: Union[Path, Image], encoding: str = "utf-8") -> str: """encode image from file or PIL.Image into base64""" - if isinstance(image_path_or_pil, Image): + if isinstance(image_path_or_pil, Image.Image): buffer = BytesIO() image_path_or_pil.save(buffer, format="JPEG") bytes_data = buffer.getvalue()