diff --git a/.gitignore b/.gitignore index b5dafc3fc..87c7b3120 100644 --- a/.gitignore +++ b/.gitignore @@ -177,4 +177,3 @@ htmlcov.* *.pkl *-structure.csv *-structure.json - diff --git a/examples/sd_tool_usage.py b/examples/sd_tool_usage.py new file mode 100644 index 000000000..92f4cd5b0 --- /dev/null +++ b/examples/sd_tool_usage.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# @Date : 1/11/2024 7:06 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import asyncio + +from metagpt.roles.code_interpreter import CodeInterpreter + + +async def main(requirement: str = ""): + code_interpreter = CodeInterpreter(use_tools=True, goal=requirement) + await code_interpreter.run(requirement) + + +if __name__ == "__main__": + sd_url = "http://your.sd.service.ip:port" + requirement = ( + f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}" + ) + + asyncio.run(main(requirement)) diff --git a/metagpt/actions/debug_code.py b/metagpt/actions/debug_code.py index 26a84bcf2..e5e0ac5d4 100644 --- a/metagpt/actions/debug_code.py +++ b/metagpt/actions/debug_code.py @@ -85,20 +85,14 @@ class DebugCode(BaseWriteAnalysisCode): async def run_reflection( self, - # goal, - # finished_code, - # finished_code_result, context: List[Message], code, runtime_result, ) -> dict: info = [] - # finished_code_and_result = finished_code + "\n [finished results]\n\n" + finished_code_result reflection_prompt = REFLECTION_PROMPT.format( debug_example=DEBUG_REFLECTION_EXAMPLE, context=context, - # goal=goal, - # finished_code=finished_code_and_result, code=code, runtime_result=runtime_result, ) @@ -106,33 +100,13 @@ class DebugCode(BaseWriteAnalysisCode): info.append(Message(role="system", content=system_prompt)) info.append(Message(role="user", content=reflection_prompt)) - # msg = messages_to_str(info) - # resp = await self.llm.aask(msg=msg) resp = await self.llm.aask_code(messages=info, **create_func_config(CODE_REFLECTION)) logger.info(f"reflection is {resp}") return resp - # async def rewrite_code(self, reflection: str = "", context: List[Message] = None) -> str: - # """ - # 根据reflection重写代码 - # """ - # info = context - # # info.append(Message(role="assistant", content=f"[code context]:{code_context}" - # # f"finished code are executable, and you should based on the code to continue your current code debug and improvement" - # # f"[reflection]: \n {reflection}")) - # info.append(Message(role="assistant", content=f"[reflection]: \n {reflection}")) - # info.append(Message(role="user", content=f"[improved impl]:\n Return in Python block")) - # msg = messages_to_str(info) - # resp = await self.llm.aask(msg=msg) - # improv_code = CodeParser.parse_code(block=None, text=resp) - # return improv_code - async def run( self, context: List[Message] = None, - plan: str = "", - # finished_code: str = "", - # finished_code_result: str = "", code: str = "", runtime_result: str = "", ) -> str: @@ -140,14 +114,10 @@ class DebugCode(BaseWriteAnalysisCode): 根据当前运行代码和报错信息进行reflection和纠错 """ reflection = await self.run_reflection( - # plan, - # finished_code=finished_code, - # finished_code_result=finished_code_result, code=code, context=context, runtime_result=runtime_result, ) # 根据reflection结果重写代码 - # improv_code = await self.rewrite_code(reflection, context=context) improv_code = reflection["improved_impl"] return improv_code diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index cf903347d..a60642bff 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -60,7 +60,6 @@ class MLEngineer(CodeInterpreter): if code_execution_count > 0: logger.warning("We got a bug code, now start to debug...") code = await DebugCode().run( - plan=self.planner.current_task.instruction, code=self.latest_code, runtime_result=self.working_memory.get(), context=self.debug_context, diff --git a/metagpt/tools/__init__.py b/metagpt/tools/__init__.py index 4b3528795..41c8708b2 100644 --- a/metagpt/tools/__init__.py +++ b/metagpt/tools/__init__.py @@ -6,7 +6,6 @@ @File : __init__.py """ - from enum import Enum from pydantic import BaseModel @@ -71,6 +70,12 @@ TOOL_TYPE_MAPPINGS = { desc="Only for evaluating model.", usage_prompt=MODEL_EVALUATE_PROMPT, ), + "stable_diffusion": ToolType( + name="stable_diffusion", + module="metagpt.tools.sd_engine", + desc="Related to text2image, image2image using stable diffusion model.", + usage_prompt="", + ), "other": ToolType( name="other", module="", diff --git a/metagpt/tools/functions/schemas/stable_diffusion.yml b/metagpt/tools/functions/schemas/stable_diffusion.yml new file mode 100644 index 000000000..a93742a1d --- /dev/null +++ b/metagpt/tools/functions/schemas/stable_diffusion.yml @@ -0,0 +1,58 @@ +SDEngine: + type: class + description: "Generate image using stable diffusion model" + methods: + __init__: + description: "Initialize the SDEngine instance." + parameters: + properties: + sd_url: + type: str + description: "URL of the stable diffusion service." + simple_run_t2i: + description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images." + parameters: + properties: + payload: + type: dict + description: "Dictionary of input parameters for the stable diffusion API." + auto_save: + type: bool + description: "Save generated images automatically." + required: + - prompts + run_t2i: + type: async function + description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images." + parameters: + properties: + payloads: + type: list + description: "List of payload, each payload is a dictionary of input parameters for the stable diffusion API." + required: + - payloads + construct_payload: + description: "Modify and set the API parameters for image generation." + parameters: + properties: + prompt: + type: str + description: "Text input for image generation." + required: + - prompt + returns: + payload: + type: dict + description: "Updated parameters for the stable diffusion API." + save: + description: "Save generated images to the output directory." + parameters: + properties: + imgs: + type: str + description: "Generated images." + save_name: + type: str + description: "Output image name. Default is empty." + required: + - imgs diff --git a/metagpt/tools/sd_engine.py b/metagpt/tools/sd_engine.py index c4d9d2df4..ba61fd496 100644 --- a/metagpt/tools/sd_engine.py +++ b/metagpt/tools/sd_engine.py @@ -2,13 +2,14 @@ # @Date : 2023/7/19 16:28 # @Author : stellahong (stellahong@deepwisdom.ai) # @Desc : -import asyncio import base64 +import hashlib import io import json from os.path import join from typing import List +import requests from aiohttp import ClientSession from PIL import Image, PngImagePlugin @@ -51,9 +52,9 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution" class SDEngine: - def __init__(self): + def __init__(self, sd_url=""): # Initialize the SDEngine with configuration - self.sd_url = CONFIG.get("SD_URL") + self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL") self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}" # Define default payload settings for SD API self.payload = payload @@ -69,25 +70,36 @@ class SDEngine: ): # Configure the payload with provided inputs self.payload["prompt"] = prompt - self.payload["negtive_prompt"] = negtive_prompt + self.payload["negative_prompt"] = negtive_prompt self.payload["width"] = width self.payload["height"] = height self.payload["override_settings"]["sd_model_checkpoint"] = sd_model logger.info(f"call sd payload is {self.payload}") return self.payload - def _save(self, imgs, save_name=""): + def save(self, imgs, save_name=""): save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO if not save_dir.exists(): save_dir.mkdir(parents=True, exist_ok=True) batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name) - async def run_t2i(self, prompts: List): + def simple_run_t2i(self, payload: dict, auto_save: bool = True): + with requests.Session() as session: + logger.debug(self.sd_t2i_url) + rsp = session.post(self.sd_t2i_url, json=payload, timeout=600) + + results = rsp.json()["images"] + if auto_save: + save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6] + self.save(results, save_name=f"output_{save_name}") + return results + + async def run_t2i(self, payloads: List): # Asynchronously run the SD API for multiple prompts session = ClientSession() - for payload_idx, payload in enumerate(prompts): + for payload_idx, payload in enumerate(payloads): results = await self.run(url=self.sd_t2i_url, payload=payload, session=session) - self._save(results, save_name=f"output_{payload_idx}") + self.save(results, save_name=f"output_{payload_idx}") await session.close() async def run(self, url, payload, session): @@ -121,13 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""): for idx, _img in enumerate(imgs): save_name = join(save_dir, save_name) decode_base64_to_image(_img, save_name=save_name) - - -if __name__ == "__main__": - engine = SDEngine() - prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary" - - engine.construct_payload(prompt) - - event_loop = asyncio.get_event_loop() - event_loop.run_until_complete(engine.run_t2i(prompt)) diff --git a/tests/metagpt/actions/test_debug_code.py b/tests/metagpt/actions/test_debug_code.py new file mode 100644 index 000000000..262f2e60d --- /dev/null +++ b/tests/metagpt/actions/test_debug_code.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# @Date : 1/11/2024 8:51 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions.debug_code import DebugCode, messages_to_str +from metagpt.schema import Message + +ErrorStr = """Tested passed: + +Tests failed: +assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5] +""" + +CODE = """ +def sort_array(arr): + # Helper function to count the number of ones in the binary representation + def count_ones(n): + return bin(n).count('1') + + # Sort the array using a custom key function + # The key function returns a tuple (number of ones, value) for each element + # This ensures that if two elements have the same number of ones, they are sorted by their value + sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x)) + + return sorted_arr +``` +""" + +DebugContext = '''Solve the problem in Python: +def sort_array(arr): + """ + In this Kata, you have to sort an array of non-negative integers according to + number of ones in their binary representation in ascending order. + For similar number of ones, sort based on decimal value. + + It must be implemented like this: + >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] + >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2] + >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4] + """ +''' + + +@pytest.mark.asyncio +async def test_debug_code(): + debug_context = Message(content=DebugContext) + new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr) + assert "def sort_array(arr)" in new_code + + +def test_messages_to_str(): + debug_context = Message(content=DebugContext) + msg_str = messages_to_str([debug_context]) + assert "user: Solve the problem in Python" in msg_str diff --git a/tests/metagpt/tools/functions/test_sd.py b/tests/metagpt/tools/functions/test_sd.py new file mode 100644 index 000000000..142101cad --- /dev/null +++ b/tests/metagpt/tools/functions/test_sd.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# @Date : 1/10/2024 10:07 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : +import pytest + +from metagpt.tools.sd_engine import SDEngine + + +def test_sd_tools(): + engine = SDEngine() + prompt = "1boy, hansom" + engine.construct_payload(prompt) + engine.simple_run_t2i(engine.payload) + + +def test_sd_construct_payload(): + engine = SDEngine() + prompt = "1boy, hansom" + engine.construct_payload(prompt) + assert "negative_prompt" in engine.payload + + +@pytest.mark.asyncio +async def test_sd_asyn_t2i(): + engine = SDEngine() + prompt = "1boy, hansom" + engine.construct_payload(prompt) + await engine.run_t2i([engine.payload]) + assert "negative_prompt" in engine.payload