From cc893914c4d8465cb368ff6c353b2881050485df Mon Sep 17 00:00:00 2001 From: geekan Date: Tue, 9 Jan 2024 15:56:40 +0800 Subject: [PATCH] llm config mixin update --- metagpt/config2.py | 23 ++++++++-- metagpt/context.py | 51 +++++++++++++---------- metagpt/provider/base_llm.py | 1 + metagpt/provider/llm_provider_registry.py | 2 +- tests/metagpt/test_context.py | 9 ++++ 5 files changed, 61 insertions(+), 25 deletions(-) diff --git a/metagpt/config2.py b/metagpt/config2.py index 9c809e559..230e090af 100644 --- a/metagpt/config2.py +++ b/metagpt/config2.py @@ -101,7 +101,7 @@ class Config(CLIParams, YamlModel): self.reqa_file = reqa_file self.max_auto_summarize_code = max_auto_summarize_code - def get_llm_config(self, name: Optional[str] = None) -> LLMConfig: + def _get_llm_config(self, name: Optional[str] = None) -> LLMConfig: """Get LLM instance by name""" if name is None: # Use the first LLM as default @@ -121,6 +121,21 @@ class Config(CLIParams, YamlModel): return llm[0] return None + def get_llm_config(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> LLMConfig: + """Return a LLMConfig instance""" + if provider: + llm_configs = self.get_llm_configs_by_type(provider) + if name: + llm_configs = [c for c in llm_configs if c.name == name] + + if len(llm_configs) == 0: + raise ValueError(f"Cannot find llm config with name {name} and provider {provider}") + # return the first one if name is None, or return the only one + llm_config = llm_configs[0] + else: + llm_config = self._get_llm_config(name) + return llm_config + def get_openai_llm(self) -> Optional[LLMConfig]: """Get OpenAI LLMConfig by name. If no OpenAI, raise Exception""" return self.get_llm_config_by_type(LLMType.OPENAI) @@ -138,10 +153,12 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict: return result -class ConfigurableMixin: +class ConfigMixin: """Mixin class for configurable objects""" - def __init__(self, config=None): + _config: Optional[Config] = None + + def __init__(self, config: Optional[Config] = None): self._config = config def try_set_parent_config(self, parent_config): diff --git a/metagpt/context.py b/metagpt/context.py index e396de7e1..3505614bb 100644 --- a/metagpt/context.py +++ b/metagpt/context.py @@ -12,10 +12,10 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from metagpt.config2 import Config -from metagpt.configs.llm_config import LLMType +from metagpt.configs.llm_config import LLMConfig, LLMType from metagpt.const import OPTIONS from metagpt.provider.base_llm import BaseLLM -from metagpt.provider.llm_provider_registry import get_llm +from metagpt.provider.llm_provider_registry import create_llm_instance from metagpt.utils.cost_manager import CostManager from metagpt.utils.git_repository import GitRepository @@ -42,7 +42,26 @@ class AttrDict(BaseModel): raise AttributeError(f"No such attribute: {key}") -class Context(BaseModel): +class LLMMixin: + config: Optional[Config] = None + llm_config: Optional[LLMConfig] = None + _llm_instance: Optional[BaseLLM] = None + + def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI): + # 更新LLM配置 + self.llm_config = self.config.get_llm_config(name, provider) + # 重置LLM实例 + self._llm_instance = None + + @property + def llm(self) -> BaseLLM: + # 实例化LLM,如果尚未实例化 + if not self._llm_instance and self.llm_config: + self._llm_instance = create_llm_instance(self.llm_config) + return self._llm_instance + + +class Context(LLMMixin, BaseModel): """Env context for MetaGPT""" model_config = ConfigDict(arbitrary_types_allowed=True) @@ -69,24 +88,14 @@ class Context(BaseModel): env.update({k: v for k, v in i.items() if isinstance(v, str)}) return env - def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: - """Return a LLM instance""" - if provider: - llm_configs = self.config.get_llm_configs_by_type(provider) - if name: - llm_configs = [c for c in llm_configs if c.name == name] - - if len(llm_configs) == 0: - raise ValueError(f"Cannot find llm config with name {name} and provider {provider}") - # return the first one if name is None, or return the only one - llm_config = llm_configs[0] - else: - llm_config = self.config.get_llm_config(name) - - llm = get_llm(llm_config) - if llm.cost_manager is None: - llm.cost_manager = self.cost_manager - return llm + # def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM: + # """Return a LLM instance""" + # llm_config = self.config.get_llm_config(name, provider) + # + # llm = create_llm_instance(llm_config) + # if llm.cost_manager is None: + # llm.cost_manager = self.cost_manager + # return llm # Global context, not in Env diff --git a/metagpt/provider/base_llm.py b/metagpt/provider/base_llm.py index 3c6c464dc..b9847850e 100644 --- a/metagpt/provider/base_llm.py +++ b/metagpt/provider/base_llm.py @@ -27,6 +27,7 @@ class BaseLLM(ABC): # OpenAI / Azure / Others aclient: Optional[Union[AsyncOpenAI]] = None cost_manager: Optional[CostManager] = None + model: Optional[str] = None @abstractmethod def __init__(self, config: LLMConfig): diff --git a/metagpt/provider/llm_provider_registry.py b/metagpt/provider/llm_provider_registry.py index 2f68f27c8..df89d36aa 100644 --- a/metagpt/provider/llm_provider_registry.py +++ b/metagpt/provider/llm_provider_registry.py @@ -31,7 +31,7 @@ def register_provider(key): return decorator -def get_llm(config: LLMConfig) -> BaseLLM: +def create_llm_instance(config: LLMConfig) -> BaseLLM: """get the default llm provider""" return LLM_REGISTRY.get_provider(config.api_type)(config) diff --git a/tests/metagpt/test_context.py b/tests/metagpt/test_context.py index d4f29e352..2d52325bc 100644 --- a/tests/metagpt/test_context.py +++ b/tests/metagpt/test_context.py @@ -61,3 +61,12 @@ def test_context_2(): kwargs.test_key = "test_value" assert kwargs.test_key == "test_value" + + +def test_context_3(): + ctx = Context() + ctx.use_llm(provider=LLMType.OPENAI) + assert ctx.llm_config is not None + assert ctx.llm_config.api_type == LLMType.OPENAI + assert ctx.llm is not None + assert "gpt" in ctx.llm.model