diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index cabab784f..cad8112d2 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -14,8 +14,6 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator from metagpt.actions.action_node import ActionNode from metagpt.context import ContextMixin -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import ( CodeSummarizeContext, CodingContext, @@ -30,7 +28,6 @@ class Action(SerializationMixin, ContextMixin, BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"]) name: str = "" - llm: BaseLLM = Field(default_factory=LLM, exclude=True) i_context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" prefix: str = "" # aask*时会加上prefix,作为system_message desc: str = "" # for skill manager diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index a3406ff65..60939d2eb 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -133,7 +133,6 @@ class GenerateTable(Action): name: str = "GenerateTable" i_context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 84067ad92..ce366e3d2 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -178,7 +178,6 @@ class WebBrowseAndSummarize(Action): name: str = "WebBrowseAndSummarize" i_context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None web_browser_engine: Optional[WebBrowserEngine] = None diff --git a/metagpt/context.py b/metagpt/context.py index 4083a1696..bd86fb039 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -42,28 +42,6 @@ class AttrDict(BaseModel): raise AttributeError(f"No such attribute: {key}") -class LLMInstance: - """Mixin class for LLM""" - - # _config: Optional[Config] = None - _llm_config: Optional[LLMConfig] = None - _llm_instance: Optional[BaseLLM] = None - - def __init__(self, config: Config, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI): - """Use a LLM provider""" - # 更新LLM配置 - self._llm_config = config.get_llm_config(name, provider) - # 重置LLM实例 - self._llm_instance = None - - @property - def instance(self) -> BaseLLM: - """Return the LLM instance""" - if not self._llm_instance and self._llm_config: - self._llm_instance = create_llm_instance(self._llm_config) - return self._llm_instance - - class Context(BaseModel): """Env context for MetaGPT""" @@ -74,7 +52,8 @@ class Context(BaseModel): git_repo: Optional[GitRepository] = None src_workspace: Optional[Path] = None cost_manager: CostManager = CostManager() - _llm: Optional[LLMInstance] = None + + _llm: Optional[BaseLLM] = None @property def file_repo(self): @@ -92,12 +71,19 @@ class Context(BaseModel): env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env + # def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: + # """Use a LLM instance""" + # self._llm_config = self.config.get_llm_config(name, provider) + # self._llm = None + # return self._llm + def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: - """Return a LLM instance""" - llm = LLMInstance(self.config, name, provider).instance - if llm.cost_manager is None: - llm.cost_manager = self.cost_manager - return llm + """Return a LLM instance, fixme: support multiple llm instances""" + if self._llm is None: + self._llm = create_llm_instance(self.config.get_llm_config(name, provider)) + if self._llm.cost_manager is None: + self._llm.cost_manager = self.cost_manager + return self._llm class ContextMixin(BaseModel): @@ -108,11 +94,22 @@ class ContextMixin(BaseModel): # Env/Role/Action will use this config as private config, or use self.context.config as public config _config: Optional[Config] = None - def __init__(self, context: Optional[Context] = None, config: Optional[Config] = None, **kwargs): + # Env/Role/Action will use this llm as private llm, or use self.context._llm instance + _llm_config: Optional[LLMConfig] = None + _llm: Optional[BaseLLM] = None + + def __init__( + self, + context: Optional[Context] = None, + config: Optional[Config] = None, + llm: Optional[BaseLLM] = None, + **kwargs, + ): """Initialize with config""" super().__init__(**kwargs) self.set_context(context) self.set_config(config) + self.set_llm(llm) def set(self, k, v, override=False): """Set attribute""" @@ -127,30 +124,56 @@ class ContextMixin(BaseModel): """Set config""" self.set("_config", config, override) + def set_llm_config(self, llm_config: LLMConfig, override=False): + """Set llm config""" + self.set("_llm_config", llm_config, override) + + def set_llm(self, llm: BaseLLM, override=False): + """Set llm""" + self.set("_llm", llm, override) + + def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: + """Use a LLM instance""" + self._llm_config = self.config.get_llm_config(name, provider) + self._llm = None + return self.llm + @property - def config(self): + def config(self) -> Config: """Role config: role config > context config""" if self._config: return self._config return self.context.config @config.setter - def config(self, config: Config): + def config(self, config: Config) -> None: """Set config""" self.set_config(config) @property - def context(self): + def context(self) -> Context: """Role context: role context > context""" if self._context: return self._context return CONTEXT @context.setter - def context(self, context: Context): + def context(self, context: Context) -> None: """Set context""" self.set_context(context) + @property + def llm(self) -> BaseLLM: + """Role llm: role llm > context llm""" + if self._llm_config and not self._llm: + self._llm = self.context.llm(self._llm_config.name, self._llm_config.provider) + return self._llm or self.context.llm() + + @llm.setter + def llm(self, llm: BaseLLM) -> None: + """Set llm""" + self._llm = llm + # Global context, not in Env CONTEXT = Context() diff --git a/metagpt/roles/engineer.py b/metagpt/roles/engineer.py index dc9f31686..364566b37 100644 --- a/metagpt/roles/engineer.py +++ b/metagpt/roles/engineer.py @@ -109,7 +109,7 @@ class Engineer(Role): coding_context = await todo.run() # Code review if review: - action = WriteCodeReview(context=coding_context, g_context=self.context, llm=self.llm) + action = WriteCodeReview(context=coding_context, _context=self.context, llm=self.llm) self._init_action_system_message(action) coding_context = await action.run() await src_file_repo.save( diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 98cc05234..9c6832d8f 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -31,11 +31,9 @@ from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.add_requirement import UserRequirement from metagpt.context import ContextMixin -from metagpt.llm import LLM from metagpt.logs import logger from metagpt.memory import Memory from metagpt.provider import HumanProvider -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, MessageQueue, SerializationMixin from metagpt.utils.common import any_to_name, any_to_str, role_raise_decorator from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output @@ -131,7 +129,6 @@ class Role(SerializationMixin, ContextMixin, BaseModel): desc: str = "" is_human: bool = False - llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message role_id: str = "" states: list[str] = [] actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True) diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index 468905fce..200ed5051 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -17,9 +17,7 @@ from semantic_kernel.planning.basic_planner import BasicPlanner, Plan from metagpt.actions import UserRequirement from metagpt.actions.execute_task import ExecuteTask -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.make_sk_kernel import make_sk_kernel @@ -44,7 +42,6 @@ class SkAgent(Role): plan: Plan = Field(default=None, exclude=True) planner_cls: Any = None planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None - llm: BaseLLM = Field(default_factory=LLM) kernel: Kernel = Field(default_factory=Kernel) import_semantic_skill_from_directory: Callable = Field(default=None, exclude=True) import_skill: Callable = Field(default=None, exclude=True) diff --git a/metagpt/tools/moderation.py b/metagpt/tools/moderation.py index cda164ec5..f00b0e1f2 100644 --- a/metagpt/tools/moderation.py +++ b/metagpt/tools/moderation.py @@ -7,12 +7,12 @@ """ from typing import Union -from metagpt.llm import LLM +from metagpt.provider.base_llm import BaseLLM class Moderation: - def __init__(self): - self.llm = LLM() + def __init__(self, llm: BaseLLM): + self.llm = llm def handle_moderation_results(self, results): resp = [] diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index fc31b95f7..bf7c5e799 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -16,9 +16,6 @@ from metagpt.provider.base_llm import BaseLLM class OpenAIText2Image: def __init__(self, llm: BaseLLM): - """ - :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` - """ self.llm = llm async def text_2_image(self, text, size_type="1024x1024"): diff --git a/tests/metagpt/test_config.py b/tests/metagpt/test_config.py index c74b16930..cfde7a04c 100644 --- a/tests/metagpt/test_config.py +++ b/tests/metagpt/test_config.py @@ -79,3 +79,6 @@ def test_config_mixin_3(): assert obj.b == "b" assert obj.c == "c" assert obj.d == "d" + + print(obj.__dict__.keys()) + assert "_config" in obj.__dict__.keys() diff --git a/tests/metagpt/test_context.py b/tests/metagpt/test_context.py index f1c9da4e7..255794c41 100644 --- a/tests/metagpt/test_context.py +++ b/tests/metagpt/test_context.py @@ -66,7 +66,5 @@ def test_context_2(): def test_context_3(): ctx = Context() ctx.use_llm(provider=LLMType.OPENAI) - assert ctx.llm_config is not None - assert ctx.llm_config.api_type == LLMType.OPENAI - assert ctx.llm is not None - assert "gpt" in ctx.llm.model + assert ctx.llm() is not None + assert "gpt" in ctx.llm().model diff --git a/tests/metagpt/tools/test_moderation.py b/tests/metagpt/tools/test_moderation.py index 534fe812a..d265c3f78 100644 --- a/tests/metagpt/tools/test_moderation.py +++ b/tests/metagpt/tools/test_moderation.py @@ -9,6 +9,7 @@ import pytest from metagpt.config import CONFIG +from metagpt.context import CONTEXT from metagpt.tools.moderation import Moderation @@ -27,7 +28,7 @@ async def test_amoderation(content): assert not CONFIG.OPENAI_API_TYPE assert CONFIG.OPENAI_API_MODEL - moderation = Moderation() + moderation = Moderation(CONTEXT.llm()) results = await moderation.amoderation(content=content) assert isinstance(results, list) assert len(results) == len(content)