mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-15 11:02:36 +02:00
update debug_code ut
update sd_tool_usage example
This commit is contained in:
parent
a98edada1a
commit
af26fe06cf
5 changed files with 68 additions and 65 deletions
|
|
@ -3,38 +3,17 @@
|
|||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
from metagpt.const import METAGPT_ROOT
|
||||
from metagpt.actions.write_analysis_code import WriteCodeWithTools
|
||||
from metagpt.plan.planner import Planner
|
||||
from metagpt.actions.execute_code import ExecutePyCode
|
||||
|
||||
from metagpt.roles.code_interpreter import CodeInterpreter
|
||||
|
||||
sd_url = 'http://106.75.10.65:19094/sdapi/v1/txt2img'
|
||||
requirement = f"i have a text2image tool, generate a girl image use it, sd_url={sd_url}"
|
||||
|
||||
async def main(requirement: str = ""):
|
||||
code_interpreter = CodeInterpreter(use_tools=True, goal=requirement)
|
||||
await code_interpreter.run(requirement)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
code_interpreter = CodeInterpreter(use_tools=True, goal=requirement)
|
||||
asyncio.run(code_interpreter.run(requirement))
|
||||
# planner = Planner(
|
||||
# goal="i have a sdt2i tool, generate a girl image use it, sd_url='http://106.75.10.65:19094/sdapi/v1/txt2img'",
|
||||
# auto_run=True)
|
||||
# asyncio.run(planner.update_plan())
|
||||
|
||||
# schema_path = METAGPT_ROOT / "metagpt/tools/functions/schemas"
|
||||
# #
|
||||
# prompt = "1girl, beautiful"
|
||||
# planner = Planner(
|
||||
# goal="i have a sdt2i tool, generate a girl image use it, sd_url='http://106.75.10.65:19094/sdapi/v1/txt2img'",
|
||||
# auto_run=True)
|
||||
# asyncio.run(planner.update_plan())
|
||||
# planner.plan.current_task.task_type = "sd"
|
||||
# planner.plan.current_task.instruction = "Use the sdt2i tool with the provided API endpoint to generate the girl image."
|
||||
# executor = ExecutePyCode()
|
||||
#
|
||||
# tool_context, code = asyncio.run(WriteCodeWithTools(schema_path=schema_path).run(
|
||||
# context=f"task prompt: {prompt}",
|
||||
# plan=planner.plan,
|
||||
# column_info="",
|
||||
# ))
|
||||
# print(code)
|
||||
# asyncio.run(executor.run(code))
|
||||
sd_url = 'http://106.75.10.65:19094'
|
||||
requirement = f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}"
|
||||
|
||||
asyncio.run(main(requirement))
|
||||
|
|
|
|||
|
|
@ -85,20 +85,14 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
|
||||
async def run_reflection(
|
||||
self,
|
||||
# goal,
|
||||
# finished_code,
|
||||
# finished_code_result,
|
||||
context: List[Message],
|
||||
code,
|
||||
runtime_result,
|
||||
) -> dict:
|
||||
info = []
|
||||
# finished_code_and_result = finished_code + "\n [finished results]\n\n" + finished_code_result
|
||||
reflection_prompt = REFLECTION_PROMPT.format(
|
||||
debug_example=DEBUG_REFLECTION_EXAMPLE,
|
||||
context=context,
|
||||
# goal=goal,
|
||||
# finished_code=finished_code_and_result,
|
||||
code=code,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
|
|
@ -106,33 +100,14 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
info.append(Message(role="system", content=system_prompt))
|
||||
info.append(Message(role="user", content=reflection_prompt))
|
||||
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
resp = await self.llm.aask_code(messages=info, **create_func_config(CODE_REFLECTION))
|
||||
logger.info(f"reflection is {resp}")
|
||||
return resp
|
||||
|
||||
# async def rewrite_code(self, reflection: str = "", context: List[Message] = None) -> str:
|
||||
# """
|
||||
# 根据reflection重写代码
|
||||
# """
|
||||
# info = context
|
||||
# # info.append(Message(role="assistant", content=f"[code context]:{code_context}"
|
||||
# # f"finished code are executable, and you should based on the code to continue your current code debug and improvement"
|
||||
# # f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="assistant", content=f"[reflection]: \n {reflection}"))
|
||||
# info.append(Message(role="user", content=f"[improved impl]:\n Return in Python block"))
|
||||
# msg = messages_to_str(info)
|
||||
# resp = await self.llm.aask(msg=msg)
|
||||
# improv_code = CodeParser.parse_code(block=None, text=resp)
|
||||
# return improv_code
|
||||
|
||||
async def run(
|
||||
self,
|
||||
context: List[Message] = None,
|
||||
plan: str = "",
|
||||
# finished_code: str = "",
|
||||
# finished_code_result: str = "",
|
||||
code: str = "",
|
||||
runtime_result: str = "",
|
||||
) -> str:
|
||||
|
|
@ -140,14 +115,10 @@ class DebugCode(BaseWriteAnalysisCode):
|
|||
根据当前运行代码和报错信息进行reflection和纠错
|
||||
"""
|
||||
reflection = await self.run_reflection(
|
||||
# plan,
|
||||
# finished_code=finished_code,
|
||||
# finished_code_result=finished_code_result,
|
||||
code=code,
|
||||
context=context,
|
||||
runtime_result=runtime_result,
|
||||
)
|
||||
# 根据reflection结果重写代码
|
||||
# improv_code = await self.rewrite_code(reflection, context=context)
|
||||
improv_code = reflection["improved_impl"]
|
||||
return improv_code
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ class MLEngineer(CodeInterpreter):
|
|||
if code_execution_count > 0:
|
||||
logger.warning("We got a bug code, now start to debug...")
|
||||
code = await DebugCode().run(
|
||||
plan=self.planner.current_task.instruction,
|
||||
code=self.latest_code,
|
||||
runtime_result=self.working_memory.get(),
|
||||
context=self.debug_context,
|
||||
|
|
|
|||
|
|
@ -34,14 +34,14 @@ def rsp_cache():
|
|||
rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided
|
||||
new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy
|
||||
if os.path.exists(rsp_cache_file_path):
|
||||
with open(rsp_cache_file_path, "r") as f1:
|
||||
with open(rsp_cache_file_path, "r", encoding="utf-8") as f1:
|
||||
rsp_cache_json = json.load(f1)
|
||||
else:
|
||||
rsp_cache_json = {}
|
||||
yield rsp_cache_json
|
||||
with open(rsp_cache_file_path, "w") as f2:
|
||||
with open(rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False)
|
||||
with open(new_rsp_cache_file_path, "w") as f2:
|
||||
with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2:
|
||||
json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
|
|
@ -139,7 +139,7 @@ def loguru_caplog(caplog):
|
|||
|
||||
|
||||
# init & dispose git repo
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
@pytest.fixture(scope="function", autouse=False)
|
||||
def setup_and_teardown_git_repo(request):
|
||||
CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}")
|
||||
CONFIG.git_reinit = True
|
||||
|
|
|
|||
54
tests/metagpt/actions/test_debug_code.py
Normal file
54
tests/metagpt/actions/test_debug_code.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 1/11/2024 8:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.actions.debug_code import DebugCode, messages_to_str
|
||||
from metagpt.schema import Message
|
||||
|
||||
ErrorStr = '''Tested passed:
|
||||
|
||||
Tests failed:
|
||||
assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5]
|
||||
'''
|
||||
|
||||
CODE = '''
|
||||
def sort_array(arr):
|
||||
# Helper function to count the number of ones in the binary representation
|
||||
def count_ones(n):
|
||||
return bin(n).count('1')
|
||||
|
||||
# Sort the array using a custom key function
|
||||
# The key function returns a tuple (number of ones, value) for each element
|
||||
# This ensures that if two elements have the same number of ones, they are sorted by their value
|
||||
sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x))
|
||||
|
||||
return sorted_arr
|
||||
```
|
||||
'''
|
||||
|
||||
DebugContext = '''Solve the problem in Python:
|
||||
def sort_array(arr):
|
||||
"""
|
||||
In this Kata, you have to sort an array of non-negative integers according to
|
||||
number of ones in their binary representation in ascending order.
|
||||
For similar number of ones, sort based on decimal value.
|
||||
|
||||
It must be implemented like this:
|
||||
>>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5]
|
||||
>>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2]
|
||||
>>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4]
|
||||
"""
|
||||
'''
|
||||
@pytest.mark.asyncio
|
||||
async def test_debug_code():
|
||||
debug_context = Message(content=DebugContext)
|
||||
new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr)
|
||||
assert "def sort_array(arr)" in new_code
|
||||
|
||||
def test_messages_to_str():
|
||||
debug_context = Message(content=DebugContext)
|
||||
msg_str = messages_to_str([debug_context])
|
||||
assert "user: Solve the problem in Python" in msg_str
|
||||
Loading…
Add table
Add a link
Reference in a new issue