mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Merge branch 'add-test-for-aask_code-executecode' into 'code_intepreter'
Add test for aask code and executecode See merge request agents/data_agents_opt!57
This commit is contained in:
commit
93538cc848
4 changed files with 143 additions and 165 deletions
|
|
@ -15,14 +15,13 @@ import nbformat
|
|||
from nbclient import NotebookClient
|
||||
from nbclient.exceptions import CellTimeoutError, DeadKernelError
|
||||
from nbformat import NotebookNode
|
||||
from nbformat.v4 import new_code_cell, new_output, new_markdown_cell
|
||||
from rich.console import Console
|
||||
from rich.syntax import Syntax
|
||||
from nbformat.v4 import new_code_cell, new_markdown_cell, new_output
|
||||
from rich.box import MINIMAL
|
||||
from rich.console import Console, Group
|
||||
from rich.live import Live
|
||||
from rich.markdown import Markdown
|
||||
from rich.panel import Panel
|
||||
from rich.box import MINIMAL
|
||||
from rich.live import Live
|
||||
from rich.console import Group
|
||||
from rich.syntax import Syntax
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -166,7 +165,7 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
# 如果在Python脚本中运行,__file__ 变量存在
|
||||
return False
|
||||
|
||||
def _process_code(self, code: Union[str, Dict, Message], language: str = None) -> Tuple:
|
||||
def _process_code(self, code: Union[str, Dict], language: str = None) -> Tuple:
|
||||
language = language or "python"
|
||||
if isinstance(code, str) and Path(code).suffix in (".py", ".txt"):
|
||||
code = Path(code).read_text(encoding="utf-8")
|
||||
|
|
@ -174,20 +173,10 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
|
||||
if isinstance(code, str):
|
||||
return code, language
|
||||
|
||||
if isinstance(code, dict):
|
||||
assert "code" in code
|
||||
if "language" not in code:
|
||||
code["language"] = "python"
|
||||
code, language = code["code"], code["language"]
|
||||
elif isinstance(code, Message):
|
||||
if isinstance(code.content, dict) and "language" not in code.content:
|
||||
code.content["language"] = "python"
|
||||
code, language = code.content["code"], code.content["language"]
|
||||
elif isinstance(code.content, str):
|
||||
code, language = code.content, language
|
||||
else:
|
||||
raise ValueError(f"Not support code type {type(code).__name__}.")
|
||||
|
||||
code = code["code"]
|
||||
return code, language
|
||||
|
||||
async def run_cell(self, cell: NotebookNode, cell_index: int) -> Tuple[bool, str]:
|
||||
|
|
@ -229,7 +218,7 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
# code success
|
||||
outputs = self.parse_outputs(self.nb.cells[-1].outputs)
|
||||
return truncate(remove_escape_and_color_codes(outputs), is_success=success)
|
||||
elif language == 'markdown':
|
||||
elif language == "markdown":
|
||||
# markdown
|
||||
self.add_markdown_cell(code)
|
||||
return code, True
|
||||
|
|
@ -238,28 +227,28 @@ class ExecutePyCode(ExecuteCode, Action):
|
|||
|
||||
|
||||
def truncate(result: str, keep_len: int = 2000, is_success: bool = True):
|
||||
desc = f"Executed code {'successfully' if is_success else 'failed, please reflect the cause of bug and then debug'}"
|
||||
"""对于超出keep_len个字符的result: 执行失败的代码, 展示result后keep_len个字符; 执行成功的代码, 展示result前keep_len个字符。"""
|
||||
desc = f"Executed code {'successfully. ' if is_success else 'failed, please reflect the cause of bug and then debug. '}"
|
||||
is_same_desc = False
|
||||
|
||||
if is_success:
|
||||
desc += f"Truncated to show only {keep_len} characters\n"
|
||||
desc += f"Truncated to show only first {keep_len} characters\n"
|
||||
else:
|
||||
desc += "Show complete information for you."
|
||||
desc += f"Truncated to show only last {keep_len} characters\n"
|
||||
|
||||
if result.startswith(desc):
|
||||
result = result[len(desc) :]
|
||||
is_same_desc = True
|
||||
|
||||
if result.strip().startswith("<coroutine object"):
|
||||
result = "Executed code failed, you need use key word 'await' to run a async code."
|
||||
return result, False
|
||||
|
||||
if len(result) > keep_len:
|
||||
result = result[-keep_len:] if not is_success else result
|
||||
if not result:
|
||||
result = 'No output about your code. Only when importing packages it is normal case. Recap and go ahead.'
|
||||
return result, False
|
||||
result = result[-keep_len:] if not is_success else result[:keep_len]
|
||||
return desc + result, is_success
|
||||
|
||||
if result.strip().startswith("<coroutine object"):
|
||||
result = "Executed code failed, you need use key word 'await' to run a async code."
|
||||
return result, False
|
||||
|
||||
return desc + result[:keep_len+500], is_success
|
||||
|
||||
return result, is_success
|
||||
return result if not is_same_desc else desc + result, is_success
|
||||
|
||||
|
||||
def remove_escape_and_color_codes(input_str):
|
||||
|
|
@ -271,13 +260,13 @@ def remove_escape_and_color_codes(input_str):
|
|||
|
||||
def display_markdown(content: str):
|
||||
# 使用正则表达式逐个匹配代码块
|
||||
matches = re.finditer(r'```(.+?)```', content, re.DOTALL)
|
||||
matches = re.finditer(r"```(.+?)```", content, re.DOTALL)
|
||||
start_index = 0
|
||||
content_panels = []
|
||||
# 逐个打印匹配到的文本和代码
|
||||
for match in matches:
|
||||
text_content = content[start_index:match.start()].strip()
|
||||
code_content = match.group(0).strip()[3:-3] # Remove triple backticks
|
||||
text_content = content[start_index : match.start()].strip()
|
||||
code_content = match.group(0).strip()[3:-3] # Remove triple backticks
|
||||
|
||||
if text_content:
|
||||
content_panels.append(Panel(Markdown(text_content), box=MINIMAL))
|
||||
|
|
|
|||
|
|
@ -237,9 +237,13 @@ class OpenAILLM(BaseLLM):
|
|||
try:
|
||||
return json.loads(message.tool_calls[0].function.arguments, strict=False)
|
||||
except json.decoder.JSONDecodeError as e:
|
||||
logger.debug(
|
||||
f"Got JSONDecodeError for {message.tool_calls[0].function.arguments},\
|
||||
we will use RegExp to parse code, \n {e}"
|
||||
logger.warning(
|
||||
"\n".join(
|
||||
[
|
||||
(f"Got JSONDecodeError for \n{'--'*40} \n{message.tool_calls[0].function.arguments}"),
|
||||
(f"{'--'*40}\nwe will use RegExp to parse code. JSONDecodeError is: {e}"),
|
||||
]
|
||||
)
|
||||
)
|
||||
return self._parse_arguments(message.tool_calls[0].function.arguments)
|
||||
elif message.tool_calls is None and message.content is not None:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions.execute_code import ExecutePyCode, truncate
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -11,9 +10,6 @@ async def test_code_running():
|
|||
assert output[1] is True
|
||||
output = await pi.run({"code": "print('hello world!')", "language": "python"})
|
||||
assert output[1] is True
|
||||
code_msg = Message("print('hello world!')")
|
||||
output = await pi.run(code_msg)
|
||||
assert output[1] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -52,42 +48,30 @@ async def test_plotting_code():
|
|||
|
||||
# 显示图形
|
||||
plt.show()
|
||||
plt.close()
|
||||
"""
|
||||
output = await pi.run(code)
|
||||
assert output[1] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plotting_bug():
|
||||
code = """
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import pandas as pd
|
||||
from sklearn.datasets import load_iris
|
||||
# Load the Iris dataset
|
||||
iris_data = load_iris()
|
||||
# Convert the loaded Iris dataset into a DataFrame for easier manipulation
|
||||
iris_df = pd.DataFrame(iris_data['data'], columns=iris_data['feature_names'])
|
||||
# Add a column for the target
|
||||
iris_df['species'] = pd.Categorical.from_codes(iris_data['target'], iris_data['target_names'])
|
||||
# Set the style of seaborn
|
||||
sns.set(style='whitegrid')
|
||||
# Create a pairplot of the iris dataset
|
||||
plt.figure(figsize=(10, 8))
|
||||
pairplot = sns.pairplot(iris_df, hue='species')
|
||||
# Show the plot
|
||||
plt.show()
|
||||
"""
|
||||
pi = ExecutePyCode()
|
||||
output = await pi.run(code)
|
||||
assert output[1] is True
|
||||
|
||||
|
||||
def test_truncate():
|
||||
output = "hello world"
|
||||
assert truncate(output) == output
|
||||
output = "hello world"
|
||||
assert truncate(output, 5) == "Truncated to show only the last 5 characters\nworld"
|
||||
# 代码执行成功
|
||||
output, is_success = truncate("hello world", 5, True)
|
||||
assert "Truncated to show only first 5 characters\nhello" in output
|
||||
assert is_success
|
||||
# 代码执行失败
|
||||
output, is_success = truncate("hello world", 5, False)
|
||||
assert "Truncated to show only last 5 characters\nworld" in output
|
||||
assert not is_success
|
||||
# 异步
|
||||
output, is_success = truncate("<coroutine object", 5, True)
|
||||
assert not is_success
|
||||
assert "await" in output
|
||||
# 重复的desc
|
||||
result = "Executed code successfully. Truncated to show only first 5 characters\nhello"
|
||||
output, is_success = truncate(result, 5, True)
|
||||
assert is_success
|
||||
assert output == result
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -97,3 +81,41 @@ async def test_run_with_timeout():
|
|||
message, success = await pi.run(code)
|
||||
assert not success
|
||||
assert message.startswith("Cell execution timed out")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_code_text():
|
||||
pi = ExecutePyCode()
|
||||
message, success = await pi.run(code='print("This is a code!")', language="python")
|
||||
assert success
|
||||
assert message == "This is a code!\n"
|
||||
message, success = await pi.run(code="# This is a code!", language="markdown")
|
||||
assert success
|
||||
assert message == "# This is a code!"
|
||||
mix_text = "# Title!\n ```python\n print('This is a code!')```"
|
||||
message, success = await pi.run(code=mix_text, language="markdown")
|
||||
assert success
|
||||
assert message == mix_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_terminate():
|
||||
pi = ExecutePyCode()
|
||||
await pi.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await pi.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await pi.terminate()
|
||||
import time
|
||||
|
||||
time.sleep(2)
|
||||
assert pi.nb_client.km is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reset():
|
||||
pi = ExecutePyCode()
|
||||
await pi.run(code='print("This is a code!")', language="python")
|
||||
is_kernel_alive = await pi.nb_client.km.is_alive()
|
||||
assert is_kernel_alive
|
||||
await pi.reset()
|
||||
assert pi.nb_client.km is None
|
||||
|
|
|
|||
|
|
@ -1,104 +1,21 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
CONFIG.openai_proxy = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAILLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAILLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAILLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
def test_ask_code_list_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"]
|
||||
rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_code_steps2():
|
||||
llm = OpenAIGPTAPI()
|
||||
msg = ["step by setp 生成代码: Step 1. 先生成随机数组a, Step 2. 求a中最大值, Step 3. 绘制数据a的直方图"]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'}
|
||||
print(rsp)
|
||||
assert "language" in rsp
|
||||
assert "code" in rsp
|
||||
assert len(rsp["code"]) > 0
|
||||
assert "Step 1" in rsp["code"]
|
||||
assert "Step 2" in rsp["code"]
|
||||
assert "Step 3" in rsp["code"]
|
||||
|
||||
|
||||
class TestOpenAI:
|
||||
@pytest.fixture
|
||||
def config(self):
|
||||
|
|
@ -146,6 +63,38 @@ class TestOpenAI:
|
|||
openai_api_type="azure",
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def tool_calls_rsp(self):
|
||||
function_rsps = [
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'}', name="execute"),
|
||||
Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```}', name="execute"),
|
||||
Function(arguments='{\n"language": "python",\n"code": """print("hello world")"""}', name="execute"),
|
||||
Function(arguments='\nprint("hello world")\\n', name="execute"),
|
||||
# only `{` in arguments
|
||||
Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
# no `{`, `}` in arguments
|
||||
Function(arguments='\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"),
|
||||
]
|
||||
tool_calls = [
|
||||
ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f)
|
||||
for i, f in enumerate(function_rsps)
|
||||
]
|
||||
messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls]
|
||||
# 添加一个纯文本响应
|
||||
messages.append(
|
||||
ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None)
|
||||
)
|
||||
choices = [
|
||||
Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages)
|
||||
]
|
||||
return [
|
||||
ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion")
|
||||
for i, c in enumerate(choices)
|
||||
]
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
|
|
@ -171,3 +120,17 @@ class TestOpenAI:
|
|||
instance.config = config_azure_proxy
|
||||
kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
||||
def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp):
|
||||
instance = OpenAILLM()
|
||||
for i, rsp in enumerate(tool_calls_rsp):
|
||||
code = instance.get_choice_function_arguments(rsp)
|
||||
logger.info(f"\ntest get function call arguments {i}: {code}")
|
||||
assert "code" in code
|
||||
assert "language" in code
|
||||
assert "hello world" in code["code"]
|
||||
|
||||
if "Completed a python code for hello world!" == code["code"]:
|
||||
code["language"] == "markdown"
|
||||
else:
|
||||
code["language"] == "python"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue