add llm provider registry

This commit is contained in:
geekan 2023-12-19 17:55:34 +08:00 committed by better629
parent bd12087be4
commit f32f9c82e5
9 changed files with 89 additions and 52 deletions

View file

@ -8,6 +8,7 @@ Provide configuration, singleton
"""
import os
from copy import deepcopy
from enum import Enum
from pathlib import Path
from typing import Any
@ -31,6 +32,15 @@ class NotConfiguredException(Exception):
super().__init__(self.message)
class LLMProviderEnum(Enum):
OPENAI = "openai"
ANTHROPIC = "anthropic"
SPARK = "spark"
ZHIPUAI = "zhipuai"
FIREWORKS = "fireworks"
OPEN_LLM = "open_llm"
class Config(metaclass=Singleton):
"""
Regular usage method:
@ -46,30 +56,37 @@ class Config(metaclass=Singleton):
def __init__(self, yaml_file=default_yaml_file):
golbal_options = OPTIONS.get()
# cli paras
self.project_path = ""
self.project_name = ""
self.inc = False
self.reqa_file = ""
self.max_auto_summarize_code = 0
self._init_with_config_files_and_env(yaml_file)
self._update()
golbal_options.update(OPTIONS.get())
logger.debug("Config loading done.")
def get_default_llm_provider_enum(self):
if self._is_valid_llm_key(self.openai_api_key):
llm = LLMProviderEnum.OPENAI
elif self._is_valid_llm_key(self.anthropic_api_key):
llm = LLMProviderEnum.ANTHROPIC
elif self._is_valid_llm_key(self.zhipuai_api_key):
llm = LLMProviderEnum.ZHIPUAI
elif self._is_valid_llm_key(self.fireworks_api_key):
llm = LLMProviderEnum.FIREWORKS
elif self.open_llm_api_base:
llm = LLMProviderEnum.OPEN_LLM
else:
raise NotConfiguredException("You should config a LLM configuration first")
return llm
@staticmethod
def _is_valid_llm_key(k) -> bool:
return k and k != "YOUR_API_KEY"
def _check_llm_exists(self):
if not any(
[
self._is_valid_llm_key(self.openai_api_key),
self._is_valid_llm_key(self.anthropic_api_key),
self._is_valid_llm_key(self.zhipuai_api_key),
self._is_valid_llm_key(self.fireworks_api_key),
self.open_llm_api_base,
]
):
raise NotConfiguredException(
"Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY "
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE"
)
def _update(self):
# logger.info("Config loading done.")
self.global_proxy = self._get("GLOBAL_PROXY")
@ -80,7 +97,7 @@ class Config(metaclass=Singleton):
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
self._check_llm_exists()
_ = self.get_default_llm_provider_enum()
self.openai_api_base = self._get("OPENAI_API_BASE")
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
@ -131,13 +148,6 @@ class Config(metaclass=Singleton):
self.workspace_path = Path(self._get("WORKSPACE_PATH", DEFAULT_WORKSPACE_ROOT))
self._ensure_workspace_exists()
def _init_cli_paras(self):
self.project_path = None
self.project_name = None
self.inc = None
self.reqa_file = None
self.max_auto_summarize_code = None
def update_via_cli(self, project_path, project_name, inc, reqa_file, max_auto_summarize_code):
"""update config via cli"""