mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge branch 'sd_and_debugcode_ut' into 'code_intepreter'
单测及简单优化:增加sd作为工具使用 & 简化debugcode See merge request agents/data_agents_opt!42
This commit is contained in:
commit
4a28d66680
9 changed files with 192 additions and 51 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -177,4 +177,3 @@ htmlcov.*
|
|||
*.pkl
|
||||
*-structure.csv
|
||||
*-structure.json
|
||||
|
||||
|
|
|
|||
21
examples/sd_tool_usage.py
Normal file
21
examples/sd_tool_usage.py
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 7:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
code_interpreter = CodeInterpreter(use_tools=True, goal=requirement)
|
||||
await code_interpreter.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sd_url = "http://your.sd.service.ip:port"
|
||||
requirement = (
|
||||
f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}"
|
||||
)
|
||||
|
||||
asyncio.run(main(requirement))
|
||||
|
|
@ -85,20 +85,14 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
|
||||
async def run_reflection(
|
||||
self,
|
||||
# goal,
|
||||
# finished_code,
|
||||
# finished_code_result,
|
||||
context: List[Message],
|
||||
code,
|
||||
runtime_result,
|
||||
) -> dict:
|
||||
info = []
|
||||
# finished_code_and_result = finished_code + "\n [finished results]\n\n" + finished_code_result
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
# goal=goal,
|
||||
# finished_code=finished_code_and_result,
|
||||
code=code,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
|
|
@ -106,33 +100,13 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
info.append(Message(role="system", content=system_prompt))
|
||||
info.append(Message(role="user", content=reflection_prompt))
|
||||
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
resp = await self.llm.aask_code(messages=info, **create_func_config(CODE_REFLECTION))
|
||||
logger.info(f"reflection is {resp}")
|
||||
return resp
|
||||
|
||||
# async def rewrite_code(self, reflection: str = "", context: List[Message] = None) -> str:
|
||||
# """
|
||||
# 根据reflection重写代码
|
||||
# """
|
||||
# info = context
|
||||
# # info.append(Message(role="assistant", content=f"[code context]:{code_context}"
|
||||
# # f"finished code are executable, and you should based on the code to continue your current code debug and improvement"
|
||||
# # f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="assistant", content=f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="user", content=f"[improved impl]:\n Return in Python block"))
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
# improv_code = CodeParser.parse_code(block=None, text=resp)
|
||||
# return improv_code
|
||||
|
||||
async def run(
|
||||
self,
|
||||
context: List[Message] = None,
|
||||
plan: str = "",
|
||||
# finished_code: str = "",
|
||||
# finished_code_result: str = "",
|
||||
code: str = "",
|
||||
runtime_result: str = "",
|
||||
) -> str:
|
||||
|
|
@ -140,14 +114,10 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
根据当前运行代码和报错信息进行reflection和纠错
|
||||
"""
|
||||
reflection = await self.run_reflection(
|
||||
# plan,
|
||||
# finished_code=finished_code,
|
||||
# finished_code_result=finished_code_result,
|
||||
code=code,
|
||||
context=context,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
# 根据reflection结果重写代码
|
||||
# improv_code = await self.rewrite_code(reflection, context=context)
|
||||
improv_code = reflection["improved_impl"]
|
||||
return improv_code
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ class MLEngineer(CodeInterpreter):
|
|||
if code_execution_count > 0:
|
||||
logger.warning("We got a bug code, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
plan=self.planner.current_task.instruction,
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=self.debug_context,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -71,6 +70,12 @@ TOOL_TYPE_MAPPINGS = {
|
|||
desc="Only for evaluating model.",
|
||||
usage_prompt=MODEL_EVALUATE_PROMPT,
|
||||
),
|
||||
"stable_diffusion": ToolType(
|
||||
name="stable_diffusion",
|
||||
module="metagpt.tools.sd_engine",
|
||||
desc="Related to text2image, image2image using stable diffusion model.",
|
||||
usage_prompt="",
|
||||
),
|
||||
"other": ToolType(
|
||||
name="other",
|
||||
module="",
|
||||
|
|
|
|||
58
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
58
metagpt/tools/functions/schemas/stable_diffusion.yml
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
SDEngine:
|
||||
type: class
|
||||
description: "Generate image using stable diffusion model"
|
||||
methods:
|
||||
__init__:
|
||||
description: "Initialize the SDEngine instance."
|
||||
parameters:
|
||||
properties:
|
||||
sd_url:
|
||||
type: str
|
||||
description: "URL of the stable diffusion service."
|
||||
simple_run_t2i:
|
||||
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
|
||||
parameters:
|
||||
properties:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Dictionary of input parameters for the stable diffusion API."
|
||||
auto_save:
|
||||
type: bool
|
||||
description: "Save generated images automatically."
|
||||
required:
|
||||
- prompts
|
||||
run_t2i:
|
||||
type: async function
|
||||
description: "Run the stable diffusion API for multiple prompts, calling the stable diffusion API to generate images."
|
||||
parameters:
|
||||
properties:
|
||||
payloads:
|
||||
type: list
|
||||
description: "List of payload, each payload is a dictionary of input parameters for the stable diffusion API."
|
||||
required:
|
||||
- payloads
|
||||
construct_payload:
|
||||
description: "Modify and set the API parameters for image generation."
|
||||
parameters:
|
||||
properties:
|
||||
prompt:
|
||||
type: str
|
||||
description: "Text input for image generation."
|
||||
required:
|
||||
- prompt
|
||||
returns:
|
||||
payload:
|
||||
type: dict
|
||||
description: "Updated parameters for the stable diffusion API."
|
||||
save:
|
||||
description: "Save generated images to the output directory."
|
||||
parameters:
|
||||
properties:
|
||||
imgs:
|
||||
type: str
|
||||
description: "Generated images."
|
||||
save_name:
|
||||
type: str
|
||||
description: "Output image name. Default is empty."
|
||||
required:
|
||||
- imgs
|
||||
|
|
@ -2,13 +2,14 @@
|
|||
# @Date : 2023/7/19 16:28
|
||||
# @Author : stellahong (stellahong@deepwisdom.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import base64
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
from os.path import join
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from aiohttp import ClientSession
|
||||
from PIL import Image, PngImagePlugin
|
||||
|
||||
|
|
@ -51,9 +52,9 @@ default_negative_prompt = "(easynegative:0.8),black, dark,Low resolution"
|
|||
|
||||
|
||||
class SDEngine:
|
||||
def __init__(self):
|
||||
def __init__(self, sd_url=""):
|
||||
# Initialize the SDEngine with configuration
|
||||
self.sd_url = CONFIG.get("SD_URL")
|
||||
self.sd_url = sd_url if sd_url else CONFIG.get("SD_URL")
|
||||
self.sd_t2i_url = f"{self.sd_url}{CONFIG.get('SD_T2I_API')}"
|
||||
# Define default payload settings for SD API
|
||||
self.payload = payload
|
||||
|
|
@ -69,25 +70,36 @@ class SDEngine:
|
|||
):
|
||||
# Configure the payload with provided inputs
|
||||
self.payload["prompt"] = prompt
|
||||
self.payload["negtive_prompt"] = negtive_prompt
|
||||
self.payload["negative_prompt"] = negtive_prompt
|
||||
self.payload["width"] = width
|
||||
self.payload["height"] = height
|
||||
self.payload["override_settings"]["sd_model_checkpoint"] = sd_model
|
||||
logger.info(f"call sd payload is {self.payload}")
|
||||
return self.payload
|
||||
|
||||
def _save(self, imgs, save_name=""):
|
||||
def save(self, imgs, save_name=""):
|
||||
save_dir = CONFIG.workspace_path / SD_OUTPUT_FILE_REPO
|
||||
if not save_dir.exists():
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
batch_decode_base64_to_image(imgs, str(save_dir), save_name=save_name)
|
||||
|
||||
async def run_t2i(self, prompts: List):
|
||||
def simple_run_t2i(self, payload: dict, auto_save: bool = True):
|
||||
with requests.Session() as session:
|
||||
logger.debug(self.sd_t2i_url)
|
||||
rsp = session.post(self.sd_t2i_url, json=payload, timeout=600)
|
||||
|
||||
results = rsp.json()["images"]
|
||||
if auto_save:
|
||||
save_name = hashlib.sha256(payload["prompt"][:10].encode()).hexdigest()[:6]
|
||||
self.save(results, save_name=f"output_{save_name}")
|
||||
return results
|
||||
|
||||
async def run_t2i(self, payloads: List):
|
||||
# Asynchronously run the SD API for multiple prompts
|
||||
session = ClientSession()
|
||||
for payload_idx, payload in enumerate(prompts):
|
||||
for payload_idx, payload in enumerate(payloads):
|
||||
results = await self.run(url=self.sd_t2i_url, payload=payload, session=session)
|
||||
self._save(results, save_name=f"output_{payload_idx}")
|
||||
self.save(results, save_name=f"output_{payload_idx}")
|
||||
await session.close()
|
||||
|
||||
async def run(self, url, payload, session):
|
||||
|
|
@ -121,13 +133,3 @@ def batch_decode_base64_to_image(imgs, save_dir="", save_name=""):
|
|||
for idx, _img in enumerate(imgs):
|
||||
save_name = join(save_dir, save_name)
|
||||
decode_base64_to_image(_img, save_name=save_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
engine = SDEngine()
|
||||
prompt = "pixel style, game design, a game interface should be minimalistic and intuitive with the score and high score displayed at the top. The snake and its food should be easily distinguishable. The game should have a simple color scheme, with a contrasting color for the snake and its food. Complete interface boundary"
|
||||
|
||||
engine.construct_payload(prompt)
|
||||
|
||||
event_loop = asyncio.get_event_loop()
|
||||
event_loop.run_until_complete(engine.run_t2i(prompt))
|
||||
|
|
|
|||
57
tests/metagpt/actions/test_debug_code.py
Normal file
57
tests/metagpt/actions/test_debug_code.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.debug_code import DebugCode, messages_to_str
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = """Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
"""
|
||||
|
||||
CODE = """
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
"""
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code
|
||||
|
||||
|
||||
def test_messages_to_str():
|
||||
debug_context = Message(content=DebugContext)
|
||||
msg_str = messages_to_str([debug_context])
|
||||
assert "user: Solve the problem in Python" in msg_str
|
||||
30
tests/metagpt/tools/functions/test_sd.py
Normal file
30
tests/metagpt/tools/functions/test_sd.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/10/2024 10:07 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import pytest
|
||||
|
||||
from metagpt.tools.sd_engine import SDEngine
|
||||
|
||||
|
||||
def test_sd_tools():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
engine.simple_run_t2i(engine.payload)
|
||||
|
||||
|
||||
def test_sd_construct_payload():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
assert "negative_prompt" in engine.payload
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sd_asyn_t2i():
|
||||
engine = SDEngine()
|
||||
prompt = "1boy, hansom"
|
||||
engine.construct_payload(prompt)
|
||||
await engine.run_t2i([engine.payload])
|
||||
assert "negative_prompt" in engine.payload
|
||||
Loading…
Add table
Add a link
Reference in a new issue