From 312d327d55d56c437cd3e7863d5f3d71389046b5 Mon Sep 17 00:00:00 2001 From: geekan Date: Thu, 11 Jan 2024 22:30:26 +0800 Subject: [PATCH] test for mixin --- metagpt/config2.py | 7 ++++- metagpt/const.py | 2 +- metagpt/roles/role.py | 7 +++-- metagpt/startup.py | 4 +-- tests/metagpt/provider/test_spark_api.py | 3 +- tests/metagpt/test_context_mixin.py | 39 ++++++++++++++++-------- 6 files changed, 41 insertions(+), 21 deletions(-) diff --git a/metagpt/config2.py b/metagpt/config2.py index c916b9b60..2d4ac0930 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -18,7 +18,7 @@ from metagpt.configs.redis_config import RedisConfig from metagpt.configs.s3_config import S3Config from metagpt.configs.search_config import SearchConfig from metagpt.configs.workspace_config import WorkspaceConfig -from metagpt.const import METAGPT_ROOT +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT from metagpt.utils.yaml_model import YamlModel @@ -81,6 +81,11 @@ class Config(CLIParams, YamlModel): AZURE_TTS_REGION: str = "" mermaid_engine: str = "nodejs" + @classmethod + def from_home(cls, path): + """Load config from ~/.metagpt/config.yaml""" + return Config.model_validate_yaml(CONFIG_ROOT / path) + @classmethod def default(cls): """Load default config diff --git a/metagpt/const.py b/metagpt/const.py index 811ff9516..8e89b0526 100644 --- a/metagpt/const.py +++ b/metagpt/const.py @@ -47,7 +47,7 @@ def get_metagpt_root(): # METAGPT PROJECT ROOT AND VARS - +CONFIG_ROOT = Path.home() / ".metagpt" METAGPT_ROOT = get_metagpt_root() # Dependent on METAGPT_PROJECT_ROOT DEFAULT_WORKSPACE_ROOT = METAGPT_ROOT / "workspace" diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index cc432d81f..6e05937a7 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -155,6 +155,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): if self.is_human: self.llm = HumanProvider(None) + self._check_actions() self.llm.system_prompt = self._get_prefix() self._watch(data.get("watch") or [UserRequirement]) @@ -229,14 +230,16 @@ class Role(SerializationMixin, ContextMixin, BaseModel): def _setting(self): return f"{self.name}({self.profile})" - @model_validator(mode="after") def _check_actions(self): """Check actions and set llm and prefix for each action.""" self.set_actions(self.actions) return self def _init_action(self, action: Action): - action.set_llm(self.llm, override=False) + if not action.private_config: + action.set_llm(self.llm, override=True) + else: + action.set_llm(self.llm, override=False) action.set_prefix(self._get_prefix()) def set_action(self, action: Action): diff --git a/metagpt/startup.py b/metagpt/startup.py index 14092edd2..771cde80c 100644 --- a/metagpt/startup.py +++ b/metagpt/startup.py @@ -7,7 +7,7 @@ from pathlib import Path import typer from metagpt.config2 import config -from metagpt.const import METAGPT_ROOT +from metagpt.const import CONFIG_ROOT, METAGPT_ROOT app = typer.Typer(add_completion=False, pretty_exceptions_show_locals=False) @@ -118,7 +118,7 @@ def startup( def copy_config_to(config_path=METAGPT_ROOT / "config" / "config2.yaml"): """Initialize the configuration file for MetaGPT.""" - target_path = Path.home() / ".metagpt" / "config2.yaml" + target_path = CONFIG_ROOT / "config2.yaml" # 创建目标目录(如果不存在) target_path.parent.mkdir(parents=True, exist_ok=True) diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 2cb6bf559..f5a6f66fd 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,7 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : the unittest of spark api -from pathlib import Path import pytest @@ -37,7 +36,7 @@ def mock_spark_get_msg_from_web_run(self) -> str: @pytest.mark.asyncio async def test_spark_aask(): - llm = SparkLLM(Config.model_validate_yaml(Path.home() / ".metagpt" / "spark.yaml").llm) + llm = SparkLLM(Config.from_home("spark.yaml").llm) resp = await llm.aask("Hello!") print(resp) diff --git a/tests/metagpt/test_context_mixin.py b/tests/metagpt/test_context_mixin.py index a1222c125..472d67a27 100644 --- a/tests/metagpt/test_context_mixin.py +++ b/tests/metagpt/test_context_mixin.py @@ -99,17 +99,30 @@ def test_config_mixin_4_multi_inheritance_override_config(): @pytest.mark.asyncio -async def test_debate_two_roles(): - config = Config.default() - config.llm.model = "gpt-4-1106-preview" - action1 = Action(config=config, name="AlexSay", instruction="Say your opinion with emotion and don't repeat it") - action2 = Action(name="BobSay", instruction="Say your opinion with emotion and don't repeat it") - alex = Role( - name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2] - ) - bob = Role(name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1]) - env = Environment(desc="US election live broadcast") - team = Team(investment=10.0, env=env, roles=[alex, bob]) +async def test_config_priority(): + """If action's config is set, then its llm will be set, otherwise, it will use the role's llm""" + gpt4t = Config.from_home("gpt-4-1106-preview.yaml") + gpt35 = Config.default() + gpt4 = Config.default() + gpt4.llm.model = "gpt-4-0613" - history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3) - assert "Alex" in history + a1 = Action(config=gpt4t, name="Say", instruction="Say your opinion with emotion and don't repeat it") + a2 = Action(name="Say", instruction="Say your opinion with emotion and don't repeat it") + a3 = Action(name="Vote", instruction="Vote for the candidate, and say why you vote for him/her") + + # it will not work for a1 because the config is already set + A = Role(name="A", profile="Democratic candidate", goal="Win the election", actions=[a1], watch=[a2], config=gpt4) + # it will work for a2 because the config is not set + B = Role(name="B", profile="Republican candidate", goal="Win the election", actions=[a2], watch=[a1], config=gpt4) + # ditto + C = Role(name="C", profile="Voter", goal="Vote for the candidate", actions=[a3], watch=[a1, a2], config=gpt35) + + env = Environment(desc="US election live broadcast") + Team(investment=10.0, env=env, roles=[A, B, C]) + + assert a1.llm.model == "gpt-4-1106-preview" + assert a2.llm.model == "gpt-4-0613" + assert a3.llm.model == "gpt-3.5-turbo-1106" + + # history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="a1", n_round=3) + # assert "Alex" in history