From db8b2976a122e43f49b2f4e0f2fc2d74b27df413 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 14:55:52 +0800 Subject: [PATCH 1/9] fix bug --- metagpt/environment/android/env_space.py | 2 +- metagpt/environment/werewolf/env_space.py | 2 +- metagpt/ext/android_assistant/actions/manual_record.py | 3 ++- metagpt/ext/android_assistant/actions/parse_record.py | 3 ++- metagpt/ext/android_assistant/actions/screenshot_parse.py | 3 ++- .../ext/android_assistant/actions/self_learn_and_reflect.py | 3 ++- metagpt/ext/android_assistant/roles/android_assistant.py | 4 ++-- metagpt/ext/android_assistant/utils/utils.py | 3 ++- metagpt/rag/engines/simple.py | 3 ++- setup.py | 2 +- 10 files changed, 17 insertions(+), 11 deletions(-) diff --git a/metagpt/environment/android/env_space.py b/metagpt/environment/android/env_space.py index 9580e3a7d..8225f0127 100644 --- a/metagpt/environment/android/env_space.py +++ b/metagpt/environment/android/env_space.py @@ -10,7 +10,7 @@ import numpy.typing as npt from gymnasium import spaces from pydantic import ConfigDict, Field, field_validator -from metagpt.environment.base_env_space import ( +from metagpt.base.base_env_space import ( BaseEnvAction, BaseEnvActionType, BaseEnvObsParams, diff --git a/metagpt/environment/werewolf/env_space.py b/metagpt/environment/werewolf/env_space.py index e6243b10f..dd6ceeabe 100644 --- a/metagpt/environment/werewolf/env_space.py +++ b/metagpt/environment/werewolf/env_space.py @@ -5,7 +5,7 @@ from gymnasium import spaces from pydantic import ConfigDict, Field -from metagpt.environment.base_env_space import BaseEnvAction, BaseEnvActionType +from metagpt.base.base_env_space import BaseEnvAction, BaseEnvActionType from metagpt.environment.werewolf.const import STEP_INSTRUCTIONS diff --git a/metagpt/ext/android_assistant/actions/manual_record.py b/metagpt/ext/android_assistant/actions/manual_record.py index bcfb2ed89..71e4d5e82 100644 --- a/metagpt/ext/android_assistant/actions/manual_record.py +++ b/metagpt/ext/android_assistant/actions/manual_record.py @@ -7,7 +7,7 @@ from pathlib import Path import cv2 from metagpt.actions.action import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -55,6 +55,7 @@ class ManualRecord(Action): self.task_desc_path.write_text(task_desc) step = 0 + config = Config.default() extra_config = config.extra while True: step += 1 diff --git a/metagpt/ext/android_assistant/actions/parse_record.py b/metagpt/ext/android_assistant/actions/parse_record.py index 304daf655..f96ad0a19 100644 --- a/metagpt/ext/android_assistant/actions/parse_record.py +++ b/metagpt/ext/android_assistant/actions/parse_record.py @@ -8,7 +8,7 @@ import re from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.ext.android_assistant.actions.parse_record_an import RECORD_PARSE_NODE from metagpt.ext.android_assistant.prompts.operation_prompt import ( long_press_doc_template, @@ -45,6 +45,7 @@ class ParseRecord(Action): path.mkdir(parents=True, exist_ok=True) task_desc = self.task_desc_path.read_text() + config = Config.default() extra_config = config.extra with open(self.record_path, "r") as record_file: diff --git a/metagpt/ext/android_assistant/actions/screenshot_parse.py b/metagpt/ext/android_assistant/actions/screenshot_parse.py index 4d8bb0e1e..8cb738522 100644 --- a/metagpt/ext/android_assistant/actions/screenshot_parse.py +++ b/metagpt/ext/android_assistant/actions/screenshot_parse.py @@ -6,7 +6,7 @@ import ast from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -101,6 +101,7 @@ next action. You should always prioritize these documented elements for interact grid_on: bool, env: AndroidEnv, ): + config = Config.default() extra_config = config.extra for path in [task_dir, docs_dir]: path.mkdir(parents=True, exist_ok=True) diff --git a/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py index 5e9cfbb45..b783217ff 100644 --- a/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py +++ b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py @@ -6,7 +6,7 @@ import ast from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -80,6 +80,7 @@ class SelfLearnAndReflect(Action): async def run_self_learn( self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: + config = Config.default() extra_config = config.extra screenshot_path: Path = env.observe( EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) diff --git a/metagpt/ext/android_assistant/roles/android_assistant.py b/metagpt/ext/android_assistant/roles/android_assistant.py index 45636f519..6462e30a2 100644 --- a/metagpt/ext/android_assistant/roles/android_assistant.py +++ b/metagpt/ext/android_assistant/roles/android_assistant.py @@ -9,7 +9,7 @@ from typing import Optional from pydantic import Field from metagpt.actions.add_requirement import UserRequirement -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.const import EXAMPLE_PATH from metagpt.ext.android_assistant.actions.manual_record import ManualRecord from metagpt.ext.android_assistant.actions.parse_record import ParseRecord @@ -38,7 +38,7 @@ class AndroidAssistant(Role): def __init__(self, **data): super().__init__(**data) - + config = Config.default() self._watch([UserRequirement, AndroidActionOutput]) extra_config = config.extra self.task_desc = extra_config.get("task_desc", "Just explore any app in this phone!") diff --git a/metagpt/ext/android_assistant/utils/utils.py b/metagpt/ext/android_assistant/utils/utils.py index f1fa13869..168369054 100644 --- a/metagpt/ext/android_assistant/utils/utils.py +++ b/metagpt/ext/android_assistant/utils/utils.py @@ -10,7 +10,7 @@ from xml.etree.ElementTree import Element, iterparse import cv2 import pyshine as ps -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.ext.android_assistant.utils.schema import ( ActionOp, AndroidElement, @@ -48,6 +48,7 @@ def get_id_from_element(elem: Element) -> str: def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False): path = [] + config = Config.default() extra_config = config.extra for event, elem in iterparse(str(xml_path), ["start", "end"]): if event == "start": diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 61200a295..1c0834c96 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -31,7 +31,7 @@ from llama_index.core.schema import ( TransformComponent, ) -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.rag.factories import ( get_index, get_rag_embedding, @@ -400,6 +400,7 @@ class SimpleEngine(RetrieverQueryEngine): dict[file_type: BaseReader] """ file_extractor: dict[str:BaseReader] = {} + config = Config.default() if config.omniparse.base_url: pdf_parser = OmniParse( api_key=config.omniparse.api_key, diff --git a/setup.py b/setup.py index 2ffc09ee8..658c82219 100644 --- a/setup.py +++ b/setup.py @@ -43,7 +43,7 @@ extras_require = { "llama-index-postprocessor-cohere-rerank==0.1.4", "llama-index-postprocessor-colbert-rerank==0.1.1", "llama-index-postprocessor-flag-embedding-reranker==0.1.2", - # "llama-index-vector-stores-milvus==0.1.23", + "llama-index-vector-stores-milvus==0.1.23", "docx2txt==0.8", ], } From b2aac3ebd5ce2fced7f7abe2727dc8485e0d004b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 14:59:44 +0800 Subject: [PATCH 2/9] fix bug --- examples/android_assistant/run_assistant.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/android_assistant/run_assistant.py b/examples/android_assistant/run_assistant.py index 7d5d4d5c8..dbd1dc6ff 100644 --- a/examples/android_assistant/run_assistant.py +++ b/examples/android_assistant/run_assistant.py @@ -9,7 +9,7 @@ from pathlib import Path import typer -from metagpt.config2 import config +from metagpt.config2 import Config from metagpt.environment.android.android_env import AndroidEnv from metagpt.ext.android_assistant.roles.android_assistant import AndroidAssistant from metagpt.team import Team @@ -41,6 +41,7 @@ def startup( ), device_id: str = typer.Option(default="emulator-5554", help="The Android device_id"), ): + config = Config.default() config.extra = { "stage": stage, "mode": mode, From 9e989af31eb1d0723fba7f75304c2ba8371d1f69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 15:22:42 +0800 Subject: [PATCH 3/9] fix bug --- examples/ui_with_chainlit/app.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/ui_with_chainlit/app.py b/examples/ui_with_chainlit/app.py index 3b449a12c..1522bd478 100644 --- a/examples/ui_with_chainlit/app.py +++ b/examples/ui_with_chainlit/app.py @@ -1,3 +1,5 @@ +from pathlib import Path + import chainlit as cl from init_setup import ChainlitEnv @@ -67,8 +69,8 @@ async def startup(message: cl.Message) -> None: await company.run(n_round=5) - workdir = company.env.context.git_repo.workdir - files = company.env.context.git_repo.get_files(workdir) + workdir = Path(company.env.context.config.project_path) + files = [file.name for file in workdir.iterdir() if file.is_file()] files = "\n".join([f"{workdir}/{file}" for file in files if not file.startswith(".git")]) await cl.Message( From c4f169462f48c31617067f36c9b8c557e0e1957d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 16:06:26 +0800 Subject: [PATCH 4/9] update examples/rag/rag_bm.py --- examples/rag/rag_bm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/rag/rag_bm.py b/examples/rag/rag_bm.py index a6a1145b5..99a546010 100644 --- a/examples/rag/rag_bm.py +++ b/examples/rag/rag_bm.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- """RAG benchmark pipeline""" import asyncio From 632e14d415d3f8f20a7a4341d36b6808dd7daf79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 16:36:29 +0800 Subject: [PATCH 5/9] update config int --- examples/agent_creator.py | 3 +-- metagpt/actions/requirement_analysis/framework/__init__.py | 4 ++-- metagpt/config2.py | 1 + metagpt/exp_pool/decorator.py | 4 +--- metagpt/ext/android_assistant/actions/manual_record.py | 3 +-- metagpt/ext/android_assistant/actions/parse_record.py | 3 +-- metagpt/ext/android_assistant/actions/screenshot_parse.py | 3 +-- .../ext/android_assistant/actions/self_learn_and_reflect.py | 3 +-- metagpt/ext/android_assistant/roles/android_assistant.py | 3 +-- metagpt/ext/android_assistant/utils/utils.py | 3 +-- metagpt/ext/stanford_town/utils/utils.py | 3 +-- metagpt/rag/engines/simple.py | 3 +-- metagpt/rag/factories/llm.py | 3 +-- metagpt/rag/schema.py | 4 +--- metagpt/software_company.py | 4 +--- metagpt/tools/libs/index_repo.py | 3 +-- metagpt/tools/ut_writer.py | 3 +-- metagpt/utils/embedding.py | 3 +-- metagpt/utils/file.py | 3 +-- metagpt/utils/make_sk_kernel.py | 3 +-- 20 files changed, 21 insertions(+), 41 deletions(-) diff --git a/examples/agent_creator.py b/examples/agent_creator.py index 34160d398..bd58840ce 100644 --- a/examples/agent_creator.py +++ b/examples/agent_creator.py @@ -6,13 +6,12 @@ Author: garylin2099 import re from metagpt.actions import Action -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import METAGPT_ROOT from metagpt.logs import logger from metagpt.roles import Role from metagpt.schema import Message -config = Config.default() EXAMPLE_CODE_FILE = METAGPT_ROOT / "examples/build_customized_agent.py" MULTI_ACTION_AGENT_CODE_EXAMPLE = EXAMPLE_CODE_FILE.read_text() diff --git a/metagpt/actions/requirement_analysis/framework/__init__.py b/metagpt/actions/requirement_analysis/framework/__init__.py index 5e0653088..968effd86 100644 --- a/metagpt/actions/requirement_analysis/framework/__init__.py +++ b/metagpt/actions/requirement_analysis/framework/__init__.py @@ -16,7 +16,7 @@ from pydantic import BaseModel from metagpt.actions.requirement_analysis.framework.evaluate_framework import EvaluateFramework from metagpt.actions.requirement_analysis.framework.write_framework import WriteFramework -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.utils.common import awrite @@ -54,7 +54,7 @@ async def save_framework( output_dir = ( Path(output_dir) if output_dir - else Config.default().workspace.path / (datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:8]) + else config.workspace.path / (datetime.now().strftime("%Y%m%d%H%M%ST") + uuid.uuid4().hex[0:8]) ) output_dir.mkdir(parents=True, exist_ok=True) diff --git a/metagpt/config2.py b/metagpt/config2.py index 1942e5ef5..84b6059ef 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -178,3 +178,4 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict: _CONFIG_CACHE = {} +config = Config.default() diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index d49c13e95..bb285d31c 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, TypeVar from pydantic import BaseModel, ConfigDict, model_validator -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder from metagpt.exp_pool.manager import ExperienceManager, get_exp_manager from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge @@ -60,8 +60,6 @@ def exp_cache( def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: @functools.wraps(func) async def get_or_create(args: Any, kwargs: Any) -> ReturnType: - config = Config.default() - if not config.exp_pool.enabled: rsp = func(*args, **kwargs) return await rsp if asyncio.iscoroutine(rsp) else rsp diff --git a/metagpt/ext/android_assistant/actions/manual_record.py b/metagpt/ext/android_assistant/actions/manual_record.py index 71e4d5e82..bcfb2ed89 100644 --- a/metagpt/ext/android_assistant/actions/manual_record.py +++ b/metagpt/ext/android_assistant/actions/manual_record.py @@ -7,7 +7,7 @@ from pathlib import Path import cv2 from metagpt.actions.action import Action -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -55,7 +55,6 @@ class ManualRecord(Action): self.task_desc_path.write_text(task_desc) step = 0 - config = Config.default() extra_config = config.extra while True: step += 1 diff --git a/metagpt/ext/android_assistant/actions/parse_record.py b/metagpt/ext/android_assistant/actions/parse_record.py index f96ad0a19..304daf655 100644 --- a/metagpt/ext/android_assistant/actions/parse_record.py +++ b/metagpt/ext/android_assistant/actions/parse_record.py @@ -8,7 +8,7 @@ import re from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.ext.android_assistant.actions.parse_record_an import RECORD_PARSE_NODE from metagpt.ext.android_assistant.prompts.operation_prompt import ( long_press_doc_template, @@ -45,7 +45,6 @@ class ParseRecord(Action): path.mkdir(parents=True, exist_ok=True) task_desc = self.task_desc_path.read_text() - config = Config.default() extra_config = config.extra with open(self.record_path, "r") as record_file: diff --git a/metagpt/ext/android_assistant/actions/screenshot_parse.py b/metagpt/ext/android_assistant/actions/screenshot_parse.py index 8cb738522..4d8bb0e1e 100644 --- a/metagpt/ext/android_assistant/actions/screenshot_parse.py +++ b/metagpt/ext/android_assistant/actions/screenshot_parse.py @@ -6,7 +6,7 @@ import ast from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -101,7 +101,6 @@ next action. You should always prioritize these documented elements for interact grid_on: bool, env: AndroidEnv, ): - config = Config.default() extra_config = config.extra for path in [task_dir, docs_dir]: path.mkdir(parents=True, exist_ok=True) diff --git a/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py index b783217ff..5e9cfbb45 100644 --- a/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py +++ b/metagpt/ext/android_assistant/actions/self_learn_and_reflect.py @@ -6,7 +6,7 @@ import ast from pathlib import Path from metagpt.actions.action import Action -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.environment.android.android_env import AndroidEnv from metagpt.environment.android.const import ADB_EXEC_FAIL from metagpt.environment.android.env_space import ( @@ -80,7 +80,6 @@ class SelfLearnAndReflect(Action): async def run_self_learn( self, round_count: int, task_desc: str, last_act: str, task_dir: Path, env: AndroidEnv ) -> AndroidActionOutput: - config = Config.default() extra_config = config.extra screenshot_path: Path = env.observe( EnvObsParams(obs_type=EnvObsType.GET_SCREENSHOT, ss_name=f"{round_count}_before", local_save_dir=task_dir) diff --git a/metagpt/ext/android_assistant/roles/android_assistant.py b/metagpt/ext/android_assistant/roles/android_assistant.py index 6462e30a2..97d66d30e 100644 --- a/metagpt/ext/android_assistant/roles/android_assistant.py +++ b/metagpt/ext/android_assistant/roles/android_assistant.py @@ -9,7 +9,7 @@ from typing import Optional from pydantic import Field from metagpt.actions.add_requirement import UserRequirement -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import EXAMPLE_PATH from metagpt.ext.android_assistant.actions.manual_record import ManualRecord from metagpt.ext.android_assistant.actions.parse_record import ParseRecord @@ -38,7 +38,6 @@ class AndroidAssistant(Role): def __init__(self, **data): super().__init__(**data) - config = Config.default() self._watch([UserRequirement, AndroidActionOutput]) extra_config = config.extra self.task_desc = extra_config.get("task_desc", "Just explore any app in this phone!") diff --git a/metagpt/ext/android_assistant/utils/utils.py b/metagpt/ext/android_assistant/utils/utils.py index 168369054..f1fa13869 100644 --- a/metagpt/ext/android_assistant/utils/utils.py +++ b/metagpt/ext/android_assistant/utils/utils.py @@ -10,7 +10,7 @@ from xml.etree.ElementTree import Element, iterparse import cv2 import pyshine as ps -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.ext.android_assistant.utils.schema import ( ActionOp, AndroidElement, @@ -48,7 +48,6 @@ def get_id_from_element(elem: Element) -> str: def traverse_xml_tree(xml_path: Path, elem_list: list[AndroidElement], attrib: str, add_index=False): path = [] - config = Config.default() extra_config = config.extra for event, elem in iterparse(str(xml_path), ["start", "end"]): if event == "start": diff --git a/metagpt/ext/stanford_town/utils/utils.py b/metagpt/ext/stanford_town/utils/utils.py index b4e15f485..4e81298c9 100644 --- a/metagpt/ext/stanford_town/utils/utils.py +++ b/metagpt/ext/stanford_town/utils/utils.py @@ -13,7 +13,7 @@ from typing import Union from openai import OpenAI -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.logs import logger @@ -48,7 +48,6 @@ def read_csv_to_list(curr_file: str, header=False, strip_trail=True): def get_embedding(text, model: str = "text-embedding-ada-002"): - config = Config.default() text = text.replace("\n", " ") embedding = None if not text: diff --git a/metagpt/rag/engines/simple.py b/metagpt/rag/engines/simple.py index 1c0834c96..61200a295 100644 --- a/metagpt/rag/engines/simple.py +++ b/metagpt/rag/engines/simple.py @@ -31,7 +31,7 @@ from llama_index.core.schema import ( TransformComponent, ) -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.rag.factories import ( get_index, get_rag_embedding, @@ -400,7 +400,6 @@ class SimpleEngine(RetrieverQueryEngine): dict[file_type: BaseReader] """ file_extractor: dict[str:BaseReader] = {} - config = Config.default() if config.omniparse.base_url: pdf_parser = OmniParse( api_key=config.omniparse.api_key, diff --git a/metagpt/rag/factories/llm.py b/metagpt/rag/factories/llm.py index e936e3a45..36b17da36 100644 --- a/metagpt/rag/factories/llm.py +++ b/metagpt/rag/factories/llm.py @@ -12,7 +12,7 @@ from llama_index.core.llms import ( from llama_index.core.llms.callbacks import llm_completion_callback from pydantic import Field -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.provider.base_llm import BaseLLM from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.token_counter import TOKEN_MAX @@ -41,7 +41,6 @@ class RAGLLM(CustomLLM): **kwargs ): super().__init__(*args, **kwargs) - config = Config.default() if context_window < 0: context_window = TOKEN_MAX.get(config.llm.model, DEFAULT_CONTEXT_WINDOW) diff --git a/metagpt/rag/schema.py b/metagpt/rag/schema.py index 12e635a1a..5c63b09df 100644 --- a/metagpt/rag/schema.py +++ b/metagpt/rag/schema.py @@ -11,7 +11,7 @@ from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import VectorStoreQueryMode from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.configs.embedding_config import EmbeddingType from metagpt.logs import logger from metagpt.rag.interface import RAGObject @@ -47,7 +47,6 @@ class FAISSRetrieverConfig(IndexRetrieverConfig): @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: - config = Config.default() self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) @@ -89,7 +88,6 @@ class MilvusRetrieverConfig(IndexRetrieverConfig): @model_validator(mode="after") def check_dimensions(self): if self.dimensions == 0: - config = Config.default() self.dimensions = config.embedding.dimensions or self._embedding_type_to_dimensions.get( config.embedding.api_type, 1536 ) diff --git a/metagpt/software_company.py b/metagpt/software_company.py index 508a6a5f3..73dcdd5db 100644 --- a/metagpt/software_company.py +++ b/metagpt/software_company.py @@ -27,7 +27,7 @@ def generate_repo( recover_path=None, ): """Run the startup logic. Can be called from CLI or other Python scripts.""" - from metagpt.config2 import Config + from metagpt.config2 import config from metagpt.context import Context from metagpt.roles import ( Architect, @@ -38,8 +38,6 @@ def generate_repo( ) from metagpt.team import Team - config = Config.default() - config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) ctx = Context(config=config) diff --git a/metagpt/tools/libs/index_repo.py b/metagpt/tools/libs/index_repo.py index 4c4e6c59b..5f95f9c02 100644 --- a/metagpt/tools/libs/index_repo.py +++ b/metagpt/tools/libs/index_repo.py @@ -11,7 +11,7 @@ from llama_index.core.base.embeddings.base import BaseEmbedding from llama_index.core.schema import NodeWithScore from pydantic import BaseModel, Field, model_validator -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.context import Context from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine @@ -142,7 +142,6 @@ class IndexRepo(BaseModel): return flat_nodes if not self.embedding: - config = Config.default() if self.model: config.embedding.model = self.model factory = RAGEmbeddingFactory(config) diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 9e67a3585..243871aff 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.provider.openai_api import OpenAILLM as GPTAPI from metagpt.utils.common import awrite @@ -282,7 +282,6 @@ class UTGenerator: """Choose based on different calling methods""" result = "" if self.chatgpt_method == "API": - config = Config.default() result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages) return result diff --git a/metagpt/utils/embedding.py b/metagpt/utils/embedding.py index 3fcf1f25b..3d53a314c 100644 --- a/metagpt/utils/embedding.py +++ b/metagpt/utils/embedding.py @@ -7,11 +7,10 @@ """ from llama_index.embeddings.openai import OpenAIEmbedding -from metagpt.config2 import Config +from metagpt.config2 import config def get_embedding() -> OpenAIEmbedding: - config = Config.default() llm = config.get_openai_llm() if llm is None: raise ValueError("To use OpenAIEmbedding, please ensure that config.llm.api_type is correctly set to 'openai'.") diff --git a/metagpt/utils/file.py b/metagpt/utils/file.py index 75107c8be..d4cfc4d0a 100644 --- a/metagpt/utils/file.py +++ b/metagpt/utils/file.py @@ -13,7 +13,7 @@ from typing import Optional, Tuple, Union import aiofiles from fsspec.implementations.memory import MemoryFileSystem as _MemoryFileSystem -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.logs import logger from metagpt.utils import read_docx from metagpt.utils.common import aread, aread_bin, awrite_bin, check_http_endpoint @@ -190,7 +190,6 @@ class File: @staticmethod async def _read_omniparse_config() -> Tuple[str, int]: - config = Config.default() if config.omniparse and config.omniparse.base_url: return config.omniparse.base_url, config.omniparse.timeout return "", 0 diff --git a/metagpt/utils/make_sk_kernel.py b/metagpt/utils/make_sk_kernel.py index f0c55b07c..283a682d6 100644 --- a/metagpt/utils/make_sk_kernel.py +++ b/metagpt/utils/make_sk_kernel.py @@ -13,11 +13,10 @@ from semantic_kernel.connectors.ai.open_ai.services.open_ai_chat_completion impo OpenAIChatCompletion, ) -from metagpt.config2 import Config +from metagpt.config2 import config def make_sk_kernel(): - config = Config.default() kernel = sk.Kernel() if llm := config.get_azure_llm(): kernel.add_chat_service( From f3c41b6fb5b72b5f687b0b50ee9386178d24afa2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 16:51:39 +0800 Subject: [PATCH 6/9] update default config int --- examples/android_assistant/run_assistant.py | 3 +- tests/metagpt/learn/test_text_to_embedding.py | 3 +- tests/metagpt/learn/test_text_to_image.py | 4 +- tests/metagpt/learn/test_text_to_speech.py | 4 +- .../roles/di/run_swe_agent_for_benchmark.py | 3 +- tests/metagpt/test_document.py | 4 +- tests/metagpt/tools/test_azure_tts.py | 4 +- tests/metagpt/tools/test_iflytek_tts.py | 3 +- .../tools/test_metagpt_text_to_image.py | 4 +- tests/metagpt/tools/test_moderation.py | 4 +- .../tools/test_openai_text_to_embedding.py | 3 +- .../tools/test_openai_text_to_image.py | 4 +- tests/metagpt/tools/test_ut_writer.py | 4 +- .../utils/test_repair_llm_raw_output.py | 4 +- tests/metagpt/utils/test_sanitize.py | 246 ++++++++++++++++++ tests/mock/mock_llm.py | 4 +- 16 files changed, 261 insertions(+), 40 deletions(-) create mode 100644 tests/metagpt/utils/test_sanitize.py diff --git a/examples/android_assistant/run_assistant.py b/examples/android_assistant/run_assistant.py index dbd1dc6ff..7d5d4d5c8 100644 --- a/examples/android_assistant/run_assistant.py +++ b/examples/android_assistant/run_assistant.py @@ -9,7 +9,7 @@ from pathlib import Path import typer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.environment.android.android_env import AndroidEnv from metagpt.ext.android_assistant.roles.android_assistant import AndroidAssistant from metagpt.team import Team @@ -41,7 +41,6 @@ def startup( ), device_id: str = typer.Option(default="emulator-5554", help="The Android device_id"), ): - config = Config.default() config.extra = { "stage": stage, "mode": mode, diff --git a/tests/metagpt/learn/test_text_to_embedding.py b/tests/metagpt/learn/test_text_to_embedding.py index f50f6a7aa..3b5486c5d 100644 --- a/tests/metagpt/learn/test_text_to_embedding.py +++ b/tests/metagpt/learn/test_text_to_embedding.py @@ -11,7 +11,7 @@ from pathlib import Path import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_embedding import text_to_embedding from metagpt.utils.common import aread @@ -19,7 +19,6 @@ from metagpt.utils.common import aread @pytest.mark.asyncio async def test_text_to_embedding(mocker): # mock - config = Config.default() mock_post = mocker.patch("aiohttp.ClientSession.post") mock_response = mocker.AsyncMock() mock_response.status = 200 diff --git a/tests/metagpt/learn/test_text_to_image.py b/tests/metagpt/learn/test_text_to_image.py index d3272dadd..eb252589b 100644 --- a/tests/metagpt/learn/test_text_to_image.py +++ b/tests/metagpt/learn/test_text_to_image.py @@ -12,7 +12,7 @@ import openai import pytest from pydantic import BaseModel -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_image import text_to_image from metagpt.tools.metagpt_text_to_image import MetaGPTText2Image from metagpt.tools.openai_text_to_image import OpenAIText2Image @@ -26,7 +26,6 @@ async def test_text_to_image(mocker): mocker.patch.object(OpenAIText2Image, "text_2_image", return_value=b"mock OpenAIText2Image") mocker.patch.object(S3, "cache", return_value="http://mock/s3") - config = Config.default() assert config.metagpt_tti_url data = await text_to_image("Panda emoji", size_type="512x512", config=config) @@ -51,7 +50,6 @@ async def test_openai_text_to_image(mocker): mock_post.return_value.__aenter__.return_value = mock_response mocker.patch.object(S3, "cache", return_value="http://mock.s3.com/0.png") - config = Config.default() config.metagpt_tti_url = None assert config.get_openai_llm() diff --git a/tests/metagpt/learn/test_text_to_speech.py b/tests/metagpt/learn/test_text_to_speech.py index f01e5d132..480e35f7a 100644 --- a/tests/metagpt/learn/test_text_to_speech.py +++ b/tests/metagpt/learn/test_text_to_speech.py @@ -10,7 +10,7 @@ import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.learn.text_to_speech import text_to_speech from metagpt.tools.iflytek_tts import IFlyTekTTS from metagpt.utils.s3 import S3 @@ -19,7 +19,6 @@ from metagpt.utils.s3 import S3 @pytest.mark.asyncio async def test_azure_text_to_speech(mocker): # mock - config = Config.default() config.iflytek_api_key = None config.iflytek_api_secret = None config.iflytek_app_id = None @@ -47,7 +46,6 @@ async def test_azure_text_to_speech(mocker): @pytest.mark.asyncio async def test_iflytek_text_to_speech(mocker): # mock - config = Config.default() config.azure_tts_subscription_key = None config.azure_tts_region = None mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) diff --git a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py index 5ceba6dcc..ce4ef94a4 100644 --- a/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py +++ b/tests/metagpt/roles/di/run_swe_agent_for_benchmark.py @@ -7,7 +7,7 @@ import sys from datetime import datetime from pathlib import Path -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import DEFAULT_WORKSPACE_ROOT, METAGPT_ROOT from metagpt.logs import logger from metagpt.roles.di.engineer2 import Engineer2 @@ -15,7 +15,6 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.libs.terminal import Terminal from metagpt.tools.swe_agent_commands.swe_agent_utils import load_hf_dataset -config = Config.default() # Specify by yourself TEST_REPO_DIR = METAGPT_ROOT / "data" / "test_repo" DATA_DIR = METAGPT_ROOT / "data/hugging_face" diff --git a/tests/metagpt/test_document.py b/tests/metagpt/test_document.py index 29393bb13..9c076f4e6 100644 --- a/tests/metagpt/test_document.py +++ b/tests/metagpt/test_document.py @@ -5,12 +5,10 @@ @Author : alexanderwu @File : test_document.py """ -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.document import Repo from metagpt.logs import logger -config = Config.default() - def set_existing_repo(path): repo1 = Repo.from_path(path) diff --git a/tests/metagpt/tools/test_azure_tts.py b/tests/metagpt/tools/test_azure_tts.py index ee55616d2..f72b5663b 100644 --- a/tests/metagpt/tools/test_azure_tts.py +++ b/tests/metagpt/tools/test_azure_tts.py @@ -12,11 +12,9 @@ from pathlib import Path import pytest from azure.cognitiveservices.speech import ResultReason, SpeechSynthesizer -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.azure_tts import AzureTTS -config = Config.default() - @pytest.mark.asyncio async def test_azure_tts(mocker): diff --git a/tests/metagpt/tools/test_iflytek_tts.py b/tests/metagpt/tools/test_iflytek_tts.py index c51f62b8e..b4bcadb89 100644 --- a/tests/metagpt/tools/test_iflytek_tts.py +++ b/tests/metagpt/tools/test_iflytek_tts.py @@ -7,14 +7,13 @@ """ import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.iflytek_tts import IFlyTekTTS, oas3_iflytek_tts @pytest.mark.asyncio async def test_iflytek_tts(mocker): # mock - config = Config.default() config.azure_tts_subscription_key = None config.azure_tts_region = None mocker.patch.object(IFlyTekTTS, "synthesize_speech", return_value=None) diff --git a/tests/metagpt/tools/test_metagpt_text_to_image.py b/tests/metagpt/tools/test_metagpt_text_to_image.py index bd0fcaf8b..d3797a460 100644 --- a/tests/metagpt/tools/test_metagpt_text_to_image.py +++ b/tests/metagpt/tools/test_metagpt_text_to_image.py @@ -10,11 +10,9 @@ from unittest.mock import AsyncMock import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.metagpt_text_to_image import oas3_metagpt_text_to_image -config = Config.default() - @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 0f921887f..8dc9e9d5e 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -8,12 +8,10 @@ import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.llm import LLM from metagpt.tools.moderation import Moderation -config = Config.default() - @pytest.mark.asyncio @pytest.mark.parametrize( diff --git a/tests/metagpt/tools/test_openai_text_to_embedding.py b/tests/metagpt/tools/test_openai_text_to_embedding.py index 81b3895c3..047206d48 100644 --- a/tests/metagpt/tools/test_openai_text_to_embedding.py +++ b/tests/metagpt/tools/test_openai_text_to_embedding.py @@ -10,7 +10,7 @@ from pathlib import Path import pytest -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.tools.openai_text_to_embedding import oas3_openai_text_to_embedding from metagpt.utils.common import aread @@ -18,7 +18,6 @@ from metagpt.utils.common import aread @pytest.mark.asyncio async def test_embedding(mocker): # mock - config = Config.default() mock_post = mocker.patch("aiohttp.ClientSession.post") mock_response = mocker.AsyncMock() mock_response.status = 200 diff --git a/tests/metagpt/tools/test_openai_text_to_image.py b/tests/metagpt/tools/test_openai_text_to_image.py index 4856342d1..3f9169ddd 100644 --- a/tests/metagpt/tools/test_openai_text_to_image.py +++ b/tests/metagpt/tools/test_openai_text_to_image.py @@ -11,7 +11,7 @@ import openai import pytest from pydantic import BaseModel -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.llm import LLM from metagpt.tools.openai_text_to_image import ( OpenAIText2Image, @@ -19,8 +19,6 @@ from metagpt.tools.openai_text_to_image import ( ) from metagpt.utils.s3 import S3 -config = Config.default() - @pytest.mark.asyncio async def test_draw(mocker): diff --git a/tests/metagpt/tools/test_ut_writer.py b/tests/metagpt/tools/test_ut_writer.py index ebb8c8aa2..557067191 100644 --- a/tests/metagpt/tools/test_ut_writer.py +++ b/tests/metagpt/tools/test_ut_writer.py @@ -20,12 +20,10 @@ from openai.types.chat.chat_completion_message_tool_call import ( Function, ) -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.const import API_QUESTIONS_PATH, UT_PY_PATH from metagpt.tools.ut_writer import YFT_PROMPT_PREFIX, UTGenerator -config = Config.default() - class TestUTWriter: @pytest.mark.asyncio diff --git a/tests/metagpt/utils/test_repair_llm_raw_output.py b/tests/metagpt/utils/test_repair_llm_raw_output.py index 75bd9f165..7a29ea3ee 100644 --- a/tests/metagpt/utils/test_repair_llm_raw_output.py +++ b/tests/metagpt/utils/test_repair_llm_raw_output.py @@ -2,9 +2,7 @@ # -*- coding: utf-8 -*- # @Desc : unittest of repair_llm_raw_output -from metagpt.config2 import Config - -config = Config.default() +from metagpt.config2 import config """ CONFIG.repair_llm_output should be True before retry_parse_json_text imported. diff --git a/tests/metagpt/utils/test_sanitize.py b/tests/metagpt/utils/test_sanitize.py new file mode 100644 index 000000000..c229af173 --- /dev/null +++ b/tests/metagpt/utils/test_sanitize.py @@ -0,0 +1,246 @@ +# -*- coding: utf-8 -*- +from unittest.mock import Mock, patch + +import pytest + +from metagpt.utils.sanitize import ( + NodeType, + code_extract, + get_definition_name, + get_deps, + get_function_dependency, + has_return_statement, + sanitize, + syntax_check, + traverse_tree, +) + + +@pytest.fixture +def mock_node(): + node = Mock() + node.type = "test_node" + node.text = b"test_text" + node.children = [] + return node + + +def test_node_type_enum(): + assert NodeType.CLASS.value == "class_definition" + assert NodeType.FUNCTION.value == "function_definition" + assert isinstance(NodeType.IMPORT.value, list) + + +@patch("tree_sitter.Node") +def test_traverse_tree(mock_node_class): + # 测试基本情况:没有子节点的情况 + root = Mock() + cursor = Mock() + cursor.node = root + cursor.goto_first_child.return_value = False + cursor.goto_next_sibling.return_value = False + cursor.goto_parent.return_value = False + root.walk.return_value = cursor + + nodes = list(traverse_tree(root)) + assert len(nodes) == 1 + assert nodes[0] == root + + # 测试有子节点和兄弟节点的情况 + cursor2 = Mock() + cursor2.node = Mock() + + # 模拟遍历行为 + first_child_calls = [True, False] + next_sibling_calls = [False] + parent_calls = [True, False] + + cursor2.goto_first_child.side_effect = lambda: first_child_calls.pop(0) if first_child_calls else False + cursor2.goto_next_sibling.side_effect = lambda: next_sibling_calls.pop(0) if next_sibling_calls else False + cursor2.goto_parent.side_effect = lambda: parent_calls.pop(0) if parent_calls else False + + root2 = Mock() + root2.walk.return_value = cursor2 + nodes = list(traverse_tree(root2)) + assert len(nodes) > 1 + + +def test_syntax_check(): + # 测试有效代码 + assert syntax_check("def test(): return True") is True + + # 测试无效代码 + assert syntax_check("def test() return True") is False + + # 测试无效代码(带verbose) + assert syntax_check("def test() return True", verbose=True) is False + + # 测试内存错误情况 + with patch("ast.parse", side_effect=MemoryError): + assert syntax_check("large_code", verbose=True) is False + + +def test_code_extract(): + # 测试基本情况 + text = "def valid_function():\n return True\n" + result = code_extract(text) + assert syntax_check(result) + assert "def valid_function" in result + + # 测试空字符串 + assert code_extract("") == "" + + # 测试单行有效语法 + single_line = "x = 1" + result = code_extract(single_line) + assert syntax_check(result) + assert "x = 1" in result + + # 测试完全无效的代码 + assert code_extract("invalid!!!!") == "" or code_extract("invalid!!!!") == "invalid!!!!" + + # 测试带有嵌套结构的有效代码 + nested_code = """def outer():\n def inner():\n return True\n""" + result = code_extract(nested_code) + assert syntax_check(result) + assert "def outer" in result + + +def test_get_definition_name(): + # 基本测试 + mock_identifier = Mock() + mock_identifier.type = NodeType.IDENTIFIER.value + mock_identifier.text = b"test_function" + + mock_node = Mock() + mock_node.children = [mock_identifier] + assert get_definition_name(mock_node) == "test_function" + + # 测试空children + mock_node.children = [] + assert get_definition_name(mock_node) is None + + # 测试children中没有identifier + mock_node.children = [Mock(type="not_identifier")] + assert get_definition_name(mock_node) is None + + +@pytest.mark.parametrize( + "node_type,expected", + [ + (NodeType.RETURN.value, True), + ("other_type", False), + ], +) +def test_has_return_statement(node_type, expected): + mock_node = Mock() + cursor = Mock() + cursor.node = Mock() + cursor.node.type = node_type + cursor.goto_first_child.return_value = False + cursor.goto_next_sibling.return_value = False + cursor.goto_parent.return_value = False + mock_node.walk.return_value = cursor + + assert has_return_statement(mock_node) is expected + + +def test_get_deps(): + mock_id1 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep1") + mock_id2 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep2") + mock_node = Mock(children=[mock_id1, mock_id2]) + + nodes = [("test_func", mock_node)] + result = get_deps(nodes) + + assert "test_func" in result + assert result["test_func"] == {"dep1", "dep2"} + + # 测试嵌套结构 + nested_node = Mock(children=[Mock(type="not_identifier", children=[mock_id1])]) + nodes = [("nested_func", nested_node)] + result = get_deps(nodes) + assert result["nested_func"] == {"dep1"} + + +def test_get_function_dependency(): + call_graph = {"main": {"helper1", "helper2"}, "helper1": {"helper3"}, "helper2": set(), "helper3": set()} + + result = get_function_dependency("main", call_graph) + assert result == {"main", "helper1", "helper2", "helper3"} + + assert get_function_dependency("non_existent", call_graph) == {"non_existent"} + + +@patch("tree_sitter.Parser") +@patch("tree_sitter.Language") +def test_sanitize(mock_language, mock_parser): + test_code = """import math +from os import path + +class TestClass: + def method(self): return True + +def test_function(): + return True + +x = 1""" + + mock_root = Mock() + mock_nodes = [] + + # 添加导入语句 + import_node = Mock(type="import_statement", start_byte=0, end_byte=11) + import_from_node = Mock(type="import_from_statement", start_byte=12, end_byte=30) + mock_nodes.extend([import_node, import_from_node]) + + # 添加类定义 + class_node = Mock(type="class_definition", start_byte=32, end_byte=80) + class_id = Mock(type="identifier", text=b"TestClass") + class_node.children = [class_id] + mock_nodes.append(class_node) + + # 添加函数定义 + func_node = Mock(type="function_definition", start_byte=82, end_byte=110) + func_id = Mock(type="identifier", text=b"test_function") + return_node = Mock(type="return_statement") + func_node.children = [func_id, return_node] + mock_nodes.append(func_node) + + # 添加赋值语句 + assign_node = Mock(type="expression_statement", start_byte=112, end_byte=117) + assign_child = Mock(type="assignment") + var_id = Mock(type="identifier", text=b"x") + assign_child.children = [var_id] + assign_node.children = [assign_child] + mock_nodes.append(assign_node) + + mock_root.children = mock_nodes + mock_tree = Mock(root_node=mock_root) + mock_parser.return_value.parse.return_value = mock_tree + + # 测试无entrypoint情况 + result = sanitize(test_code) + assert isinstance(result, str) + assert len(result) > 0 + + # 测试有entrypoint情况 + result = sanitize(test_code, entrypoint="test_function") + assert isinstance(result, str) + assert len(result) > 0 + + # 测试空代码 + assert sanitize("") == "" + + # 测试无效代码 + assert sanitize("invalid code") == "invalid!!!!" or sanitize("invalid code") == "" + + # 测试函数依赖 + mock_nodes = [func_node] # 只保留函数节点 + mock_root.children = mock_nodes + result = sanitize(test_code, entrypoint="test_function") + assert isinstance(result, str) + + +if __name__ == "__main__": + pytest.main(["-v"]) diff --git a/tests/mock/mock_llm.py b/tests/mock/mock_llm.py index e58ce4120..704403e64 100644 --- a/tests/mock/mock_llm.py +++ b/tests/mock/mock_llm.py @@ -1,7 +1,7 @@ import json from typing import Optional, Union -from metagpt.config2 import Config +from metagpt.config2 import config from metagpt.configs.llm_config import LLMType from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import logger @@ -10,8 +10,6 @@ from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message -config = Config.default() - OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM From b652b5f97d14d454a2c891584606bbb15d06f9c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 17:05:02 +0800 Subject: [PATCH 7/9] update default config int --- tests/metagpt/utils/test_sanitize.py | 246 --------------------------- 1 file changed, 246 deletions(-) delete mode 100644 tests/metagpt/utils/test_sanitize.py diff --git a/tests/metagpt/utils/test_sanitize.py b/tests/metagpt/utils/test_sanitize.py deleted file mode 100644 index c229af173..000000000 --- a/tests/metagpt/utils/test_sanitize.py +++ /dev/null @@ -1,246 +0,0 @@ -# -*- coding: utf-8 -*- -from unittest.mock import Mock, patch - -import pytest - -from metagpt.utils.sanitize import ( - NodeType, - code_extract, - get_definition_name, - get_deps, - get_function_dependency, - has_return_statement, - sanitize, - syntax_check, - traverse_tree, -) - - -@pytest.fixture -def mock_node(): - node = Mock() - node.type = "test_node" - node.text = b"test_text" - node.children = [] - return node - - -def test_node_type_enum(): - assert NodeType.CLASS.value == "class_definition" - assert NodeType.FUNCTION.value == "function_definition" - assert isinstance(NodeType.IMPORT.value, list) - - -@patch("tree_sitter.Node") -def test_traverse_tree(mock_node_class): - # 测试基本情况:没有子节点的情况 - root = Mock() - cursor = Mock() - cursor.node = root - cursor.goto_first_child.return_value = False - cursor.goto_next_sibling.return_value = False - cursor.goto_parent.return_value = False - root.walk.return_value = cursor - - nodes = list(traverse_tree(root)) - assert len(nodes) == 1 - assert nodes[0] == root - - # 测试有子节点和兄弟节点的情况 - cursor2 = Mock() - cursor2.node = Mock() - - # 模拟遍历行为 - first_child_calls = [True, False] - next_sibling_calls = [False] - parent_calls = [True, False] - - cursor2.goto_first_child.side_effect = lambda: first_child_calls.pop(0) if first_child_calls else False - cursor2.goto_next_sibling.side_effect = lambda: next_sibling_calls.pop(0) if next_sibling_calls else False - cursor2.goto_parent.side_effect = lambda: parent_calls.pop(0) if parent_calls else False - - root2 = Mock() - root2.walk.return_value = cursor2 - nodes = list(traverse_tree(root2)) - assert len(nodes) > 1 - - -def test_syntax_check(): - # 测试有效代码 - assert syntax_check("def test(): return True") is True - - # 测试无效代码 - assert syntax_check("def test() return True") is False - - # 测试无效代码(带verbose) - assert syntax_check("def test() return True", verbose=True) is False - - # 测试内存错误情况 - with patch("ast.parse", side_effect=MemoryError): - assert syntax_check("large_code", verbose=True) is False - - -def test_code_extract(): - # 测试基本情况 - text = "def valid_function():\n return True\n" - result = code_extract(text) - assert syntax_check(result) - assert "def valid_function" in result - - # 测试空字符串 - assert code_extract("") == "" - - # 测试单行有效语法 - single_line = "x = 1" - result = code_extract(single_line) - assert syntax_check(result) - assert "x = 1" in result - - # 测试完全无效的代码 - assert code_extract("invalid!!!!") == "" or code_extract("invalid!!!!") == "invalid!!!!" - - # 测试带有嵌套结构的有效代码 - nested_code = """def outer():\n def inner():\n return True\n""" - result = code_extract(nested_code) - assert syntax_check(result) - assert "def outer" in result - - -def test_get_definition_name(): - # 基本测试 - mock_identifier = Mock() - mock_identifier.type = NodeType.IDENTIFIER.value - mock_identifier.text = b"test_function" - - mock_node = Mock() - mock_node.children = [mock_identifier] - assert get_definition_name(mock_node) == "test_function" - - # 测试空children - mock_node.children = [] - assert get_definition_name(mock_node) is None - - # 测试children中没有identifier - mock_node.children = [Mock(type="not_identifier")] - assert get_definition_name(mock_node) is None - - -@pytest.mark.parametrize( - "node_type,expected", - [ - (NodeType.RETURN.value, True), - ("other_type", False), - ], -) -def test_has_return_statement(node_type, expected): - mock_node = Mock() - cursor = Mock() - cursor.node = Mock() - cursor.node.type = node_type - cursor.goto_first_child.return_value = False - cursor.goto_next_sibling.return_value = False - cursor.goto_parent.return_value = False - mock_node.walk.return_value = cursor - - assert has_return_statement(mock_node) is expected - - -def test_get_deps(): - mock_id1 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep1") - mock_id2 = Mock(type=NodeType.IDENTIFIER.value, text=b"dep2") - mock_node = Mock(children=[mock_id1, mock_id2]) - - nodes = [("test_func", mock_node)] - result = get_deps(nodes) - - assert "test_func" in result - assert result["test_func"] == {"dep1", "dep2"} - - # 测试嵌套结构 - nested_node = Mock(children=[Mock(type="not_identifier", children=[mock_id1])]) - nodes = [("nested_func", nested_node)] - result = get_deps(nodes) - assert result["nested_func"] == {"dep1"} - - -def test_get_function_dependency(): - call_graph = {"main": {"helper1", "helper2"}, "helper1": {"helper3"}, "helper2": set(), "helper3": set()} - - result = get_function_dependency("main", call_graph) - assert result == {"main", "helper1", "helper2", "helper3"} - - assert get_function_dependency("non_existent", call_graph) == {"non_existent"} - - -@patch("tree_sitter.Parser") -@patch("tree_sitter.Language") -def test_sanitize(mock_language, mock_parser): - test_code = """import math -from os import path - -class TestClass: - def method(self): return True - -def test_function(): - return True - -x = 1""" - - mock_root = Mock() - mock_nodes = [] - - # 添加导入语句 - import_node = Mock(type="import_statement", start_byte=0, end_byte=11) - import_from_node = Mock(type="import_from_statement", start_byte=12, end_byte=30) - mock_nodes.extend([import_node, import_from_node]) - - # 添加类定义 - class_node = Mock(type="class_definition", start_byte=32, end_byte=80) - class_id = Mock(type="identifier", text=b"TestClass") - class_node.children = [class_id] - mock_nodes.append(class_node) - - # 添加函数定义 - func_node = Mock(type="function_definition", start_byte=82, end_byte=110) - func_id = Mock(type="identifier", text=b"test_function") - return_node = Mock(type="return_statement") - func_node.children = [func_id, return_node] - mock_nodes.append(func_node) - - # 添加赋值语句 - assign_node = Mock(type="expression_statement", start_byte=112, end_byte=117) - assign_child = Mock(type="assignment") - var_id = Mock(type="identifier", text=b"x") - assign_child.children = [var_id] - assign_node.children = [assign_child] - mock_nodes.append(assign_node) - - mock_root.children = mock_nodes - mock_tree = Mock(root_node=mock_root) - mock_parser.return_value.parse.return_value = mock_tree - - # 测试无entrypoint情况 - result = sanitize(test_code) - assert isinstance(result, str) - assert len(result) > 0 - - # 测试有entrypoint情况 - result = sanitize(test_code, entrypoint="test_function") - assert isinstance(result, str) - assert len(result) > 0 - - # 测试空代码 - assert sanitize("") == "" - - # 测试无效代码 - assert sanitize("invalid code") == "invalid!!!!" or sanitize("invalid code") == "" - - # 测试函数依赖 - mock_nodes = [func_node] # 只保留函数节点 - mock_root.children = mock_nodes - result = sanitize(test_code, entrypoint="test_function") - assert isinstance(result, str) - - -if __name__ == "__main__": - pytest.main(["-v"]) From 11b0f649fabeccf2298ce2bffd9d281986a21677 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 17:20:45 +0800 Subject: [PATCH 8/9] pre-commit --- tests/metagpt/roles/di/test_swe_agent.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/metagpt/roles/di/test_swe_agent.py b/tests/metagpt/roles/di/test_swe_agent.py index 325e7bed8..3dbfa910b 100644 --- a/tests/metagpt/roles/di/test_swe_agent.py +++ b/tests/metagpt/roles/di/test_swe_agent.py @@ -1,22 +1,17 @@ import pytest +from metagpt.environment.mgx.mgx_env import MGXEnv from metagpt.roles.di.swe_agent import SWEAgent +from metagpt.roles.di.team_leader import TeamLeader from metagpt.schema import Message from metagpt.tools.libs.terminal import Bash -from metagpt.environment.mgx.mgx_env import MGXEnv -from metagpt.roles.di.team_leader import TeamLeader @pytest.fixture def env(): test_env = MGXEnv() tl = TeamLeader() - test_env.add_roles( - [ - tl, - SWEAgent() - ] - ) + test_env.add_roles([tl, SWEAgent()]) return test_env From 25016d2b2226d31de0c9c22c1293e7d8bb1496db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BB=BA=E7=94=9F?= Date: Fri, 28 Feb 2025 17:29:03 +0800 Subject: [PATCH 9/9] pre-commit --- examples/di/interacting_with_human.py | 1 + examples/write_design.py | 4 ++-- tests/metagpt/roles/test_architect.py | 5 +++-- tests/metagpt/roles/test_engineer.py | 4 ++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/di/interacting_with_human.py b/examples/di/interacting_with_human.py index c0ce02a40..a02f7c3bc 100644 --- a/examples/di/interacting_with_human.py +++ b/examples/di/interacting_with_human.py @@ -1,4 +1,5 @@ import fire + from metagpt.environment.mgx.mgx_env import MGXEnv from metagpt.logs import logger from metagpt.roles.di.team_leader import TeamLeader diff --git a/examples/write_design.py b/examples/write_design.py index 7eaa1a87b..fbbe9e1f0 100644 --- a/examples/write_design.py +++ b/examples/write_design.py @@ -1,10 +1,10 @@ import asyncio +from metagpt.environment.mgx.mgx_env import MGXEnv from metagpt.logs import logger from metagpt.roles.architect import Architect -from metagpt.environment.mgx.mgx_env import MGXEnv -from metagpt.schema import Message from metagpt.roles.di.team_leader import TeamLeader +from metagpt.schema import Message async def main(): diff --git a/tests/metagpt/roles/test_architect.py b/tests/metagpt/roles/test_architect.py index 9a834ffbe..cb636b6a1 100644 --- a/tests/metagpt/roles/test_architect.py +++ b/tests/metagpt/roles/test_architect.py @@ -8,18 +8,19 @@ distribution feature for message handling. """ import uuid +from pathlib import Path import pytest from metagpt.actions import WritePRD +from metagpt.actions.di.run_command import RunCommand from metagpt.const import PRDS_FILE_REPO from metagpt.logs import logger from metagpt.roles import Architect from metagpt.schema import Message from metagpt.utils.common import any_to_str, awrite from tests.metagpt.roles.mock import MockMessages -from pathlib import Path -from metagpt.actions.di.run_command import RunCommand + @pytest.mark.asyncio async def test_architect(context): diff --git a/tests/metagpt/roles/test_engineer.py b/tests/metagpt/roles/test_engineer.py index c340e910f..18b297ae5 100644 --- a/tests/metagpt/roles/test_engineer.py +++ b/tests/metagpt/roles/test_engineer.py @@ -9,6 +9,7 @@ """ import json from pathlib import Path +from types import SimpleNamespace import pytest @@ -19,9 +20,8 @@ from metagpt.roles.engineer import Engineer from metagpt.schema import CodingContext, Message from metagpt.utils.common import CodeParser, any_to_name, any_to_str, aread, awrite from metagpt.utils.git_repository import ChangeType -from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages from metagpt.utils.project_repo import ProjectRepo -from types import SimpleNamespace +from tests.metagpt.roles.mock import STRS_FOR_PARSING, TASKS, MockMessages @pytest.mark.asyncio