mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
feat: allow overriding base URL of OpenAI models (#368)
* Add OpenAI-compatible API option in model configuration
Backend-only cherry-pick from 20617db37a.
* Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
* fix: harden the base url settings in SaaS mode
---------
Co-authored-by: Chris Briddock <briddockchristopher@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
9675151549
commit
8a58b0992d
6 changed files with 425 additions and 11 deletions
|
|
@ -13,6 +13,7 @@ from api.schemas.user_configuration import (
|
|||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
AuthContext = TypedDict(
|
||||
"AuthContext",
|
||||
|
|
@ -107,6 +108,17 @@ class UserConfigurationValidator:
|
|||
|
||||
provider = service_config.provider
|
||||
|
||||
for url_field in ("base_url", "endpoint"):
|
||||
url = getattr(service_config, url_field, None)
|
||||
if url:
|
||||
try:
|
||||
validate_user_configured_service_url(
|
||||
url,
|
||||
field_name=url_field,
|
||||
)
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
|
||||
# Speaches doesn't require an API key
|
||||
if provider == ServiceProviders.SPEACHES.value:
|
||||
try:
|
||||
|
|
@ -181,30 +193,92 @@ class UserConfigurationValidator:
|
|||
api_key = service_config.api_key
|
||||
|
||||
try:
|
||||
if not self._check_api_key(provider, api_key):
|
||||
if not self._check_api_key(provider, api_key, service_config):
|
||||
return [
|
||||
{"model": service_name, "message": f"Invalid {provider} API key"}
|
||||
{
|
||||
"model": service_name,
|
||||
"message": (
|
||||
f"Invalid {provider} API key. Please verify your API key is "
|
||||
f"correct, has not expired, and has the required permissions."
|
||||
),
|
||||
}
|
||||
]
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
|
||||
return []
|
||||
|
||||
def _check_api_key(self, provider: str, api_key: str) -> bool:
|
||||
def _check_api_key(
|
||||
self,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
service_config: Optional[ServiceConfig] = None,
|
||||
) -> bool:
|
||||
"""Check if an API key for a provider is valid."""
|
||||
validator = self._validator_map.get(provider)
|
||||
if not validator:
|
||||
return False
|
||||
|
||||
if provider in (
|
||||
ServiceProviders.OPENAI.value,
|
||||
ServiceProviders.OPENAI_REALTIME.value,
|
||||
):
|
||||
return validator(provider, api_key, service_config)
|
||||
return validator(provider, api_key)
|
||||
|
||||
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
def _check_openai_api_key(
|
||||
self, model: str, api_key: str, service_config: Optional[ServiceConfig] = None
|
||||
) -> bool:
|
||||
client_kwargs: dict[str, str] = {"api_key": api_key}
|
||||
base_url = getattr(service_config, "base_url", None) if service_config else None
|
||||
if base_url:
|
||||
client_kwargs["base_url"] = base_url
|
||||
client = openai.OpenAI(**client_kwargs)
|
||||
try:
|
||||
client.models.list()
|
||||
return True
|
||||
except openai.AuthenticationError:
|
||||
return False
|
||||
if base_url and "openai.com" not in base_url:
|
||||
raise ValueError(
|
||||
f"Invalid OpenAI API key. The key was rejected by the API at {base_url}. "
|
||||
"Please check that your API key is correct and has not been revoked."
|
||||
)
|
||||
raise ValueError(
|
||||
"Invalid OpenAI API key. The key was rejected by the OpenAI API. "
|
||||
"Please check that your API key is correct and has not been revoked. "
|
||||
"You can verify your keys at https://platform.openai.com/api-keys."
|
||||
)
|
||||
except openai.APIConnectionError:
|
||||
if base_url:
|
||||
raise ValueError(
|
||||
f"Could not connect to the OpenAI-compatible API at {base_url}. "
|
||||
"Please verify that the base_url is correct and reachable, and try again."
|
||||
)
|
||||
raise ValueError(
|
||||
"Could not connect to the OpenAI API. Please check your network connection "
|
||||
"and try again."
|
||||
)
|
||||
except openai.APIError:
|
||||
if base_url:
|
||||
raise ValueError(
|
||||
f"The OpenAI-compatible API at {base_url} returned an error while "
|
||||
"validating the API key. Please verify that the base_url is correct, "
|
||||
"the service is available, and the API key is valid."
|
||||
)
|
||||
raise ValueError(
|
||||
"The OpenAI API returned an error while validating the API key. "
|
||||
"Please try again later."
|
||||
)
|
||||
except Exception:
|
||||
if base_url:
|
||||
raise ValueError(
|
||||
f"Failed to validate the OpenAI API key using the API at {base_url}. "
|
||||
"Please verify that the base_url is correct and reachable, and that the "
|
||||
"API key is valid."
|
||||
)
|
||||
raise ValueError(
|
||||
"Failed to validate the OpenAI API key. Please try again later."
|
||||
)
|
||||
|
||||
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
|
||||
try:
|
||||
|
|
@ -212,7 +286,11 @@ class UserConfigurationValidator:
|
|||
deepgram.manage.v1.projects.list()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
raise ValueError(
|
||||
"Invalid Deepgram API key. The key was rejected by the Deepgram API. "
|
||||
"Please check that your API key is correct and active. "
|
||||
"You can verify your keys at https://console.deepgram.com/."
|
||||
)
|
||||
|
||||
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
|
||||
client = Groq(api_key=api_key)
|
||||
|
|
@ -220,7 +298,11 @@ class UserConfigurationValidator:
|
|||
client.models.list()
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
raise ValueError(
|
||||
"Invalid Groq API key. The key was rejected by the Groq API. "
|
||||
"Please check that your API key is correct and active. "
|
||||
"You can verify your keys at https://console.groq.com/keys."
|
||||
)
|
||||
|
||||
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -290,6 +290,10 @@ class OpenAILLMService(BaseLLMConfiguration):
|
|||
description="OpenAI chat model to use.",
|
||||
json_schema_extra={"examples": OPENAI_MODELS, "allow_custom_input": True},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://api.openai.com/v1",
|
||||
description="Override only if using an OpenAI-compatible API (e.g. local LLM, proxy).",
|
||||
)
|
||||
|
||||
|
||||
@register_llm
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from loguru import logger
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
|
|
@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
if self._api_key_configured:
|
||||
client_kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
validate_user_configured_service_url(
|
||||
base_url,
|
||||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from loguru import logger
|
|||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
|
|
@ -62,6 +63,16 @@ if TYPE_CHECKING:
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
|
||||
def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
||||
try:
|
||||
validate_user_configured_service_url(
|
||||
url,
|
||||
field_name=field_name,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
):
|
||||
|
|
@ -174,6 +185,7 @@ def create_stt_service(
|
|||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
_validate_runtime_service_url(user_config.stt.base_url, "base_url")
|
||||
return SpeachesSTTService(
|
||||
base_url=user_config.stt.base_url,
|
||||
api_key=user_config.stt.api_key or "none",
|
||||
|
|
@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
# ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP
|
||||
# scheme (matching ElevenLabs documentation, e.g.
|
||||
# https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme.
|
||||
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
|
||||
elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace(
|
||||
"http://", "ws://"
|
||||
)
|
||||
|
|
@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
tts._settings.language = language
|
||||
return tts
|
||||
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
|
||||
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
|
||||
return SpeachesTTSService(
|
||||
base_url=user_config.tts.base_url,
|
||||
api_key=user_config.tts.api_key or "none",
|
||||
|
|
@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
).rstrip("/")
|
||||
if not base_url.endswith("/t2a_v2"):
|
||||
base_url = f"{base_url}/t2a_v2"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
return MiniMaxOwnedSessionTTSService(
|
||||
|
|
@ -504,6 +519,10 @@ def create_llm_service_from_provider(
|
|||
"""
|
||||
logger.info(f"Creating LLM service: provider={provider}, model={model}")
|
||||
if provider == ServiceProviders.OPENAI.value:
|
||||
kwargs = {}
|
||||
if base_url:
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
kwargs["base_url"] = base_url
|
||||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
api_key=api_key,
|
||||
|
|
@ -511,10 +530,12 @@ def create_llm_service_from_provider(
|
|||
model=model,
|
||||
extra={"reasoning_effort": "minimal", "verbosity": "low"},
|
||||
),
|
||||
**kwargs,
|
||||
)
|
||||
return OpenAILLMService(
|
||||
api_key=api_key,
|
||||
settings=OpenAILLMSettings(model=model, temperature=0.1),
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == ServiceProviders.GROQ.value:
|
||||
return GroqLLMService(
|
||||
|
|
@ -524,6 +545,7 @@ def create_llm_service_from_provider(
|
|||
elif provider == ServiceProviders.OPENROUTER.value:
|
||||
kwargs = {}
|
||||
if base_url:
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
kwargs["base_url"] = base_url
|
||||
return OpenRouterLLMService(
|
||||
api_key=api_key,
|
||||
|
|
@ -543,6 +565,8 @@ def create_llm_service_from_provider(
|
|||
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
if endpoint:
|
||||
_validate_runtime_service_url(endpoint, "endpoint")
|
||||
return AzureLLMService(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
|
|
@ -562,15 +586,19 @@ def create_llm_service_from_provider(
|
|||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
base_url = base_url or "http://localhost:11434/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return SpeachesLLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
base_url=base_url,
|
||||
api_key=api_key or "none",
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
base_url = base_url or "https://api.minimax.io/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return MiniMaxLLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://api.minimax.io/v1",
|
||||
base_url=base_url,
|
||||
settings=MiniMaxLLMService.Settings(
|
||||
model=model,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
|
|
@ -709,7 +737,9 @@ def create_llm_service(user_config):
|
|||
api_key = user_config.llm.api_key
|
||||
|
||||
kwargs = {}
|
||||
if provider == ServiceProviders.OPENROUTER.value:
|
||||
if provider == ServiceProviders.OPENAI.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.OPENROUTER.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue