From 5ddaaaa3471e096b5ea02f2e2b8a4cc34f50332a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 14:56:44 +0800 Subject: [PATCH 1/8] add test: test_get_choice_function_arguments_for_aask_code. --- tests/metagpt/provider/test_openai.py | 139 +++++++++----------------- 1 file changed, 48 insertions(+), 91 deletions(-) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 8de29c11b..7af2f6892 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,104 +1,21 @@ from unittest.mock import Mock import pytest +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call import Function from metagpt.config import CONFIG +from metagpt.logs import logger from metagpt.provider.openai_api import OpenAILLM -from metagpt.schema import UserMessage CONFIG.openai_proxy = None -@pytest.mark.asyncio -async def test_aask_code(): - llm = OpenAILLM() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_aask_code_str(): - llm = OpenAILLM() - msg = "Write a python hello world code." - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_aask_code_Message(): - llm = OpenAILLM() - msg = UserMessage("Write a python hello world code.") - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code(): - llm = OpenAIGPTAPI() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_str(): - llm = OpenAIGPTAPI() - msg = "Write a python hello world code." - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_Message(): - llm = OpenAIGPTAPI() - msg = UserMessage("Write a python hello world code.") - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_Message(): - llm = OpenAIGPTAPI() - msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_str(): - llm = OpenAIGPTAPI() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_ask_code_steps2(): - llm = OpenAIGPTAPI() - msg = ["step by setp 生成代码: Step 1. 先生成随机数组a, Step 2. 求a中最大值, Step 3. 绘制数据a的直方图"] - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - assert "Step 1" in rsp["code"] - assert "Step 2" in rsp["code"] - assert "Step 3" in rsp["code"] - - class TestOpenAI: @pytest.fixture def config(self): @@ -146,6 +63,32 @@ class TestOpenAI: openai_api_type="azure", ) + @pytest.fixture + def tool_calls_rsp(self): + function_rsps = [ + Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"), + Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```', name="execute"), + Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'', name="execute"), + Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"', name="execute"), + Function(arguments='\nprint("hello world")\\n', name="execute"), + ] + tool_calls = [ + ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) + for i, f in enumerate(function_rsps) + ] + messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls] + # 添加一个纯文本响应 + messages.append( + ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None) + ) + choices = [ + Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages) + ] + return [ + ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion") + for i, c in enumerate(choices) + ] + def test_make_client_kwargs_without_proxy(self, config): instance = OpenAILLM() instance.config = config @@ -171,3 +114,17 @@ class TestOpenAI: instance.config = config_azure_proxy kwargs = instance._make_client_kwargs() assert "http_client" in kwargs + + def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp): + instance = OpenAILLM() + for i, rsp in enumerate(tool_calls_rsp): + code = instance.get_choice_function_arguments(rsp) + logger.info(f"\ntest get function call arguments {i}: {code}") + assert "code" in code + assert "language" in code + assert "hello world" in code["code"] + + if "Completed a python code for hello world!" == code["code"]: + code["language"] == "markdown" + else: + code["language"] == "python" From 6cb2910d144c56ccd2ef84c223cd9125cbf22a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 15:29:28 +0800 Subject: [PATCH 2/8] fix: now present the results of failure and success code in different ways. --- metagpt/actions/execute_code.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/metagpt/actions/execute_code.py b/metagpt/actions/execute_code.py index 5b6cba57d..851794b91 100644 --- a/metagpt/actions/execute_code.py +++ b/metagpt/actions/execute_code.py @@ -15,14 +15,13 @@ import nbformat from nbclient import NotebookClient from nbclient.exceptions import CellTimeoutError, DeadKernelError from nbformat import NotebookNode -from nbformat.v4 import new_code_cell, new_output, new_markdown_cell -from rich.console import Console -from rich.syntax import Syntax +from nbformat.v4 import new_code_cell, new_markdown_cell, new_output +from rich.box import MINIMAL +from rich.console import Console, Group +from rich.live import Live from rich.markdown import Markdown from rich.panel import Panel -from rich.box import MINIMAL -from rich.live import Live -from rich.console import Group +from rich.syntax import Syntax from metagpt.actions import Action from metagpt.logs import logger @@ -229,7 +228,7 @@ class ExecutePyCode(ExecuteCode, Action): # code success outputs = self.parse_outputs(self.nb.cells[-1].outputs) return truncate(remove_escape_and_color_codes(outputs), is_success=success) - elif language == 'markdown': + elif language == "markdown": # markdown self.add_markdown_cell(code) return code, True @@ -238,26 +237,27 @@ class ExecutePyCode(ExecuteCode, Action): def truncate(result: str, keep_len: int = 2000, is_success: bool = True): - desc = f"Executed code {'successfully' if is_success else 'failed, please reflect the cause of bug and then debug'}" + """执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。""" + desc = f"Executed code {'successfully. ' if is_success else 'failed, please reflect the cause of bug and then debug. '}" if is_success: - desc += f"Truncated to show only {keep_len} characters\n" + desc += f"Truncated to show only first {keep_len} characters\n" else: - desc += "Show complete information for you." + desc += f"Truncated to show only last {keep_len} characters\n" if result.startswith(desc): result = result[len(desc) :] if len(result) > keep_len: - result = result[-keep_len:] if not is_success else result + result = result[-keep_len:] if not is_success else result[:keep_len] if not result: - result = 'No output about your code. Only when importing packages it is normal case. Recap and go ahead.' + result = "No output about your code. Only when importing packages it is normal case. Recap and go ahead." return result, False if result.strip().startswith(" Date: Mon, 22 Jan 2024 15:36:25 +0800 Subject: [PATCH 3/8] add test. --- tests/metagpt/actions/test_execute_code.py | 50 +++++++++------------- 1 file changed, 20 insertions(+), 30 deletions(-) diff --git a/tests/metagpt/actions/test_execute_code.py b/tests/metagpt/actions/test_execute_code.py index 904cc3c58..ecddccf6f 100644 --- a/tests/metagpt/actions/test_execute_code.py +++ b/tests/metagpt/actions/test_execute_code.py @@ -52,42 +52,21 @@ async def test_plotting_code(): # 显示图形 plt.show() + plt.close() """ output = await pi.run(code) assert output[1] is True -@pytest.mark.asyncio -async def test_plotting_bug(): - code = """ - import matplotlib.pyplot as plt - import seaborn as sns - import pandas as pd - from sklearn.datasets import load_iris - # Load the Iris dataset - iris_data = load_iris() - # Convert the loaded Iris dataset into a DataFrame for easier manipulation - iris_df = pd.DataFrame(iris_data['data'], columns=iris_data['feature_names']) - # Add a column for the target - iris_df['species'] = pd.Categorical.from_codes(iris_data['target'], iris_data['target_names']) - # Set the style of seaborn - sns.set(style='whitegrid') - # Create a pairplot of the iris dataset - plt.figure(figsize=(10, 8)) - pairplot = sns.pairplot(iris_df, hue='species') - # Show the plot - plt.show() - """ - pi = ExecutePyCode() - output = await pi.run(code) - assert output[1] is True - - def test_truncate(): - output = "hello world" - assert truncate(output) == output - output = "hello world" - assert truncate(output, 5) == "Truncated to show only the last 5 characters\nworld" + # 代码执行成功 + output, is_success = truncate("hello world", 5, True) + assert "Truncated to show only first 5 characters\nhello" in output + assert is_success + # 代码执行失败 + output, is_success = truncate("hello world", 5, False) + assert "Truncated to show only last 5 characters\nworld" in output + assert not is_success @pytest.mark.asyncio @@ -97,3 +76,14 @@ async def test_run_with_timeout(): message, success = await pi.run(code) assert not success assert message.startswith("Cell execution timed out") + + +@pytest.mark.asyncio +async def test_run_code_text(): + pi = ExecutePyCode() + message, success = await pi.run(code='print("This is a code!")', language="python") + assert success + assert message == "This is a code!\n" + message, success = await pi.run(code="# This is a code!", language="markdown") + assert success + assert message == "# This is a code!" From 1793a5fec64cdb624b4b425f7c6798ea7a5627af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 16:09:23 +0800 Subject: [PATCH 4/8] update function_rsps. --- tests/metagpt/provider/test_openai.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 7af2f6892..2e5799475 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -66,11 +66,17 @@ class TestOpenAI: @pytest.fixture def tool_calls_rsp(self): function_rsps = [ - Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"), - Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```', name="execute"), - Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'', name="execute"), - Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"', name="execute"), + Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"), + Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'}', name="execute"), + Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute"), + Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"), + Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```}', name="execute"), + Function(arguments='{\n"language": "python",\n"code": """print("hello world")"""}', name="execute"), Function(arguments='\nprint("hello world")\\n', name="execute"), + # only `{` in arguments + Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"), + # no `{`, `}` in arguments + Function(arguments='\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"), ] tool_calls = [ ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) From 64a296a29d321e4d05c1b0473a073dc05ee2bb1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 16:14:59 +0800 Subject: [PATCH 5/8] update logger warning for JSONDecodeError. --- metagpt/provider/openai_api.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 72af5f40a..3358b3aad 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -237,9 +237,13 @@ class OpenAILLM(BaseLLM): try: return json.loads(message.tool_calls[0].function.arguments, strict=False) except json.decoder.JSONDecodeError as e: - logger.debug( - f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\ - we will use RegExp to parse code, \n {e}" + logger.warning( + "\n".join( + [ + (f"Got JSONDecodeError for \n{'--'*40} \n{message.tool_calls[0].function.arguments}"), + (f"{'--'*40}\nwe will use RegExp to parse code. JSONDecodeError is: {e}"), + ] + ) ) return self._parse_arguments(message.tool_calls[0].function.arguments) elif message.tool_calls is None and message.content is not None: From 3bfd0c8dadaf0e016e374d5aca28550d8f635f98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 18:36:23 +0800 Subject: [PATCH 6/8] update truncate. --- metagpt/actions/execute_code.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/metagpt/actions/execute_code.py b/metagpt/actions/execute_code.py index 851794b91..a5a766ab2 100644 --- a/metagpt/actions/execute_code.py +++ b/metagpt/actions/execute_code.py @@ -237,8 +237,10 @@ class ExecutePyCode(ExecuteCode, Action): def truncate(result: str, keep_len: int = 2000, is_success: bool = True): - """执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。""" + """对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。""" desc = f"Executed code {'successfully. ' if is_success else 'failed, please reflect the cause of bug and then debug. '}" + is_same_desc = False + if is_success: desc += f"Truncated to show only first {keep_len} characters\n" else: @@ -246,20 +248,17 @@ def truncate(result: str, keep_len: int = 2000, is_success: bool = True): if result.startswith(desc): result = result[len(desc) :] + is_same_desc = True + + if result.strip().startswith(" keep_len: result = result[-keep_len:] if not is_success else result[:keep_len] - if not result: - result = "No output about your code. Only when importing packages it is normal case. Recap and go ahead." - return result, False + return desc + result, is_success - if result.strip().startswith(" Date: Mon, 22 Jan 2024 18:53:40 +0800 Subject: [PATCH 7/8] add new test. --- tests/metagpt/actions/test_execute_code.py | 40 +++++++++++++++++++--- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/tests/metagpt/actions/test_execute_code.py b/tests/metagpt/actions/test_execute_code.py index ecddccf6f..21627e6f9 100644 --- a/tests/metagpt/actions/test_execute_code.py +++ b/tests/metagpt/actions/test_execute_code.py @@ -1,7 +1,6 @@ import pytest from metagpt.actions.execute_code import ExecutePyCode, truncate -from metagpt.schema import Message @pytest.mark.asyncio @@ -11,9 +10,6 @@ async def test_code_running(): assert output[1] is True output = await pi.run({"code": "print('hello world!')", "language": "python"}) assert output[1] is True - code_msg = Message("print('hello world!')") - output = await pi.run(code_msg) - assert output[1] is True @pytest.mark.asyncio @@ -67,6 +63,15 @@ def test_truncate(): output, is_success = truncate("hello world", 5, False) assert "Truncated to show only last 5 characters\nworld" in output assert not is_success + # 异步 + output, is_success = truncate(" Date: Mon, 22 Jan 2024 18:54:21 +0800 Subject: [PATCH 8/8] update _process_code. --- metagpt/actions/execute_code.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/metagpt/actions/execute_code.py b/metagpt/actions/execute_code.py index a5a766ab2..6591f479f 100644 --- a/metagpt/actions/execute_code.py +++ b/metagpt/actions/execute_code.py @@ -165,7 +165,7 @@ class ExecutePyCode(ExecuteCode, Action): # 如果在Python脚本中运行,__file__ 变量存在 return False - def _process_code(self, code: Union[str, Dict, Message], language: str = None) -> Tuple: + def _process_code(self, code: Union[str, Dict], language: str = None) -> Tuple: language = language or "python" if isinstance(code, str) and Path(code).suffix in (".py", ".txt"): code = Path(code).read_text(encoding="utf-8") @@ -173,20 +173,10 @@ class ExecutePyCode(ExecuteCode, Action): if isinstance(code, str): return code, language + if isinstance(code, dict): assert "code" in code - if "language" not in code: - code["language"] = "python" - code, language = code["code"], code["language"] - elif isinstance(code, Message): - if isinstance(code.content, dict) and "language" not in code.content: - code.content["language"] = "python" - code, language = code.content["code"], code.content["language"] - elif isinstance(code.content, str): - code, language = code.content, language - else: - raise ValueError(f"Not support code type {type(code).__name__}.") - + code = code["code"] return code, language async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]: