diff --git a/metagpt/config2.py b/metagpt/config2.py index 92dd98bad..5a556cc52 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -38,6 +38,7 @@ class CLIParams(BaseModel): if self.project_path: self.inc = True self.project_name = self.project_name or Path(self.project_path).name + return self class Config(CLIParams, YamlModel): diff --git a/metagpt/context.py b/metagpt/context.py index 8e9749d66..3dfd52d58 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -95,7 +95,3 @@ class Context(BaseModel): if llm.cost_manager is None: llm.cost_manager = self.cost_manager return llm - - -# Global context, not in Env -CONTEXT = Context() diff --git a/metagpt/context_mixin.py b/metagpt/context_mixin.py index 1d239d2e4..bdf2d0734 100644 --- a/metagpt/context_mixin.py +++ b/metagpt/context_mixin.py @@ -10,7 +10,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict, Field from metagpt.config2 import Config -from metagpt.context import CONTEXT, Context +from metagpt.context import Context from metagpt.provider.base_llm import BaseLLM @@ -34,7 +34,7 @@ class ContextMixin(BaseModel): def __init__( self, - context: Optional[Context] = CONTEXT, + context: Optional[Context] = None, config: Optional[Config] = None, llm: Optional[BaseLLM] = None, **kwargs, @@ -81,7 +81,7 @@ class ContextMixin(BaseModel): """Role context: role context > context""" if self.private_context: return self.private_context - return CONTEXT + return Context() @context.setter def context(self, context: Context) -> None: diff --git a/metagpt/learn/skill_loader.py b/metagpt/learn/skill_loader.py index ddcd7ccba..bcf28bb87 100644 --- a/metagpt/learn/skill_loader.py +++ b/metagpt/learn/skill_loader.py @@ -13,7 +13,7 @@ import aiofiles import yaml from pydantic import BaseModel, Field -from metagpt.context import CONTEXT, Context +from metagpt.context import Context class Example(BaseModel): @@ -73,14 +73,15 @@ class SkillsDeclaration(BaseModel): skill_data = yaml.safe_load(data) return SkillsDeclaration(**skill_data) - def get_skill_list(self, entity_name: str = "Assistant", context: Context = CONTEXT) -> Dict: + def get_skill_list(self, entity_name: str = "Assistant", context: Context = None) -> Dict: """Return the skill name based on the skill description.""" entity = self.entities.get(entity_name) if not entity: return {} # List of skills that the agent chooses to activate. - agent_skills = context.kwargs.agent_skills + ctx = context or Context() + agent_skills = ctx.kwargs.agent_skills if not agent_skills: return {} diff --git a/metagpt/llm.py b/metagpt/llm.py index 30ced25d2..a3fc5613a 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -8,12 +8,13 @@ from typing import Optional from metagpt.configs.llm_config import LLMConfig -from metagpt.context import CONTEXT +from metagpt.context import Context from metagpt.provider.base_llm import BaseLLM -def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM: +def LLM(llm_config: Optional[LLMConfig] = None, context: Context = None) -> BaseLLM: """get the default llm provider if name is None""" + ctx = context or Context() if llm_config is not None: - CONTEXT.llm_with_cost_manager_from_llm_config(llm_config) - return CONTEXT.llm() + ctx.llm_with_cost_manager_from_llm_config(llm_config) + return ctx.llm() diff --git a/metagpt/roles/assistant.py b/metagpt/roles/assistant.py index 2e9ec9bf7..2774bd9b6 100644 --- a/metagpt/roles/assistant.py +++ b/metagpt/roles/assistant.py @@ -22,7 +22,6 @@ from pydantic import Field from metagpt.actions.skill_action import ArgumentsParingAction, SkillAction from metagpt.actions.talk_action import TalkAction -from metagpt.context import CONTEXT from metagpt.learn.skill_loader import SkillsDeclaration from metagpt.logs import logger from metagpt.memory.brain_memory import BrainMemory @@ -48,7 +47,7 @@ class Assistant(Role): def __init__(self, **kwargs): super().__init__(**kwargs) - language = kwargs.get("language") or self.context.kwargs.language or CONTEXT.kwargs.language + language = kwargs.get("language") or self.context.kwargs.language self.constraints = self.constraints.format(language=language) async def think(self) -> bool: diff --git a/metagpt/startup.py b/metagpt/startup.py index 771cde80c..000b3c5d4 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -8,6 +8,7 @@ import typer from metagpt.config2 import config from metagpt.const import CONFIG_ROOT, METAGPT_ROOT +from metagpt.context import Context app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) @@ -37,9 +38,10 @@ def generate_repo( from metagpt.team import Team config.update_via_cli(project_path, project_name, inc, reqa_file, max_auto_summarize_code) + ctx = Context(config=config) if not recover_path: - company = Team() + company = Team(context=ctx) company.hire( [ ProductManager(), @@ -58,7 +60,7 @@ def generate_repo( if not stg_path.exists() or not str(stg_path).endswith("team"): raise FileNotFoundError(f"{recover_path} not exists or not endswith `team`") - company = Team.deserialize(stg_path=stg_path) + company = Team.deserialize(stg_path=stg_path, context=ctx) idea = company.idea company.invest(investment) diff --git a/metagpt/team.py b/metagpt/team.py index aec72970b..35f987b57 100644 --- a/metagpt/team.py +++ b/metagpt/team.py @@ -10,12 +10,13 @@ import warnings from pathlib import Path -from typing import Any +from typing import Any, Optional from pydantic import BaseModel, ConfigDict, Field from metagpt.actions import UserRequirement from metagpt.const import MESSAGE_ROUTE_TO_ALL, SERDESER_PATH +from metagpt.context import Context from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Role @@ -36,12 +37,17 @@ class Team(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - env: Environment = Field(default_factory=Environment) + env: Optional[Environment] = None investment: float = Field(default=10.0) idea: str = Field(default="") - def __init__(self, **data: Any): + def __init__(self, context: Context = None, **data: Any): super(Team, self).__init__(**data) + ctx = context or Context() + if not self.env: + self.env = Environment(context=ctx) + else: + self.env.context = ctx # The `env` object is allocated by deserialization if "roles" in data: self.hire(data["roles"]) if "env_desc" in data: @@ -54,7 +60,7 @@ class Team(BaseModel): write_json_file(team_info_path, self.model_dump()) @classmethod - def deserialize(cls, stg_path: Path) -> "Team": + def deserialize(cls, stg_path: Path, context: Context = None) -> "Team": """stg_path = ./storage/team""" # recover team_info team_info_path = stg_path.joinpath("team.json") @@ -64,7 +70,8 @@ class Team(BaseModel): ) team_info: dict = read_json_file(team_info_path) - team = Team(**team_info) + ctx = context or Context() + team = Team(**team_info, context=ctx) return team def hire(self, roles: list[Role]): diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index e09d49d84..f1bd1a8e5 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -28,7 +28,7 @@ from typing import Any, List, Tuple, Union import aiofiles import loguru from pydantic_core import to_jsonable_python -from tenacity import RetryCallState, _utils +from tenacity import RetryCallState, RetryError, _utils from metagpt.const import MESSAGE_ROUTE_TO_ALL from metagpt.logs import logger @@ -505,7 +505,7 @@ def role_raise_decorator(func): self.rc.memory.delete(self.latest_observed_msg) # raise again to make it captured outside raise Exception(format_trackback_info(limit=None)) - except Exception: + except Exception as e: if self.latest_observed_msg: logger.warning( "There is a exception in role's execution, in order to resume, " @@ -514,6 +514,12 @@ def role_raise_decorator(func): # remove role newest observed msg to make it observed again self.rc.memory.delete(self.latest_observed_msg) # raise again to make it captured outside + if isinstance(e, RetryError): + last_error = e.last_attempt._exception + name = any_to_str(last_error) + if re.match(r"^openai\.", name) or re.match(r"^httpx\.", name): + raise last_error + raise Exception(format_trackback_info(limit=None)) return wrapper diff --git a/tests/data/audio/hello.mp3 b/tests/data/audio/hello.mp3 new file mode 100644 index 000000000..7b3aab0a4 Binary files /dev/null and b/tests/data/audio/hello.mp3 differ diff --git a/tests/metagpt/serialize_deserialize/test_team.py b/tests/metagpt/serialize_deserialize/test_team.py index d45c8cf21..dbd38422d 100644 --- a/tests/metagpt/serialize_deserialize/test_team.py +++ b/tests/metagpt/serialize_deserialize/test_team.py @@ -144,3 +144,7 @@ async def test_team_recover_multi_roles_save(mocker, context): assert new_company.env.get_role(role_b.profile).rc.state == 1 await new_company.run(n_round=4) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_context.py b/tests/metagpt/test_context.py index d90d0b686..f8218c44d 100644 --- a/tests/metagpt/test_context.py +++ b/tests/metagpt/test_context.py @@ -6,7 +6,7 @@ @File : test_context.py """ from metagpt.configs.llm_config import LLMType -from metagpt.context import CONTEXT, AttrDict, Context +from metagpt.context import AttrDict, Context def test_attr_dict_1(): @@ -51,11 +51,12 @@ def test_context_1(): def test_context_2(): - llm = CONTEXT.config.get_openai_llm() + ctx = Context() + llm = ctx.config.get_openai_llm() assert llm is not None assert llm.api_type == LLMType.OPENAI - kwargs = CONTEXT.kwargs + kwargs = ctx.kwargs assert kwargs is not None kwargs.test_key = "test_value" diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index 10839a2a5..7559655d3 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -11,7 +11,6 @@ from pathlib import Path import pytest from metagpt.actions import UserRequirement -from metagpt.context import CONTEXT from metagpt.environment import Environment from metagpt.logs import logger from metagpt.roles import Architect, ProductManager, Role @@ -44,9 +43,9 @@ def test_get_roles(env: Environment): @pytest.mark.asyncio async def test_publish_and_process_message(env: Environment): - if CONTEXT.git_repo: - CONTEXT.git_repo.delete_repository() - CONTEXT.git_repo = None + if env.context.git_repo: + env.context.git_repo.delete_repository() + env.context.git_repo = None product_manager = ProductManager(name="Alice", profile="Product Manager", goal="做AI Native产品", constraints="资源有限") architect = Architect(