diff --git a/metagpt/llm.py b/metagpt/llm.py index 4c9993441..30ced25d2 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -5,13 +5,15 @@ @Author : alexanderwu @File : llm.py """ +from typing import Optional - +from metagpt.configs.llm_config import LLMConfig from metagpt.context import CONTEXT from metagpt.provider.base_llm import BaseLLM -def LLM() -> BaseLLM: +def LLM(llm_config: Optional[LLMConfig] = None) -> BaseLLM: """get the default llm provider if name is None""" - # context.use_llm(name=name, provider=provider) + if llm_config is not None: + CONTEXT.llm_with_cost_manager_from_llm_config(llm_config) return CONTEXT.llm() diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index bc56ca813..8b0895a69 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -110,7 +110,7 @@ class Engineer(Role): # Code review if review: action = WriteCodeReview(i_context=coding_context, context=self.context, llm=self.llm) - self._init_action_system_message(action) + self._init_action(action) coding_context = await action.run() await src_file_repo.save( coding_context.filename, diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 10b60d30e..e7e5ead84 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -146,7 +146,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): super().__init__(**data) if self.is_human: - self.llm = HumanProvider() + self.llm = HumanProvider(None) self.llm.system_prompt = self._get_prefix() self._watch(data.get("watch") or [UserRequirement]) @@ -222,7 +222,8 @@ class Role(SerializationMixin, ContextMixin, BaseModel): def _setting(self): return f"{self.name}({self.profile})" - def _init_action_system_message(self, action: Action): + def _init_action(self, action: Action): + action.set_llm(self.llm, override=False) action.set_prefix(self._get_prefix()) def set_action(self, action: Action): @@ -238,7 +239,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): self._reset() for action in actions: if not isinstance(action, Action): - i = action(name="", llm=self.llm) + i = action() else: if self.is_human and not isinstance(action.llm, HumanProvider): logger.warning( @@ -247,7 +248,7 @@ class Role(SerializationMixin, ContextMixin, BaseModel): f"try passing in Action classes instead of initialized instances" ) i = action - self._init_action_system_message(i) + self._init_action(i) self.actions.append(i) self.states.append(f"{len(self.actions)}. {action}") diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 213c19676..2cb6bf559 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -1,9 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # @Desc : the unittest of spark api +from pathlib import Path import pytest +from metagpt.config2 import Config from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM from tests.metagpt.provider.mock_llm_config import mock_llm_config @@ -33,6 +35,14 @@ def mock_spark_get_msg_from_web_run(self) -> str: return resp_content +@pytest.mark.asyncio +async def test_spark_aask(): + llm = SparkLLM(Config.model_validate_yaml(Path.home() / ".metagpt" / "spark.yaml").llm) + + resp = await llm.aask("Hello!") + print(resp) + + @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) diff --git a/tests/metagpt/test_context_mixin.py b/tests/metagpt/test_context_mixin.py index 2c237b2ec..cc202a473 100644 --- a/tests/metagpt/test_context_mixin.py +++ b/tests/metagpt/test_context_mixin.py @@ -5,10 +5,15 @@ @Author : alexanderwu @File : test_context_mixin.py """ +import pytest from pydantic import BaseModel +from metagpt.actions import Action from metagpt.config2 import Config from metagpt.context_mixin import ContextMixin +from metagpt.environment import Environment +from metagpt.roles import Role +from metagpt.team import Team from tests.metagpt.provider.mock_llm_config import ( mock_llm_config, mock_llm_config_proxy, @@ -91,3 +96,27 @@ def test_config_mixin_4_multi_inheritance_override_config(): print(obj.__dict__.keys()) assert "private_config" in obj.__dict__.keys() assert obj.llm.model == "mock_zhipu_model" + + +@pytest.mark.asyncio +async def test_debate_two_roles(): + config = Config.default() + config.llm.model = "gpt-4-1106-preview" + action1 = Action(config=config, name="AlexSay", instruction="Say your opinion with emotion and don't repeat it") + action2 = Action(name="BobSay", instruction="Say your opinion with emotion and don't repeat it") + biden = Role( + name="Alex", profile="Democratic candidate", goal="Win the election", actions=[action1], watch=[action2] + ) + trump = Role( + name="Bob", profile="Republican candidate", goal="Win the election", actions=[action2], watch=[action1] + ) + env = Environment(desc="US election live broadcast") + team = Team(investment=10.0, env=env, roles=[biden, trump]) + + print(action1.llm.system_prompt) + print(action2.llm.system_prompt) + print(biden.llm.system_prompt) + print(trump.llm.system_prompt) + + history = await team.run(idea="Topic: climate change. Under 80 words per message.", send_to="Alex", n_round=3) + assert "Alex" in history