add llm provider registry

This commit is contained in:
geekan 2023-12-19 17:55:34 +08:00 committed by better629
parent bd12087be4
commit f32f9c82e5
9 changed files with 89 additions and 52 deletions

View file

@ -8,6 +8,7 @@ Provide configuration, singleton
"""
import os
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any
@ -31,6 +32,15 @@ class NotConfiguredException(Exception):
super().__init__(self.message)
class LLMProviderEnum(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
class Config(metaclass=Singleton):
"""
Regular usage method:
@ -46,30 +56,37 @@ class Config(metaclass=Singleton):
def __init__(self, yaml_file=default_yaml_file):
golbal_options = OPTIONS.get()
# cli paras
self.project_path = ""
self.project_name = ""
self.inc = False
self.reqa_file = ""
self.max_auto_summarize_code = 0
self._init_with_config_files_and_env(yaml_file)
self._update()
golbal_options.update(OPTIONS.get())
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self):
if self._is_valid_llm_key(self.openai_api_key):
llm = LLMProviderEnum.OPENAI
elif self._is_valid_llm_key(self.anthropic_api_key):
llm = LLMProviderEnum.ANTHROPIC
elif self._is_valid_llm_key(self.zhipuai_api_key):
llm = LLMProviderEnum.ZHIPUAI
elif self._is_valid_llm_key(self.fireworks_api_key):
llm = LLMProviderEnum.FIREWORKS
elif self.open_llm_api_base:
llm = LLMProviderEnum.OPEN_LLM
else:
raise NotConfiguredException("You should config a LLM configuration first")
return llm
@staticmethod
def _is_valid_llm_key(k) -> bool:
return k and k != "YOUR_API_KEY"
def _check_llm_exists(self):
if not any(
[
self._is_valid_llm_key(self.openai_api_key),
self._is_valid_llm_key(self.anthropic_api_key),
self._is_valid_llm_key(self.zhipuai_api_key),
self._is_valid_llm_key(self.fireworks_api_key),
self.open_llm_api_base,
]
):
raise NotConfiguredException(
"Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY "
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE"
)
def _update(self):
# logger.info("Config loading done.")
self.global_proxy = self._get("GLOBAL_PROXY")
@ -80,7 +97,7 @@ class Config(metaclass=Singleton):
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
self._check_llm_exists()
_ = self.get_default_llm_provider_enum()
self.openai_api_base = self._get("OPENAI_API_BASE")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
@ -131,13 +148,6 @@ class Config(metaclass=Singleton):
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
self._ensure_workspace_exists()
def _init_cli_paras(self):
self.project_path = None
self.project_name = None
self.inc = None
self.reqa_file = None
self.max_auto_summarize_code = None
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""

View file

@ -8,12 +8,8 @@
from metagpt.config import CONFIG
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.human_provider import HumanProvider
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.llm_provider_registry import LLMProviderRegistry
_ = HumanProvider() # Avoid pre-commit error
@ -21,17 +17,4 @@ _ = HumanProvider() # Avoid pre-commit error
def LLM() -> BaseGPTAPI:
"""initialize different LLM instance according to the key field existence"""
# TODO a little trick, can use registry to initialize LLM instance further
if CONFIG.openai_api_key:
llm = OpenAIGPTAPI()
elif CONFIG.spark_api_key:
llm = SparkAPI()
elif CONFIG.zhipuai_api_key:
llm = ZhiPuAIGPTAPI()
elif CONFIG.open_llm_api_base:
llm = OpenLLMGPTAPI()
elif CONFIG.fireworks_api_key:
llm = FireWorksGPTAPI()
else:
raise RuntimeError("You should config a LLM configuration first")
return llm
return LLMProviderRegistry.get_provider(CONFIG.get_default_llm_provider_enum())

View file

@ -4,10 +4,12 @@
import openai
from metagpt.config import CONFIG
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
@register_provider(LLMProviderEnum.FIREWORKS)
class FireWorksGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_fireworks(CONFIG)

View file

@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2023/12/19 17:26
@Author : alexanderwu
@File : llm_provider_registry.py
"""
from metagpt.config import LLMProviderEnum
class LLMProviderRegistry:
def __init__(self):
self.providers = {}
def register(self, key, provider_cls):
self.providers[key] = provider_cls
def get_provider(self, enum: LLMProviderEnum):
"""get provider instance according to the enum"""
return self.providers[enum]()
# Registry instance
LLM_REGISTRY = LLMProviderRegistry()
def register_provider(key):
"""register provider to registry"""
def decorator(cls):
LLM_REGISTRY.register(key, cls)
return cls
return decorator

View file

@ -4,8 +4,9 @@
import openai
from metagpt.config import CONFIG
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, OpenAIGPTAPI, RateLimiter
@ -31,6 +32,7 @@ class OpenLLMCostManager(CostManager):
CONFIG.total_cost = self.total_cost
@register_provider(LLMProviderEnum.OPEN_LLM)
class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_openllm(CONFIG)

View file

@ -18,10 +18,11 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.schema import Message
from metagpt.utils.singleton import Singleton
from metagpt.utils.token_counter import (
@ -137,6 +138,7 @@ See FAQ 5.8
raise retry_state.outcome.exception()
@register_provider(LLMProviderEnum.OPENAI)
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
"""
Check https://platform.openai.com/examples for examples

View file

@ -19,11 +19,13 @@ from wsgiref.handlers import format_date_time
import websocket # 使用websocket_client
from metagpt.config import CONFIG
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
@register_provider(LLMProviderEnum.SPARK)
class SparkAPI(BaseGPTAPI):
def __init__(self):
logger.warning("当前方法无法支持异步运行。当你使用acompletion时并不能并行访问。")

View file

@ -16,9 +16,10 @@ from tenacity import (
wait_random_exponential,
)
from metagpt.config import CONFIG
from metagpt.config import CONFIG, LLMProviderEnum
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.provider.llm_provider_registry import register_provider
from metagpt.provider.openai_api import CostManager, log_and_reraise
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
@ -30,6 +31,7 @@ class ZhiPuEvent(Enum):
FINISH = "finish"
@register_provider(LLMProviderEnum.ZHIPUAI)
class ZhiPuAIGPTAPI(BaseGPTAPI):
"""
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`

View file

@ -167,8 +167,8 @@ class Message(BaseModel):
@handle_exception(exception_type=JSONDecodeError, default_return=None)
def load(val):
"""Convert the json string to object."""
d = json.loads(val)
return Message(**d)
i = json.loads(val)
return Message(**i)
class UserMessage(Message):
@ -263,16 +263,16 @@ class MessageQueue(BaseModel):
return json.dumps(lst)
@staticmethod
def load(i) -> "MessageQueue":
def load(data) -> "MessageQueue":
"""Convert the json string to the `MessageQueue` object."""
queue = MessageQueue()
try:
lst = json.loads(i)
lst = json.loads(data)
for i in lst:
msg = Message(**i)
queue.push(msg)
except JSONDecodeError as e:
logger.warning(f"JSON load failed: {i}, error:{e}")
logger.warning(f"JSON load failed: {data}, error:{e}")
return queue