remove Dict, use direct LLMConfig / Browser. / Search. / Mermaid. instead

This commit is contained in:
geekan 2024-01-11 15:10:07 +08:00
parent 4de8fa3682
commit c275f28a37
16 changed files with 60 additions and 82 deletions

View file

@ -184,7 +184,7 @@ class WebBrowseAndSummarize(Action):
super().__init__(**kwargs)
self.web_browser_engine = WebBrowserEngine(
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
run_func=self.browse_func,
)

View file

@ -9,7 +9,7 @@ import os
from pathlib import Path
from typing import Dict, Iterable, List, Literal, Optional
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, model_validator
from metagpt.configs.browser_config import BrowserConfig
from metagpt.configs.llm_config import LLMConfig, LLMType
@ -44,15 +44,15 @@ class Config(CLIParams, YamlModel):
"""Configurations for MetaGPT"""
# Key Parameters
llm: Dict[str, LLMConfig] = Field(default_factory=Dict)
llm: LLMConfig
# Global Proxy. Will be used if llm.proxy is not set
proxy: str = ""
# Tool Parameters
search: Dict[str, SearchConfig] = {}
browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()}
mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()}
search: Optional[SearchConfig] = None
browser: BrowserConfig = BrowserConfig()
mermaid: MermaidConfig = MermaidConfig()
# Storage Parameters
s3: Optional[S3Config] = None
@ -110,46 +110,17 @@ class Config(CLIParams, YamlModel):
self.reqa_file = reqa_file
self.max_auto_summarize_code = max_auto_summarize_code
def _get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
"""Get LLM instance by name"""
if name is None:
# Use the first LLM as default
name = list(self.llm.keys())[0]
if name not in self.llm:
raise ValueError(f"LLM {name} not found in config")
return self.llm[name]
def get_llm_configs_by_type(self, llm_type: LLMType) -> List[LLMConfig]:
"""Get LLM instance by type"""
return [v for k, v in self.llm.items() if v.api_type == llm_type]
def get_llm_config_by_type(self, llm_type: LLMType) -> Optional[LLMConfig]:
"""Get LLM instance by type"""
llm = self.get_llm_configs_by_type(llm_type)
if llm:
return llm[0]
return None
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = None) -> LLMConfig:
"""Return a LLMConfig instance"""
if provider:
llm_configs = self.get_llm_configs_by_type(provider)
if len(llm_configs) == 0:
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")
# return the first one if name is None, or return the only one
llm_config = llm_configs[0]
else:
llm_config = self._get_llm_config(name)
return llm_config
def get_openai_llm(self) -> Optional[LLMConfig]:
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
return self.get_llm_config_by_type(LLMType.OPENAI)
if self.llm.api_type == LLMType.OPENAI:
return self.llm
return None
def get_azure_llm(self) -> Optional[LLMConfig]:
"""Get Azure LLMConfig by name. If no Azure, raise Exception"""
return self.get_llm_config_by_type(LLMType.AZURE)
if self.llm.api_type == LLMType.AZURE:
return self.llm
return None
def merge_dict(dicts: Iterable[Dict]) -> Dict:

View file

@ -40,6 +40,7 @@ class LLMConfig(YamlModel):
api_type: LLMType = LLMType.OPENAI
base_url: str = "https://api.openai.com/v1"
api_version: Optional[str] = None
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
# For Spark(Xunfei), maybe remove later

View file

@ -12,7 +12,7 @@ from typing import Optional
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.configs.llm_config import LLMConfig
from metagpt.const import OPTIONS
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
@ -77,10 +77,10 @@ class Context(BaseModel):
# self._llm = None
# return self._llm
def llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
def llm(self) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
self._llm = create_llm_instance(self.config.llm)
if self._llm.cost_manager is None:
self._llm.cost_manager = self.cost_manager
return self._llm
@ -140,12 +140,6 @@ class ContextMixin(BaseModel):
"""Set llm"""
self.set("_llm", llm, override)
def use_llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
"""Use a LLM instance"""
self._llm_config = self.config.get_llm_config(name, provider)
self._llm = None
return self.llm
@property
def config(self) -> Config:
"""Role config: role config > context config"""

View file

@ -6,14 +6,12 @@
@File : llm.py
"""
from typing import Optional
from metagpt.configs.llm_config import LLMType
from metagpt.context import CONTEXT
from metagpt.provider.base_llm import BaseLLM
def LLM(name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
def LLM() -> BaseLLM:
"""get the default llm provider if name is None"""
# context.use_llm(name=name, provider=provider)
return CONTEXT.llm(name=name, provider=provider)
return CONTEXT.llm()

View file

@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_api_key", mode="before")
@classmethod
def check_google_api_key(cls, val: str):
val = val or config.search["google"].api_key
val = val or config.search.api_key
if not val:
raise ValueError(
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
@field_validator("google_cse_id", mode="before")
@classmethod
def check_google_cse_id(cls, val: str):
val = val or config.search["google"].cse_id
val = val or config.search.cse_id
if not val:
raise ValueError(
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "

View file

@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
@field_validator("serpapi_api_key", mode="before")
@classmethod
def check_serpapi_api_key(cls, val: str):
val = val or config.search["serpapi"].api_key
val = val or config.search.api_key
if not val:
raise ValueError(
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "

View file

@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
@field_validator("serper_api_key", mode="before")
@classmethod
def check_serper_api_key(cls, val: str):
val = val or config.search["serper"].api_key
val = val or config.search.api_key
if not val:
raise ValueError(
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "

View file

@ -282,6 +282,6 @@ class UTGenerator:
"""Choose based on different calling methods"""
result = ""
if self.chatgpt_method == "API":
result = await GPTAPI(config.get_llm_config()).aask_code(messages=messages)
result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages)
return result

View file

@ -28,12 +28,10 @@ class PlaywrightWrapper:
def __init__(
self,
browser_type: Literal["chromium", "firefox", "webkit"] | None = None,
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
launch_kwargs: dict | None = None,
**kwargs,
) -> None:
if browser_type is None:
browser_type = config.browser["playwright"].driver
self.browser_type = browser_type
launch_kwargs = launch_kwargs or {}
if config.proxy and "proxy" not in launch_kwargs: