fixbug: recursive search for provider

This commit is contained in:
莘权 马 2024-01-04 12:46:41 +08:00
parent a1c754bd94
commit d051933324
2 changed files with 9 additions and 1 deletions

View file

@ -50,6 +50,9 @@ class LLMProviderEnum(Enum):
AZURE_OPENAI = "azure_openai"
OLLAMA = "ollama"
def __missing__(self, key):
return self.OPENAI
class Config(metaclass=Singleton):
"""
@ -108,6 +111,11 @@ class Config(metaclass=Singleton):
if v:
provider = k
break
if provider is None:
if self.DEFAULT_PROVIDER:
provider = LLMProviderEnum(self.DEFAULT_PROVIDER)
else:
raise NotConfiguredException("You should config a LLM configuration first")
if provider is LLMProviderEnum.GEMINI and not require_python_version(req_version=(3, 10)):
warnings.warn("Use Gemini requires Python >= 3.10")
@ -117,7 +125,6 @@ class Config(metaclass=Singleton):
if provider:
logger.info(f"API: {provider}")
return provider
raise NotConfiguredException("You should config a LLM configuration first")
def get_model_name(self, provider=None) -> str:
provider = provider or self.get_default_llm_provider_enum()