mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
llm config mixin update
This commit is contained in:
parent
62677c37b7
commit
cc893914c4
5 changed files with 61 additions and 25 deletions
|
|
@ -101,7 +101,7 @@ class Config(CLIParams, YamlModel):
|
|||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
|
||||
def _get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
|
||||
"""Get LLM instance by name"""
|
||||
if name is None:
|
||||
# Use the first LLM as default
|
||||
|
|
@ -121,6 +121,21 @@ class Config(CLIParams, YamlModel):
|
|||
return llm[0]
|
||||
return None
|
||||
|
||||
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> LLMConfig:
|
||||
"""Return a LLMConfig instance"""
|
||||
if provider:
|
||||
llm_configs = self.get_llm_configs_by_type(provider)
|
||||
if name:
|
||||
llm_configs = [c for c in llm_configs if c.name == name]
|
||||
|
||||
if len(llm_configs) == 0:
|
||||
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")
|
||||
# return the first one if name is None, or return the only one
|
||||
llm_config = llm_configs[0]
|
||||
else:
|
||||
llm_config = self._get_llm_config(name)
|
||||
return llm_config
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
return self.get_llm_config_by_type(LLMType.OPENAI)
|
||||
|
|
@ -138,10 +153,12 @@ def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
|||
return result
|
||||
|
||||
|
||||
class ConfigurableMixin:
|
||||
class ConfigMixin:
|
||||
"""Mixin class for configurable objects"""
|
||||
|
||||
def __init__(self, config=None):
|
||||
_config: Optional[Config] = None
|
||||
|
||||
def __init__(self, config: Optional[Config] = None):
|
||||
self._config = config
|
||||
|
||||
def try_set_parent_config(self, parent_config):
|
||||
|
|
|
|||
|
|
@ -12,10 +12,10 @@ from typing import Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.const import OPTIONS
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import get_llm
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
from metagpt.utils.cost_manager import CostManager
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
||||
|
|
@ -42,7 +42,26 @@ class AttrDict(BaseModel):
|
|||
raise AttributeError(f"No such attribute: {key}")
|
||||
|
||||
|
||||
class Context(BaseModel):
|
||||
class LLMMixin:
|
||||
config: Optional[Config] = None
|
||||
llm_config: Optional[LLMConfig] = None
|
||||
_llm_instance: Optional[BaseLLM] = None
|
||||
|
||||
def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI):
|
||||
# 更新LLM配置
|
||||
self.llm_config = self.config.get_llm_config(name, provider)
|
||||
# 重置LLM实例
|
||||
self._llm_instance = None
|
||||
|
||||
@property
|
||||
def llm(self) -> BaseLLM:
|
||||
# 实例化LLM,如果尚未实例化
|
||||
if not self._llm_instance and self.llm_config:
|
||||
self._llm_instance = create_llm_instance(self.llm_config)
|
||||
return self._llm_instance
|
||||
|
||||
|
||||
class Context(LLMMixin, BaseModel):
|
||||
"""Env context for MetaGPT"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
|
@ -69,24 +88,14 @@ class Context(BaseModel):
|
|||
env.update({k: v for k, v in i.items() if isinstance(v, str)})
|
||||
return env
|
||||
|
||||
def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
"""Return a LLM instance"""
|
||||
if provider:
|
||||
llm_configs = self.config.get_llm_configs_by_type(provider)
|
||||
if name:
|
||||
llm_configs = [c for c in llm_configs if c.name == name]
|
||||
|
||||
if len(llm_configs) == 0:
|
||||
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")
|
||||
# return the first one if name is None, or return the only one
|
||||
llm_config = llm_configs[0]
|
||||
else:
|
||||
llm_config = self.config.get_llm_config(name)
|
||||
|
||||
llm = get_llm(llm_config)
|
||||
if llm.cost_manager is None:
|
||||
llm.cost_manager = self.cost_manager
|
||||
return llm
|
||||
# def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
# """Return a LLM instance"""
|
||||
# llm_config = self.config.get_llm_config(name, provider)
|
||||
#
|
||||
# llm = create_llm_instance(llm_config)
|
||||
# if llm.cost_manager is None:
|
||||
# llm.cost_manager = self.cost_manager
|
||||
# return llm
|
||||
|
||||
|
||||
# Global context, not in Env
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class BaseLLM(ABC):
|
|||
# OpenAI / Azure / Others
|
||||
aclient: Optional[Union[AsyncOpenAI]] = None
|
||||
cost_manager: Optional[CostManager] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, config: LLMConfig):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def register_provider(key):
|
|||
return decorator
|
||||
|
||||
|
||||
def get_llm(config: LLMConfig) -> BaseLLM:
|
||||
def create_llm_instance(config: LLMConfig) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
return LLM_REGISTRY.get_provider(config.api_type)(config)
|
||||
|
||||
|
|
|
|||
|
|
@ -61,3 +61,12 @@ def test_context_2():
|
|||
|
||||
kwargs.test_key = "test_value"
|
||||
assert kwargs.test_key == "test_value"
|
||||
|
||||
|
||||
def test_context_3():
|
||||
ctx = Context()
|
||||
ctx.use_llm(provider=LLMType.OPENAI)
|
||||
assert ctx.llm_config is not None
|
||||
assert ctx.llm_config.api_type == LLMType.OPENAI
|
||||
assert ctx.llm is not None
|
||||
assert "gpt" in ctx.llm.model
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue