From f3c41b6fb5b72b5f687b0b50ee9386178d24afa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 16:51:39 +0800 Subject: [PATCH] update default config int --- examples/android_assistant/run_assistant.py | 3 +- tests/metagpt/learn/test_text_to_embedding.py | 3 +- tests/metagpt/learn/test_text_to_image.py | 4 +- tests/metagpt/learn/test_text_to_speech.py | 4 +- .../roles/di/run_swe_agent_for_benchmark.py | 3 +- tests/metagpt/test_document.py | 4 +- tests/metagpt/tools/test_azure_tts.py | 4 +- tests/metagpt/tools/test_iflytek_tts.py | 3 +- .../tools/test_metagpt_text_to_image.py | 4 +- tests/metagpt/tools/test_moderation.py | 4 +- .../tools/test_openai_text_to_embedding.py | 3 +- .../tools/test_openai_text_to_image.py | 4 +- tests/metagpt/tools/test_ut_writer.py | 4 +- .../utils/test_repair_llm_raw_output.py | 4 +- tests/metagpt/utils/test_sanitize.py | 246 ++++++++++++++++++ tests/mock/mock_llm.py | 4 +- 16 files changed, 261 insertions(+), 40 deletions(-) create mode 100644 tests/metagpt/utils/test_sanitize.py diff --git a/examples/android_assistant/run_assistant.py b/examples/android_assistant/run_assistant.py index dbd1dc6ff..7d5d4d5c8 100644 --- a/examples/android_assistant/run_assistant.py +++ b/examples/android_assistant/run_assistant.py @@ -9,7 +9,7 @@ from pathlib import Path import typer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.environment.android.android_env import AndroidEnv from metagpt.ext.android_assistant.roles.android_assistant import AndroidAssistant from metagpt.team import Team @@ -41,7 +41,6 @@ def startup( ), device_id: str = typer.Option(default="emulator-5554", help="The Android device_id"), ): - config = Config.default() config.extra = { "stage": stage, "mode": mode, diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index f50f6a7aa..3b5486c5d 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -11,7 +11,7 @@ from pathlib import Path import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_embedding import text_to_embedding from metagpt.utils.common import aread @@ -19,7 +19,6 @@ from metagpt.utils.common import aread @pytest.mark.asyncio async def test_text_to_embedding(mocker): # mock - config = Config.default() mock_post = mocker.patch("aiohttp.ClientSession.post") mock_response = mocker.AsyncMock() mock_response.status = 200 diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py index d3272dadd..eb252589b 100644 --- a/tests/metagpt/learn/test_text_to_image.py +++ b/tests/metagpt/learn/test_text_to_image.py @@ -12,7 +12,7 @@ import openai import pytest from pydantic import BaseModel -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_image import text_to_image from metagpt.tools.metagpt_text_to_image import MetaGPTText2Image from metagpt.tools.openai_text_to_image import OpenAIText2Image @@ -26,7 +26,6 @@ async def test_text_to_image(mocker): mocker.patch.object(OpenAIText2Image, "text_2_image", return_value=b"mock OpenAIText2Image") mocker.patch.object(S3, "cache", return_value="http://mock/s3") - config = Config.default() assert config.metagpt_tti_url data = await text_to_image("Panda emoji", size_type="512x512", config=config) @@ -51,7 +50,6 @@ async def test_openai_text_to_image(mocker): mock_post.return_value.__aenter__.return_value = mock_response mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png") - config = Config.default() config.metagpt_tti_url = None assert config.get_openai_llm() diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py index f01e5d132..480e35f7a 100644 --- a/tests/metagpt/learn/test_text_to_speech.py +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -10,7 +10,7 @@ import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_speech import text_to_speech from metagpt.tools.iflytek_tts import IFlyTekTTS from metagpt.utils.s3 import S3 @@ -19,7 +19,6 @@ from metagpt.utils.s3 import S3 @pytest.mark.asyncio async def test_azure_text_to_speech(mocker): # mock - config = Config.default() config.iflytek_api_key = None config.iflytek_api_secret = None config.iflytek_app_id = None @@ -47,7 +46,6 @@ async def test_azure_text_to_speech(mocker): @pytest.mark.asyncio async def test_iflytek_text_to_speech(mocker): # mock - config = Config.default() config.azure_tts_subscription_key = None config.azure_tts_region = None mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py index 5ceba6dcc..ce4ef94a4 100644 --- a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -7,7 +7,7 @@ import sys from datetime import datetime from pathlib import Path -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger from metagpt.roles.di.engineer2 import Engineer2 @@ -15,7 +15,6 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.libs.terminal import Terminal from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset -config = Config.default() # Specify by yourself TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo" DATA_DIR = METAGPT_ROOT / "data/hugging_face" diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index 29393bb13..9c076f4e6 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -5,12 +5,10 @@ @Author : alexanderwu @File : test_document.py """ -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.document import Repo from metagpt.logs import logger -config = Config.default() - def set_existing_repo(path): repo1 = Repo.from_path(path) diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index ee55616d2..f72b5663b 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -12,11 +12,9 @@ from pathlib import Path import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.azure_tts import AzureTTS -config = Config.default() - @pytest.mark.asyncio async def test_azure_tts(mocker): diff --git a/tests/metagpt/tools/test_iflytek_tts.py b/tests/metagpt/tools/test_iflytek_tts.py index c51f62b8e..b4bcadb89 100644 --- a/tests/metagpt/tools/test_iflytek_tts.py +++ b/tests/metagpt/tools/test_iflytek_tts.py @@ -7,14 +7,13 @@ """ import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.iflytek_tts import IFlyTekTTS, oas3_iflytek_tts @pytest.mark.asyncio async def test_iflytek_tts(mocker): # mock - config = Config.default() config.azure_tts_subscription_key = None config.azure_tts_region = None mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py index bd0fcaf8b..d3797a460 100644 --- a/tests/metagpt/tools/test_metagpt_text_to_image.py +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -10,11 +10,9 @@ from unittest.mock import AsyncMock import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image -config = Config.default() - @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 0f921887f..8dc9e9d5e 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,12 +8,10 @@ import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.llm import LLM from metagpt.tools.moderation import Moderation -config = Config.default() - @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/metagpt/tools/test_openai_text_to_embedding.py b/tests/metagpt/tools/test_openai_text_to_embedding.py index 81b3895c3..047206d48 100644 --- a/tests/metagpt/tools/test_openai_text_to_embedding.py +++ b/tests/metagpt/tools/test_openai_text_to_embedding.py @@ -10,7 +10,7 @@ from pathlib import Path import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding from metagpt.utils.common import aread @@ -18,7 +18,6 @@ from metagpt.utils.common import aread @pytest.mark.asyncio async def test_embedding(mocker): # mock - config = Config.default() mock_post = mocker.patch("aiohttp.ClientSession.post") mock_response = mocker.AsyncMock() mock_response.status = 200 diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 4856342d1..3f9169ddd 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -11,7 +11,7 @@ import openai import pytest from pydantic import BaseModel -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.llm import LLM from metagpt.tools.openai_text_to_image import ( OpenAIText2Image, @@ -19,8 +19,6 @@ from metagpt.tools.openai_text_to_image import ( ) from metagpt.utils.s3 import S3 -config = Config.default() - @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_ut_writer.py b/tests/metagpt/tools/test_ut_writer.py index ebb8c8aa2..557067191 100644 --- a/tests/metagpt/tools/test_ut_writer.py +++ b/tests/metagpt/tools/test_ut_writer.py @@ -20,12 +20,10 @@ from openai.types.chat.chat_completion_message_tool_call import ( Function, ) -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator -config = Config.default() - class TestUTWriter: @pytest.mark.asyncio diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 75bd9f165..7a29ea3ee 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -2,9 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : unittest of repair_llm_raw_output -from metagpt.config2 import Config - -config = Config.default() +from metagpt.config2 import config """ CONFIG.repair_llm_output should be True before retry_parse_json_text imported. diff --git a/tests/metagpt/utils/test_sanitize.py b/tests/metagpt/utils/test_sanitize.py new file mode 100644 index 000000000..c229af173 --- /dev/null +++ b/tests/metagpt/utils/test_sanitize.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from unittest.mock import Mock, patch + +import pytest + +from metagpt.utils.sanitize import ( + NodeType, + code_extract, + get_definition_name, + get_deps, + get_function_dependency, + has_return_statement, + sanitize, + syntax_check, + traverse_tree, +) + + +@pytest.fixture +def mock_node(): + node = Mock() + node.type = "test_node" + node.text = b"test_text" + node.children = [] + return node + + +def test_node_type_enum(): + assert NodeType.CLASS.value == "class_definition" + assert NodeType.FUNCTION.value == "function_definition" + assert isinstance(NodeType.IMPORT.value, list) + + +@patch("tree_sitter.Node") +def test_traverse_tree(mock_node_class): + # 测试基本情况:没有子节点的情况 + root = Mock() + cursor = Mock() + cursor.node = root + cursor.goto_first_child.return_value = False + cursor.goto_next_sibling.return_value = False + cursor.goto_parent.return_value = False + root.walk.return_value = cursor + + nodes = list(traverse_tree(root)) + assert len(nodes) == 1 + assert nodes[0] == root + + # 测试有子节点和兄弟节点的情况 + cursor2 = Mock() + cursor2.node = Mock() + + # 模拟遍历行为 + first_child_calls = [True, False] + next_sibling_calls = [False] + parent_calls = [True, False] + + cursor2.goto_first_child.side_effect = lambda: first_child_calls.pop(0) if first_child_calls else False + cursor2.goto_next_sibling.side_effect = lambda: next_sibling_calls.pop(0) if next_sibling_calls else False + cursor2.goto_parent.side_effect = lambda: parent_calls.pop(0) if parent_calls else False + + root2 = Mock() + root2.walk.return_value = cursor2 + nodes = list(traverse_tree(root2)) + assert len(nodes) > 1 + + +def test_syntax_check(): + # 测试有效代码 + assert syntax_check("def test(): return True") is True + + # 测试无效代码 + assert syntax_check("def test() return True") is False + + # 测试无效代码(带verbose) + assert syntax_check("def test() return True", verbose=True) is False + + # 测试内存错误情况 + with patch("ast.parse", side_effect=MemoryError): + assert syntax_check("large_code", verbose=True) is False + + +def test_code_extract(): + # 测试基本情况 + text = "def valid_function():\n return True\n" + result = code_extract(text) + assert syntax_check(result) + assert "def valid_function" in result + + # 测试空字符串 + assert code_extract("") == "" + + # 测试单行有效语法 + single_line = "x = 1" + result = code_extract(single_line) + assert syntax_check(result) + assert "x = 1" in result + + # 测试完全无效的代码 + assert code_extract("invalid!!!!") == "" or code_extract("invalid!!!!") == "invalid!!!!" + + # 测试带有嵌套结构的有效代码 + nested_code = """def outer():\n def inner():\n return True\n""" + result = code_extract(nested_code) + assert syntax_check(result) + assert "def outer" in result + + +def test_get_definition_name(): + # 基本测试 + mock_identifier = Mock() + mock_identifier.type = NodeType.IDENTIFIER.value + mock_identifier.text = b"test_function" + + mock_node = Mock() + mock_node.children = [mock_identifier] + assert get_definition_name(mock_node) == "test_function" + + # 测试空children + mock_node.children = [] + assert get_definition_name(mock_node) is None + + # 测试children中没有identifier + mock_node.children = [Mock(type="not_identifier")] + assert get_definition_name(mock_node) is None + + +@pytest.mark.parametrize( + "node_type,expected", + [ + (NodeType.RETURN.value, True), + ("other_type", False), + ], +) +def test_has_return_statement(node_type, expected): + mock_node = Mock() + cursor = Mock() + cursor.node = Mock() + cursor.node.type = node_type + cursor.goto_first_child.return_value = False + cursor.goto_next_sibling.return_value = False + cursor.goto_parent.return_value = False + mock_node.walk.return_value = cursor + + assert has_return_statement(mock_node) is expected + + +def test_get_deps(): + mock_id1 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep1") + mock_id2 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep2") + mock_node = Mock(children=[mock_id1, mock_id2]) + + nodes = [("test_func", mock_node)] + result = get_deps(nodes) + + assert "test_func" in result + assert result["test_func"] == {"dep1", "dep2"} + + # 测试嵌套结构 + nested_node = Mock(children=[Mock(type="not_identifier", children=[mock_id1])]) + nodes = [("nested_func", nested_node)] + result = get_deps(nodes) + assert result["nested_func"] == {"dep1"} + + +def test_get_function_dependency(): + call_graph = {"main": {"helper1", "helper2"}, "helper1": {"helper3"}, "helper2": set(), "helper3": set()} + + result = get_function_dependency("main", call_graph) + assert result == {"main", "helper1", "helper2", "helper3"} + + assert get_function_dependency("non_existent", call_graph) == {"non_existent"} + + +@patch("tree_sitter.Parser") +@patch("tree_sitter.Language") +def test_sanitize(mock_language, mock_parser): + test_code = """import math +from os import path + +class TestClass: + def method(self): return True + +def test_function(): + return True + +x = 1""" + + mock_root = Mock() + mock_nodes = [] + + # 添加导入语句 + import_node = Mock(type="import_statement", start_byte=0, end_byte=11) + import_from_node = Mock(type="import_from_statement", start_byte=12, end_byte=30) + mock_nodes.extend([import_node, import_from_node]) + + # 添加类定义 + class_node = Mock(type="class_definition", start_byte=32, end_byte=80) + class_id = Mock(type="identifier", text=b"TestClass") + class_node.children = [class_id] + mock_nodes.append(class_node) + + # 添加函数定义 + func_node = Mock(type="function_definition", start_byte=82, end_byte=110) + func_id = Mock(type="identifier", text=b"test_function") + return_node = Mock(type="return_statement") + func_node.children = [func_id, return_node] + mock_nodes.append(func_node) + + # 添加赋值语句 + assign_node = Mock(type="expression_statement", start_byte=112, end_byte=117) + assign_child = Mock(type="assignment") + var_id = Mock(type="identifier", text=b"x") + assign_child.children = [var_id] + assign_node.children = [assign_child] + mock_nodes.append(assign_node) + + mock_root.children = mock_nodes + mock_tree = Mock(root_node=mock_root) + mock_parser.return_value.parse.return_value = mock_tree + + # 测试无entrypoint情况 + result = sanitize(test_code) + assert isinstance(result, str) + assert len(result) > 0 + + # 测试有entrypoint情况 + result = sanitize(test_code, entrypoint="test_function") + assert isinstance(result, str) + assert len(result) > 0 + + # 测试空代码 + assert sanitize("") == "" + + # 测试无效代码 + assert sanitize("invalid code") == "invalid!!!!" or sanitize("invalid code") == "" + + # 测试函数依赖 + mock_nodes = [func_node] # 只保留函数节点 + mock_root.children = mock_nodes + result = sanitize(test_code, entrypoint="test_function") + assert isinstance(result, str) + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index e58ce4120..704403e64 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.configs.llm_config import LLMType from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger @@ -10,8 +10,6 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message -config = Config.default() - OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM