mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-24 14:15:17 +02:00
add llm provider registry
This commit is contained in:
parent
bd12087be4
commit
f32f9c82e5
9 changed files with 89 additions and 52 deletions
|
|
@ -8,6 +8,7 @@ Provide configuration, singleton
|
|||
"""
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -31,6 +32,15 @@ class NotConfiguredException(Exception):
|
|||
super().__init__(self.message)
|
||||
|
||||
|
||||
class LLMProviderEnum(Enum):
|
||||
OPENAI = "openai"
|
||||
ANTHROPIC = "anthropic"
|
||||
SPARK = "spark"
|
||||
ZHIPUAI = "zhipuai"
|
||||
FIREWORKS = "fireworks"
|
||||
OPEN_LLM = "open_llm"
|
||||
|
||||
|
||||
class Config(metaclass=Singleton):
|
||||
"""
|
||||
Regular usage method:
|
||||
|
|
@ -46,30 +56,37 @@ class Config(metaclass=Singleton):
|
|||
|
||||
def __init__(self, yaml_file=default_yaml_file):
|
||||
golbal_options = OPTIONS.get()
|
||||
# cli paras
|
||||
self.project_path = ""
|
||||
self.project_name = ""
|
||||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
self._update()
|
||||
golbal_options.update(OPTIONS.get())
|
||||
logger.debug("Config loading done.")
|
||||
|
||||
def get_default_llm_provider_enum(self):
|
||||
if self._is_valid_llm_key(self.openai_api_key):
|
||||
llm = LLMProviderEnum.OPENAI
|
||||
elif self._is_valid_llm_key(self.anthropic_api_key):
|
||||
llm = LLMProviderEnum.ANTHROPIC
|
||||
elif self._is_valid_llm_key(self.zhipuai_api_key):
|
||||
llm = LLMProviderEnum.ZHIPUAI
|
||||
elif self._is_valid_llm_key(self.fireworks_api_key):
|
||||
llm = LLMProviderEnum.FIREWORKS
|
||||
elif self.open_llm_api_base:
|
||||
llm = LLMProviderEnum.OPEN_LLM
|
||||
else:
|
||||
raise NotConfiguredException("You should config a LLM configuration first")
|
||||
return llm
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_llm_key(k) -> bool:
|
||||
return k and k != "YOUR_API_KEY"
|
||||
|
||||
def _check_llm_exists(self):
|
||||
if not any(
|
||||
[
|
||||
self._is_valid_llm_key(self.openai_api_key),
|
||||
self._is_valid_llm_key(self.anthropic_api_key),
|
||||
self._is_valid_llm_key(self.zhipuai_api_key),
|
||||
self._is_valid_llm_key(self.fireworks_api_key),
|
||||
self.open_llm_api_base,
|
||||
]
|
||||
):
|
||||
raise NotConfiguredException(
|
||||
"Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY "
|
||||
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE"
|
||||
)
|
||||
|
||||
def _update(self):
|
||||
# logger.info("Config loading done.")
|
||||
self.global_proxy = self._get("GLOBAL_PROXY")
|
||||
|
|
@ -80,7 +97,7 @@ class Config(metaclass=Singleton):
|
|||
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
|
||||
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
|
||||
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
|
||||
self._check_llm_exists()
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
self.openai_api_base = self._get("OPENAI_API_BASE")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
|
|
@ -131,13 +148,6 @@ class Config(metaclass=Singleton):
|
|||
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
|
||||
self._ensure_workspace_exists()
|
||||
|
||||
def _init_cli_paras(self):
|
||||
self.project_path = None
|
||||
self.project_name = None
|
||||
self.inc = None
|
||||
self.reqa_file = None
|
||||
self.max_auto_summarize_code = None
|
||||
|
||||
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
|
||||
"""update config via cli"""
|
||||
|
||||
|
|
|
|||
|
|
@ -8,12 +8,8 @@
|
|||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.spark_api import SparkAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import LLMProviderRegistry
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
|
@ -21,17 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error
|
|||
def LLM() -> BaseGPTAPI:
|
||||
"""initialize different LLM instance according to the key field existence"""
|
||||
# TODO a little trick, can use registry to initialize LLM instance further
|
||||
if CONFIG.openai_api_key:
|
||||
llm = OpenAIGPTAPI()
|
||||
elif CONFIG.spark_api_key:
|
||||
llm = SparkAPI()
|
||||
elif CONFIG.zhipuai_api_key:
|
||||
llm = ZhiPuAIGPTAPI()
|
||||
elif CONFIG.open_llm_api_base:
|
||||
llm = OpenLLMGPTAPI()
|
||||
elif CONFIG.fireworks_api_key:
|
||||
llm = FireWorksGPTAPI()
|
||||
else:
|
||||
raise RuntimeError("You should config a LLM configuration first")
|
||||
|
||||
return llm
|
||||
return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum())
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_fireworks(CONFIG)
|
||||
|
|
|
|||
34
metagpt/provider/llm_provider_registry.py
Normal file
34
metagpt/provider/llm_provider_registry.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/19 17:26
|
||||
@Author : alexanderwu
|
||||
@File : llm_provider_registry.py
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
|
||||
|
||||
class LLMProviderRegistry:
|
||||
def __init__(self):
|
||||
self.providers = {}
|
||||
|
||||
def register(self, key, provider_cls):
|
||||
self.providers[key] = provider_cls
|
||||
|
||||
def get_provider(self, enum: LLMProviderEnum):
|
||||
"""get provider instance according to the enum"""
|
||||
return self.providers[enum]()
|
||||
|
||||
|
||||
# Registry instance
|
||||
LLM_REGISTRY = LLMProviderRegistry()
|
||||
|
||||
|
||||
def register_provider(key):
|
||||
"""register provider to registry"""
|
||||
|
||||
def decorator(cls):
|
||||
LLM_REGISTRY.register(key, cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
|
@ -4,8 +4,9 @@
|
|||
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
|
||||
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ class OpenLLMCostManager(CostManager):
|
|||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
def __init__(self):
|
||||
self.__init_openllm(CONFIG)
|
||||
|
|
|
|||
|
|
@ -18,10 +18,11 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.singleton import Singleton
|
||||
from metagpt.utils.token_counter import (
|
||||
|
|
@ -137,6 +138,7 @@ See FAQ 5.8
|
|||
raise retry_state.outcome.exception()
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
|
|
|
|||
|
|
@ -19,11 +19,13 @@ from wsgiref.handlers import format_date_time
|
|||
|
||||
import websocket # 使用websocket_client
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkAPI(BaseGPTAPI):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
|
|
|||
|
|
@ -16,9 +16,10 @@ from tenacity import (
|
|||
wait_random_exponential,
|
||||
)
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import CostManager, log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
|
|
@ -30,6 +31,7 @@ class ZhiPuEvent(Enum):
|
|||
FINISH = "finish"
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
class ZhiPuAIGPTAPI(BaseGPTAPI):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
|
|
|
|||
|
|
@ -167,8 +167,8 @@ class Message(BaseModel):
|
|||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
def load(val):
|
||||
"""Convert the json string to object."""
|
||||
d = json.loads(val)
|
||||
return Message(**d)
|
||||
i = json.loads(val)
|
||||
return Message(**i)
|
||||
|
||||
|
||||
class UserMessage(Message):
|
||||
|
|
@ -263,16 +263,16 @@ class MessageQueue(BaseModel):
|
|||
return json.dumps(lst)
|
||||
|
||||
@staticmethod
|
||||
def load(i) -> "MessageQueue":
|
||||
def load(data) -> "MessageQueue":
|
||||
"""Convert the json string to the `MessageQueue` object."""
|
||||
queue = MessageQueue()
|
||||
try:
|
||||
lst = json.loads(i)
|
||||
lst = json.loads(data)
|
||||
for i in lst:
|
||||
msg = Message(**i)
|
||||
queue.push(msg)
|
||||
except JSONDecodeError as e:
|
||||
logger.warning(f"JSON load failed: {i}, error:{e}")
|
||||
logger.warning(f"JSON load failed: {data}, error:{e}")
|
||||
|
||||
return queue
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue