refine code

This commit is contained in:
geekan 2024-01-09 17:13:22 +08:00
parent b338dfca64
commit ad525acd33
2 changed files with 22 additions and 19 deletions

View file

@ -9,7 +9,7 @@ import os
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig, LLMType
@ -42,31 +42,33 @@ class AttrDict(BaseModel):
raise AttributeError(f"No such attribute: {key}")
class LLMMixin:
class LLMMixin(BaseModel):
"""Mixin class for LLM"""
model_config = ConfigDict(arbitrary_types_allowed=True)
# _config: Optional[Config] = None
_llm_config: Optional[LLMConfig] = None
_llm_instance: Optional[BaseLLM] = None
llm_config: Optional[LLMConfig] = Field(default=None, exclude=True)
llm_instance: Optional[BaseLLM] = Field(default=None, exclude=True)
def use_llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI):
"""Use a LLM provider"""
# 更新LLM配置
self._llm_config = self._config.get_llm_config(name, provider)
self.llm_config = self.config.get_llm_config(name, provider)
# 重置LLM实例
self._llm_instance = None
self.llm_instance = None
@property
def llm(self) -> BaseLLM:
"""Return the LLM instance"""
if not self._llm_config:
if not self.llm_config:
self.use_llm()
if not self._llm_instance and self._llm_config:
self._llm_instance = create_llm_instance(self._llm_config)
return self._llm_instance
if not self.llm_instance and self.llm_config:
self.llm_instance = create_llm_instance(self.llm_config)
return self.llm_instance
class Context(LLMMixin, BaseModel):
class Context(BaseModel):
"""Env context for MetaGPT"""
model_config = ConfigDict(arbitrary_types_allowed=True)
@ -93,14 +95,14 @@ class Context(LLMMixin, BaseModel):
env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
# def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
# """Return a LLM instance"""
# llm_config = self.config.get_llm_config(name, provider)
#
# llm = create_llm_instance(llm_config)
# if llm.cost_manager is None:
# llm.cost_manager = self.cost_manager
# return llm
def llm(self, name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
"""Return a LLM instance"""
llm_config = self.config.get_llm_config(name, provider)
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self.cost_manager
return llm
class ContextMixin: