mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
remove Dict, use direct LLMConfig / Browser. / Search. / Mermaid. instead
This commit is contained in:
parent
4de8fa3682
commit
c275f28a37
16 changed files with 60 additions and 82 deletions
|
|
@ -1,4 +1,3 @@
|
|||
llm:
|
||||
gpt3t:
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
api_key: "YOUR_API_KEY"
|
||||
model: "gpt-3.5-turbo-1106"
|
||||
Binary file not shown.
|
|
@ -184,7 +184,7 @@ class WebBrowseAndSummarize(Action):
|
|||
super().__init__(**kwargs)
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else WebBrowserEngineType.PLAYWRIGHT,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import os
|
|||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from metagpt.configs.browser_config import BrowserConfig
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
|
|
@ -44,15 +44,15 @@ class Config(CLIParams, YamlModel):
|
|||
"""Configurations for MetaGPT"""
|
||||
|
||||
# Key Parameters
|
||||
llm: Dict[str, LLMConfig] = Field(default_factory=Dict)
|
||||
llm: LLMConfig
|
||||
|
||||
# Global Proxy. Will be used if llm.proxy is not set
|
||||
proxy: str = ""
|
||||
|
||||
# Tool Parameters
|
||||
search: Dict[str, SearchConfig] = {}
|
||||
browser: Dict[str, BrowserConfig] = {"default": BrowserConfig()}
|
||||
mermaid: Dict[str, MermaidConfig] = {"default": MermaidConfig()}
|
||||
search: Optional[SearchConfig] = None
|
||||
browser: BrowserConfig = BrowserConfig()
|
||||
mermaid: MermaidConfig = MermaidConfig()
|
||||
|
||||
# Storage Parameters
|
||||
s3: Optional[S3Config] = None
|
||||
|
|
@ -110,46 +110,17 @@ class Config(CLIParams, YamlModel):
|
|||
self.reqa_file = reqa_file
|
||||
self.max_auto_summarize_code = max_auto_summarize_code
|
||||
|
||||
def _get_llm_config(self, name: Optional[str] = None) -> LLMConfig:
|
||||
"""Get LLM instance by name"""
|
||||
if name is None:
|
||||
# Use the first LLM as default
|
||||
name = list(self.llm.keys())[0]
|
||||
if name not in self.llm:
|
||||
raise ValueError(f"LLM {name} not found in config")
|
||||
return self.llm[name]
|
||||
|
||||
def get_llm_configs_by_type(self, llm_type: LLMType) -> List[LLMConfig]:
|
||||
"""Get LLM instance by type"""
|
||||
return [v for k, v in self.llm.items() if v.api_type == llm_type]
|
||||
|
||||
def get_llm_config_by_type(self, llm_type: LLMType) -> Optional[LLMConfig]:
|
||||
"""Get LLM instance by type"""
|
||||
llm = self.get_llm_configs_by_type(llm_type)
|
||||
if llm:
|
||||
return llm[0]
|
||||
return None
|
||||
|
||||
def get_llm_config(self, name: Optional[str] = None, provider: LLMType = None) -> LLMConfig:
|
||||
"""Return a LLMConfig instance"""
|
||||
if provider:
|
||||
llm_configs = self.get_llm_configs_by_type(provider)
|
||||
|
||||
if len(llm_configs) == 0:
|
||||
raise ValueError(f"Cannot find llm config with name {name} and provider {provider}")
|
||||
# return the first one if name is None, or return the only one
|
||||
llm_config = llm_configs[0]
|
||||
else:
|
||||
llm_config = self._get_llm_config(name)
|
||||
return llm_config
|
||||
|
||||
def get_openai_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get OpenAI LLMConfig by name. If no OpenAI, raise Exception"""
|
||||
return self.get_llm_config_by_type(LLMType.OPENAI)
|
||||
if self.llm.api_type == LLMType.OPENAI:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
def get_azure_llm(self) -> Optional[LLMConfig]:
|
||||
"""Get Azure LLMConfig by name. If no Azure, raise Exception"""
|
||||
return self.get_llm_config_by_type(LLMType.AZURE)
|
||||
if self.llm.api_type == LLMType.AZURE:
|
||||
return self.llm
|
||||
return None
|
||||
|
||||
|
||||
def merge_dict(dicts: Iterable[Dict]) -> Dict:
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class LLMConfig(YamlModel):
|
|||
api_type: LLMType = LLMType.OPENAI
|
||||
base_url: str = "https://api.openai.com/v1"
|
||||
api_version: Optional[str] = None
|
||||
|
||||
model: Optional[str] = None # also stands for DEPLOYMENT_NAME
|
||||
|
||||
# For Spark(Xunfei), maybe remove later
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from typing import Optional
|
|||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMConfig, LLMType
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.const import OPTIONS
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
|
|
@ -77,10 +77,10 @@ class Context(BaseModel):
|
|||
# self._llm = None
|
||||
# return self._llm
|
||||
|
||||
def llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
|
||||
def llm(self) -> BaseLLM:
|
||||
"""Return a LLM instance, fixme: support cache"""
|
||||
# if self._llm is None:
|
||||
self._llm = create_llm_instance(self.config.get_llm_config(name, provider))
|
||||
self._llm = create_llm_instance(self.config.llm)
|
||||
if self._llm.cost_manager is None:
|
||||
self._llm.cost_manager = self.cost_manager
|
||||
return self._llm
|
||||
|
|
@ -140,12 +140,6 @@ class ContextMixin(BaseModel):
|
|||
"""Set llm"""
|
||||
self.set("_llm", llm, override)
|
||||
|
||||
def use_llm(self, name: Optional[str] = None, provider: LLMType = None) -> BaseLLM:
|
||||
"""Use a LLM instance"""
|
||||
self._llm_config = self.config.get_llm_config(name, provider)
|
||||
self._llm = None
|
||||
return self.llm
|
||||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""Role config: role config > context config"""
|
||||
|
|
|
|||
|
|
@ -6,14 +6,12 @@
|
|||
@File : llm.py
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.context import CONTEXT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
def LLM(name: Optional[str] = None, provider: LLMType = LLMType.OPENAI) -> BaseLLM:
|
||||
def LLM() -> BaseLLM:
|
||||
"""get the default llm provider if name is None"""
|
||||
# context.use_llm(name=name, provider=provider)
|
||||
return CONTEXT.llm(name=name, provider=provider)
|
||||
return CONTEXT.llm()
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or config.search["google"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_api_key when constructing an object. Alternatively, "
|
||||
|
|
@ -47,7 +47,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or config.search["google"].cse_id
|
||||
val = val or config.search.cse_id
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the google_cse_id when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
@field_validator("serpapi_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or config.search["serpapi"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serpapi_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ class SerperWrapper(BaseModel):
|
|||
@field_validator("serper_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or config.search["serper"].api_key
|
||||
val = val or config.search.api_key
|
||||
if not val:
|
||||
raise ValueError(
|
||||
"To use, make sure you provide the serper_api_key when constructing an object. Alternatively, "
|
||||
|
|
|
|||
|
|
@ -282,6 +282,6 @@ class UTGenerator:
|
|||
"""Choose based on different calling methods"""
|
||||
result = ""
|
||||
if self.chatgpt_method == "API":
|
||||
result = await GPTAPI(config.get_llm_config()).aask_code(messages=messages)
|
||||
result = await GPTAPI(config.get_openai_llm()).aask_code(messages=messages)
|
||||
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -28,12 +28,10 @@ class PlaywrightWrapper:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = None,
|
||||
browser_type: Literal["chromium", "firefox", "webkit"] | None = "chromium",
|
||||
launch_kwargs: dict | None = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if browser_type is None:
|
||||
browser_type = config.browser["playwright"].driver
|
||||
self.browser_type = browser_type
|
||||
launch_kwargs = launch_kwargs or {}
|
||||
if config.proxy and "proxy" not in launch_kwargs:
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -12,7 +12,7 @@ from tests.metagpt.provider.mock_llm_config import (
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = LLM(name="gpt3t")
|
||||
llm = LLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
|
||||
|
|
@ -24,7 +24,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = LLM(name="gpt3t")
|
||||
llm = LLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -34,7 +34,7 @@ async def test_aask_code_str():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_message():
|
||||
llm = LLM(name="gpt3t")
|
||||
llm = LLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
|
|||
|
|
@ -10,7 +10,10 @@ from pydantic import BaseModel
|
|||
from metagpt.config2 import Config
|
||||
from metagpt.configs.llm_config import LLMType
|
||||
from metagpt.context import ContextMixin
|
||||
from tests.metagpt.provider.mock_llm_config import mock_llm_config
|
||||
from tests.metagpt.provider.mock_llm_config import (
|
||||
mock_llm_config,
|
||||
mock_llm_config_proxy,
|
||||
)
|
||||
|
||||
|
||||
def test_config_1():
|
||||
|
|
@ -21,9 +24,9 @@ def test_config_1():
|
|||
|
||||
|
||||
def test_config_from_dict():
|
||||
cfg = Config(llm={"default": mock_llm_config})
|
||||
cfg = Config(llm=mock_llm_config)
|
||||
assert cfg
|
||||
assert cfg.llm["default"].api_key == "mock_api_key"
|
||||
assert cfg.llm.api_key == "mock_api_key"
|
||||
|
||||
|
||||
class ModelX(ContextMixin, BaseModel):
|
||||
|
|
@ -47,11 +50,11 @@ def test_config_mixin_1():
|
|||
|
||||
|
||||
def test_config_mixin_2():
|
||||
i = Config(llm={"default": mock_llm_config})
|
||||
j = Config(llm={"new": mock_llm_config})
|
||||
i = Config(llm=mock_llm_config)
|
||||
j = Config(llm=mock_llm_config_proxy)
|
||||
obj = ModelX(config=i)
|
||||
assert obj._config == i
|
||||
assert obj._config.llm["default"] == mock_llm_config
|
||||
assert obj._config.llm == mock_llm_config
|
||||
|
||||
obj.set_config(j)
|
||||
# obj already has a config, so it will not be set
|
||||
|
|
@ -60,16 +63,16 @@ def test_config_mixin_2():
|
|||
|
||||
def test_config_mixin_3():
|
||||
"""Test config mixin with multiple inheritance"""
|
||||
i = Config(llm={"default": mock_llm_config})
|
||||
j = Config(llm={"new": mock_llm_config})
|
||||
i = Config(llm=mock_llm_config)
|
||||
j = Config(llm=mock_llm_config_proxy)
|
||||
obj = ModelY(config=i)
|
||||
assert obj._config == i
|
||||
assert obj._config.llm["default"] == mock_llm_config
|
||||
assert obj._config.llm == mock_llm_config
|
||||
|
||||
obj.set_config(j)
|
||||
# obj already has a config, so it will not be set
|
||||
assert obj._config == i
|
||||
assert obj._config.llm["default"] == mock_llm_config
|
||||
assert obj._config.llm == mock_llm_config
|
||||
|
||||
assert obj.a == "a"
|
||||
assert obj.b == "b"
|
||||
|
|
|
|||
|
|
@ -49,13 +49,14 @@ class MockSearchEnine:
|
|||
async def test_search_engine(search_engine_type, run_func: Callable, max_results: int, as_string: bool, aiohttp_mocker):
|
||||
# Prerequisites
|
||||
cache_json_path = None
|
||||
# FIXME: 不能使用全局的config,而是要自己实例化对应的config
|
||||
if search_engine_type is SearchEngineType.SERPAPI_GOOGLE:
|
||||
assert config.search["serpapi"]
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serpapi-metagpt-{max_results}.json"
|
||||
elif search_engine_type is SearchEngineType.DIRECT_GOOGLE:
|
||||
assert config.search["google"]
|
||||
assert config.search
|
||||
elif search_engine_type is SearchEngineType.SERPER_GOOGLE:
|
||||
assert config.search["serper"]
|
||||
assert config.search
|
||||
cache_json_path = search_cache_path / f"serper-metagpt-{max_results}.json"
|
||||
|
||||
if cache_json_path:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue