diff --git a/examples/sd_tool_usage.py b/examples/sd_tool_usage.py index 59fddb85d..82ee6a709 100644 --- a/examples/sd_tool_usage.py +++ b/examples/sd_tool_usage.py @@ -3,38 +3,17 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : import asyncio -from metagpt.const import METAGPT_ROOT -from metagpt.actions.write_analysis_code import WriteCodeWithTools -from metagpt.plan.planner import Planner -from metagpt.actions.execute_code import ExecutePyCode + from metagpt.roles.code_interpreter import CodeInterpreter -sd_url = 'http://106.75.10.65:19094/sdapi/v1/txt2img' -requirement = f"i have a text2image tool, generate a girl image use it, sd_url={sd_url}" + +async def main(requirement: str = ""): + code_interpreter = CodeInterpreter(use_tools=True, goal=requirement) + await code_interpreter.run(requirement) + if __name__ == "__main__": - code_interpreter = CodeInterpreter(use_tools=True, goal=requirement) - asyncio.run(code_interpreter.run(requirement)) - # planner = Planner( - # goal="i have a sdt2i tool, generate a girl image use it, sd_url='http://106.75.10.65:19094/sdapi/v1/txt2img'", - # auto_run=True) - # asyncio.run(planner.update_plan()) - -# schema_path = METAGPT_ROOT / "metagpt/tools/functions/schemas" -# # -# prompt = "1girl, beautiful" -# planner = Planner( -# goal="i have a sdt2i tool, generate a girl image use it, sd_url='http://106.75.10.65:19094/sdapi/v1/txt2img'", -# auto_run=True) -# asyncio.run(planner.update_plan()) -# planner.plan.current_task.task_type = "sd" -# planner.plan.current_task.instruction = "Use the sdt2i tool with the provided API endpoint to generate the girl image." -# executor = ExecutePyCode() -# -# tool_context, code = asyncio.run(WriteCodeWithTools(schema_path=schema_path).run( -# context=f"task prompt: {prompt}", -# plan=planner.plan, -# column_info="", -# )) -# print(code) -# asyncio.run(executor.run(code)) + sd_url = 'http://106.75.10.65:19094' + requirement = f"I want to generate an image of a beautiful girl using the stable diffusion text2image tool, sd_url={sd_url}" + + asyncio.run(main(requirement)) diff --git a/metagpt/actions/debug_code.py b/metagpt/actions/debug_code.py index 26a84bcf2..74a188e9f 100644 --- a/metagpt/actions/debug_code.py +++ b/metagpt/actions/debug_code.py @@ -85,20 +85,14 @@ class DebugCode(BaseWriteAnalysisCode): async def run_reflection( self, - # goal, - # finished_code, - # finished_code_result, context: List[Message], code, runtime_result, ) -> dict: info = [] - # finished_code_and_result = finished_code + "\n [finished results]\n\n" + finished_code_result reflection_prompt = REFLECTION_PROMPT.format( debug_example=DEBUG_REFLECTION_EXAMPLE, context=context, - # goal=goal, - # finished_code=finished_code_and_result, code=code, runtime_result=runtime_result, ) @@ -106,33 +100,14 @@ class DebugCode(BaseWriteAnalysisCode): info.append(Message(role="system", content=system_prompt)) info.append(Message(role="user", content=reflection_prompt)) - # msg = messages_to_str(info) - # resp = await self.llm.aask(msg=msg) resp = await self.llm.aask_code(messages=info, **create_func_config(CODE_REFLECTION)) logger.info(f"reflection is {resp}") return resp - # async def rewrite_code(self, reflection: str = "", context: List[Message] = None) -> str: - # """ - # 根据reflection重写代码 - # """ - # info = context - # # info.append(Message(role="assistant", content=f"[code context]:{code_context}" - # # f"finished code are executable, and you should based on the code to continue your current code debug and improvement" - # # f"[reflection]: \n {reflection}")) - # info.append(Message(role="assistant", content=f"[reflection]: \n {reflection}")) - # info.append(Message(role="user", content=f"[improved impl]:\n Return in Python block")) - # msg = messages_to_str(info) - # resp = await self.llm.aask(msg=msg) - # improv_code = CodeParser.parse_code(block=None, text=resp) - # return improv_code async def run( self, context: List[Message] = None, - plan: str = "", - # finished_code: str = "", - # finished_code_result: str = "", code: str = "", runtime_result: str = "", ) -> str: @@ -140,14 +115,10 @@ class DebugCode(BaseWriteAnalysisCode): 根据当前运行代码和报错信息进行reflection和纠错 """ reflection = await self.run_reflection( - # plan, - # finished_code=finished_code, - # finished_code_result=finished_code_result, code=code, context=context, runtime_result=runtime_result, ) # 根据reflection结果重写代码 - # improv_code = await self.rewrite_code(reflection, context=context) improv_code = reflection["improved_impl"] return improv_code diff --git a/metagpt/roles/ml_engineer.py b/metagpt/roles/ml_engineer.py index cf903347d..a60642bff 100644 --- a/metagpt/roles/ml_engineer.py +++ b/metagpt/roles/ml_engineer.py @@ -60,7 +60,6 @@ class MLEngineer(CodeInterpreter): if code_execution_count > 0: logger.warning("We got a bug code, now start to debug...") code = await DebugCode().run( - plan=self.planner.current_task.instruction, code=self.latest_code, runtime_result=self.working_memory.get(), context=self.debug_context, diff --git a/tests/conftest.py b/tests/conftest.py index 6f5c04f06..dc89e897f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,14 +34,14 @@ def rsp_cache(): rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache.json" # read repo-provided new_rsp_cache_file_path = TEST_DATA_PATH / "rsp_cache_new.json" # exporting a new copy if os.path.exists(rsp_cache_file_path): - with open(rsp_cache_file_path, "r") as f1: + with open(rsp_cache_file_path, "r", encoding="utf-8") as f1: rsp_cache_json = json.load(f1) else: rsp_cache_json = {} yield rsp_cache_json - with open(rsp_cache_file_path, "w") as f2: + with open(rsp_cache_file_path, "w", encoding="utf-8") as f2: json.dump(rsp_cache_json, f2, indent=4, ensure_ascii=False) - with open(new_rsp_cache_file_path, "w") as f2: + with open(new_rsp_cache_file_path, "w", encoding="utf-8") as f2: json.dump(RSP_CACHE_NEW, f2, indent=4, ensure_ascii=False) @@ -139,7 +139,7 @@ def loguru_caplog(caplog): # init & dispose git repo -@pytest.fixture(scope="function", autouse=True) +@pytest.fixture(scope="function", autouse=False) def setup_and_teardown_git_repo(request): CONFIG.git_repo = GitRepository(local_path=DEFAULT_WORKSPACE_ROOT / f"unittest/{uuid.uuid4().hex}") CONFIG.git_reinit = True diff --git a/tests/metagpt/actions/test_debug_code.py b/tests/metagpt/actions/test_debug_code.py new file mode 100644 index 000000000..675c07f78 --- /dev/null +++ b/tests/metagpt/actions/test_debug_code.py @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- +# @Date : 1/11/2024 8:51 PM +# @Author : stellahong (stellahong@fuzhi.ai) +# @Desc : + +import pytest + +from metagpt.actions.debug_code import DebugCode, messages_to_str +from metagpt.schema import Message + +ErrorStr = '''Tested passed: + +Tests failed: +assert sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] # output: [1, 2, 4, 3, 5] +''' + +CODE = ''' +def sort_array(arr): + # Helper function to count the number of ones in the binary representation + def count_ones(n): + return bin(n).count('1') + + # Sort the array using a custom key function + # The key function returns a tuple (number of ones, value) for each element + # This ensures that if two elements have the same number of ones, they are sorted by their value + sorted_arr = sorted(arr, key=lambda x: (count_ones(x), x)) + + return sorted_arr +``` +''' + +DebugContext = '''Solve the problem in Python: +def sort_array(arr): + """ + In this Kata, you have to sort an array of non-negative integers according to + number of ones in their binary representation in ascending order. + For similar number of ones, sort based on decimal value. + + It must be implemented like this: + >>> sort_array([1, 5, 2, 3, 4]) == [1, 2, 3, 4, 5] + >>> sort_array([-2, -3, -4, -5, -6]) == [-6, -5, -4, -3, -2] + >>> sort_array([1, 0, 2, 3, 4]) [0, 1, 2, 3, 4] + """ +''' +@pytest.mark.asyncio +async def test_debug_code(): + debug_context = Message(content=DebugContext) + new_code = await DebugCode().run(context=debug_context, code=CODE, runtime_result=ErrorStr) + assert "def sort_array(arr)" in new_code + +def test_messages_to_str(): + debug_context = Message(content=DebugContext) + msg_str = messages_to_str([debug_context]) + assert "user: Solve the problem in Python" in msg_str \ No newline at end of file