From 5ddaaaa3471e096b5ea02f2e2b8a4cc34f50332a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=A3=92=E6=A3=92?= Date: Mon, 22 Jan 2024 14:56:44 +0800 Subject: [PATCH] add test: test_get_choice_function_arguments_for_aask_code. --- tests/metagpt/provider/test_openai.py | 139 +++++++++----------------- 1 file changed, 48 insertions(+), 91 deletions(-) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 8de29c11b..7af2f6892 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,104 +1,21 @@ from unittest.mock import Mock import pytest +from openai.types.chat import ( + ChatCompletion, + ChatCompletionMessage, + ChatCompletionMessageToolCall, +) +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message_tool_call import Function from metagpt.config import CONFIG +from metagpt.logs import logger from metagpt.provider.openai_api import OpenAILLM -from metagpt.schema import UserMessage CONFIG.openai_proxy = None -@pytest.mark.asyncio -async def test_aask_code(): - llm = OpenAILLM() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_aask_code_str(): - llm = OpenAILLM() - msg = "Write a python hello world code." - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_aask_code_Message(): - llm = OpenAILLM() - msg = UserMessage("Write a python hello world code.") - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code(): - llm = OpenAIGPTAPI() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_str(): - llm = OpenAIGPTAPI() - msg = "Write a python hello world code." - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_Message(): - llm = OpenAIGPTAPI() - msg = UserMessage("Write a python hello world code.") - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_Message(): - llm = OpenAIGPTAPI() - msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_str(): - llm = OpenAIGPTAPI() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -@pytest.mark.asyncio -async def test_ask_code_steps2(): - llm = OpenAIGPTAPI() - msg = ["step by setp 生成代码: Step 1. 先生成随机数组a, Step 2. 求a中最大值, Step 3. 绘制数据a的直方图"] - rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - assert "Step 1" in rsp["code"] - assert "Step 2" in rsp["code"] - assert "Step 3" in rsp["code"] - - class TestOpenAI: @pytest.fixture def config(self): @@ -146,6 +63,32 @@ class TestOpenAI: openai_api_type="azure", ) + @pytest.fixture + def tool_calls_rsp(self): + function_rsps = [ + Function(arguments='{\n"language": "python",\n"code": "print(\'hello world\')"', name="execute"), + Function(arguments='{\n"language": "python",\n"code": ```print("hello world")```', name="execute"), + Function(arguments='{\n"language": "python",\n"code": \'print("hello world")\'', name="execute"), + Function(arguments='{\n"language": \'python\',\n"code": "print(\'hello world\')"', name="execute"), + Function(arguments='\nprint("hello world")\\n', name="execute"), + ] + tool_calls = [ + ChatCompletionMessageToolCall(type="function", id=f"call_{i}", function=f) + for i, f in enumerate(function_rsps) + ] + messages = [ChatCompletionMessage(content=None, role="assistant", tool_calls=[t]) for t in tool_calls] + # 添加一个纯文本响应 + messages.append( + ChatCompletionMessage(content="Completed a python code for hello world!", role="assistant", tool_calls=None) + ) + choices = [ + Choice(finish_reason="tool_calls", logprobs=None, index=i, message=msg) for i, msg in enumerate(messages) + ] + return [ + ChatCompletion(id=str(i), choices=[c], created=i, model="gpt-4", object="chat.completion") + for i, c in enumerate(choices) + ] + def test_make_client_kwargs_without_proxy(self, config): instance = OpenAILLM() instance.config = config @@ -171,3 +114,17 @@ class TestOpenAI: instance.config = config_azure_proxy kwargs = instance._make_client_kwargs() assert "http_client" in kwargs + + def test_get_choice_function_arguments_for_aask_code(self, tool_calls_rsp): + instance = OpenAILLM() + for i, rsp in enumerate(tool_calls_rsp): + code = instance.get_choice_function_arguments(rsp) + logger.info(f"\ntest get function call arguments {i}: {code}") + assert "code" in code + assert "language" in code + assert "hello world" in code["code"] + + if "Completed a python code for hello world!" == code["code"]: + code["language"] == "markdown" + else: + code["language"] == "python"