feat: allow overriding base URL of OpenAI models (#368)

* Add OpenAI-compatible API option in model configuration

Backend-only cherry-pick from 20617db37a.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: harden the base url settings in SaaS mode

---------

Co-authored-by: Chris Briddock <briddockchristopher@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Abhishek 2026-05-27 13:07:45 +05:30 committed by GitHub
parent 9675151549
commit 8a58b0992d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 425 additions and 11 deletions

View file

@ -13,6 +13,7 @@ from api.schemas.user_configuration import (
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
from api.utils.url_security import validate_user_configured_service_url
AuthContext = TypedDict(
"AuthContext",
@ -107,6 +108,17 @@ class UserConfigurationValidator:
provider = service_config.provider
for url_field in ("base_url", "endpoint"):
url = getattr(service_config, url_field, None)
if url:
try:
validate_user_configured_service_url(
url,
field_name=url_field,
)
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
# Speaches doesn't require an API key
if provider == ServiceProviders.SPEACHES.value:
try:
@ -181,30 +193,92 @@ class UserConfigurationValidator:
api_key = service_config.api_key
try:
if not self._check_api_key(provider, api_key):
if not self._check_api_key(provider, api_key, service_config):
return [
{"model": service_name, "message": f"Invalid {provider} API key"}
{
"model": service_name,
"message": (
f"Invalid {provider} API key. Please verify your API key is "
f"correct, has not expired, and has the required permissions."
),
}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
def _check_api_key(self, provider: str, api_key: str) -> bool:
def _check_api_key(
self,
provider: str,
api_key: str,
service_config: Optional[ServiceConfig] = None,
) -> bool:
"""Check if an API key for a provider is valid."""
validator = self._validator_map.get(provider)
if not validator:
return False
if provider in (
ServiceProviders.OPENAI.value,
ServiceProviders.OPENAI_REALTIME.value,
):
return validator(provider, api_key, service_config)
return validator(provider, api_key)
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
client = openai.OpenAI(api_key=api_key)
def _check_openai_api_key(
self, model: str, api_key: str, service_config: Optional[ServiceConfig] = None
) -> bool:
client_kwargs: dict[str, str] = {"api_key": api_key}
base_url = getattr(service_config, "base_url", None) if service_config else None
if base_url:
client_kwargs["base_url"] = base_url
client = openai.OpenAI(**client_kwargs)
try:
client.models.list()
return True
except openai.AuthenticationError:
return False
if base_url and "openai.com" not in base_url:
raise ValueError(
f"Invalid OpenAI API key. The key was rejected by the API at {base_url}. "
"Please check that your API key is correct and has not been revoked."
)
raise ValueError(
"Invalid OpenAI API key. The key was rejected by the OpenAI API. "
"Please check that your API key is correct and has not been revoked. "
"You can verify your keys at https://platform.openai.com/api-keys."
)
except openai.APIConnectionError:
if base_url:
raise ValueError(
f"Could not connect to the OpenAI-compatible API at {base_url}. "
"Please verify that the base_url is correct and reachable, and try again."
)
raise ValueError(
"Could not connect to the OpenAI API. Please check your network connection "
"and try again."
)
except openai.APIError:
if base_url:
raise ValueError(
f"The OpenAI-compatible API at {base_url} returned an error while "
"validating the API key. Please verify that the base_url is correct, "
"the service is available, and the API key is valid."
)
raise ValueError(
"The OpenAI API returned an error while validating the API key. "
"Please try again later."
)
except Exception:
if base_url:
raise ValueError(
f"Failed to validate the OpenAI API key using the API at {base_url}. "
"Please verify that the base_url is correct and reachable, and that the "
"API key is valid."
)
raise ValueError(
"Failed to validate the OpenAI API key. Please try again later."
)
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
try:
@ -212,7 +286,11 @@ class UserConfigurationValidator:
deepgram.manage.v1.projects.list()
return True
except Exception:
return False
raise ValueError(
"Invalid Deepgram API key. The key was rejected by the Deepgram API. "
"Please check that your API key is correct and active. "
"You can verify your keys at https://console.deepgram.com/."
)
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
client = Groq(api_key=api_key)
@ -220,7 +298,11 @@ class UserConfigurationValidator:
client.models.list()
return True
except Exception:
return False
raise ValueError(
"Invalid Groq API key. The key was rejected by the Groq API. "
"Please check that your API key is correct and active. "
"You can verify your keys at https://console.groq.com/keys."
)
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
return True

View file

@ -290,6 +290,10 @@ class OpenAILLMService(BaseLLMConfiguration):
description="OpenAI chat model to use.",
json_schema_extra={"examples": OPENAI_MODELS, "allow_custom_input": True},
)
base_url: str = Field(
default="https://api.openai.com/v1",
description="Override only if using an OpenAI-compatible API (e.g. local LLM, proxy).",
)
@register_llm

View file

@ -11,6 +11,7 @@ from loguru import logger
from openai import AsyncOpenAI
from api.db.db_client import DBClient
from api.utils.url_security import validate_user_configured_service_url
from .base import BaseEmbeddingService
@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
if self._api_key_configured:
client_kwargs = {"api_key": api_key}
if base_url:
validate_user_configured_service_url(
base_url,
field_name="base_url",
)
client_kwargs["base_url"] = base_url
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")

View file

@ -7,6 +7,7 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
from api.utils.url_security import validate_user_configured_service_url
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
@ -62,6 +63,16 @@ if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
def _validate_runtime_service_url(url: str, field_name: str) -> None:
try:
validate_user_configured_service_url(
url,
field_name=field_name,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
):
@ -174,6 +185,7 @@ def create_stt_service(
)
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
language = getattr(user_config.stt, "language", None)
_validate_runtime_service_url(user_config.stt.base_url, "base_url")
return SpeachesSTTService(
base_url=user_config.stt.base_url,
api_key=user_config.stt.api_key or "none",
@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
# ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP
# scheme (matching ElevenLabs documentation, e.g.
# https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme.
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
tts._settings.language = language
return tts
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
return SpeachesTTSService(
base_url=user_config.tts.base_url,
api_key=user_config.tts.api_key or "none",
@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
).rstrip("/")
if not base_url.endswith("/t2a_v2"):
base_url = f"{base_url}/t2a_v2"
_validate_runtime_service_url(base_url, "base_url")
session = aiohttp.ClientSession()
return MiniMaxOwnedSessionTTSService(
@ -504,6 +519,10 @@ def create_llm_service_from_provider(
"""
logger.info(f"Creating LLM service: provider={provider}, model={model}")
if provider == ServiceProviders.OPENAI.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
if "gpt-5" in model:
return OpenAILLMService(
api_key=api_key,
@ -511,10 +530,12 @@ def create_llm_service_from_provider(
model=model,
extra={"reasoning_effort": "minimal", "verbosity": "low"},
),
**kwargs,
)
return OpenAILLMService(
api_key=api_key,
settings=OpenAILLMSettings(model=model, temperature=0.1),
**kwargs,
)
elif provider == ServiceProviders.GROQ.value:
return GroqLLMService(
@ -524,6 +545,7 @@ def create_llm_service_from_provider(
elif provider == ServiceProviders.OPENROUTER.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
return OpenRouterLLMService(
api_key=api_key,
@ -543,6 +565,8 @@ def create_llm_service_from_provider(
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.AZURE.value:
if endpoint:
_validate_runtime_service_url(endpoint, "endpoint")
return AzureLLMService(
api_key=api_key,
endpoint=endpoint,
@ -562,15 +586,19 @@ def create_llm_service_from_provider(
settings=AWSBedrockLLMSettings(model=model),
)
elif provider == ServiceProviders.SPEACHES.value:
base_url = base_url or "http://localhost:11434/v1"
_validate_runtime_service_url(base_url, "base_url")
return SpeachesLLMService(
base_url=base_url or "http://localhost:11434/v1",
base_url=base_url,
api_key=api_key or "none",
settings=SpeachesLLMSettings(model=model),
)
elif provider == ServiceProviders.MINIMAX.value:
base_url = base_url or "https://api.minimax.io/v1"
_validate_runtime_service_url(base_url, "base_url")
return MiniMaxLLMService(
api_key=api_key,
base_url=base_url or "https://api.minimax.io/v1",
base_url=base_url,
settings=MiniMaxLLMService.Settings(
model=model,
temperature=temperature if temperature is not None else 1.0,
@ -709,7 +737,9 @@ def create_llm_service(user_config):
api_key = user_config.llm.api_key
kwargs = {}
if provider == ServiceProviders.OPENROUTER.value:
if provider == ServiceProviders.OPENAI.value:
kwargs["base_url"] = user_config.llm.base_url
elif provider == ServiceProviders.OPENROUTER.value:
kwargs["base_url"] = user_config.llm.base_url
elif provider == ServiceProviders.AZURE.value:
kwargs["endpoint"] = user_config.llm.endpoint

View file

@ -0,0 +1,227 @@
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.registry import (
ServiceProviders,
SpeachesLLMConfiguration,
)
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
from api.services.pipecat.service_factory import (
create_llm_service_from_provider,
create_tts_service,
)
from api.utils.url_security import validate_user_configured_service_url
def test_oss_allows_local_service_urls(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
validate_user_configured_service_url(
"http://localhost:11434/v1",
field_name="base_url",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:11434/v1",
"http://127.0.0.1:11434/v1",
"http://10.0.0.10/v1",
"http://169.254.169.254/latest/meta-data",
"http://100.64.0.1/v1",
],
)
def test_saas_blocks_local_and_internal_service_urls(url, monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError):
validate_user_configured_service_url(
url,
field_name="base_url",
)
def test_saas_rejects_unsupported_service_url_schemes(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError, match="http, https, ws, or wss"):
validate_user_configured_service_url(
"file:///etc/passwd",
field_name="base_url",
)
def test_saas_checks_resolved_hostname_ips(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("10.0.0.10", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
with pytest.raises(ValueError, match="public IP"):
validate_user_configured_service_url(
"https://internal.example.com/v1",
field_name="base_url",
)
def test_saas_allows_public_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("8.8.8.8", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
validate_user_configured_service_url(
"https://api.example.com/v1",
field_name="base_url",
)
def test_saas_allows_public_websocket_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("8.8.8.8", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
validate_user_configured_service_url(
"wss://api.example.com/v1",
field_name="base_url",
)
def test_saas_blocks_local_websocket_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError, match="localhost"):
validate_user_configured_service_url(
"ws://localhost:8000/v1",
field_name="base_url",
)
def test_validator_blocks_speaches_local_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
validator = UserConfigurationValidator()
config = SpeachesLLMConfiguration()
result = validator._validate_service(config, "llm")
assert result == [
{
"model": "llm",
"message": "base_url cannot point to localhost in SaaS mode",
}
]
def test_validator_blocks_azure_private_endpoint_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
validator = UserConfigurationValidator()
config = SimpleNamespace(
provider=ServiceProviders.AZURE.value,
endpoint="http://10.0.0.10/openai",
api_key="test-key",
)
result = validator._validate_service(config, "llm")
assert result == [
{
"model": "llm",
"message": "endpoint must resolve to a public IP address in SaaS mode",
}
]
def test_validator_allows_speaches_local_base_url_in_oss(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
validator = UserConfigurationValidator()
config = SpeachesLLMConfiguration()
assert validator._validate_service(config, "llm") == []
def test_runtime_blocks_speaches_default_llm_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.SPEACHES.value,
model="llama3",
api_key=None,
)
assert exc_info.value.status_code == 400
assert "localhost" in exc_info.value.detail
def test_runtime_blocks_openai_private_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.OPENAI.value,
model="gpt-4.1",
api_key="test-key",
base_url="http://10.0.0.10/v1",
)
assert exc_info.value.status_code == 400
assert "public IP" in exc_info.value.detail
def test_runtime_blocks_azure_private_endpoint_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.AZURE.value,
model="gpt-4.1-mini",
api_key="test-key",
endpoint="http://10.0.0.10/openai",
)
assert exc_info.value.status_code == 400
assert "public IP" in exc_info.value.detail
def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
user_config = SimpleNamespace(
tts=SimpleNamespace(
provider=ServiceProviders.ELEVENLABS.value,
api_key="test-key",
model="eleven_flash_v2_5",
voice="voice-id",
speed=1.0,
base_url="http://localhost:8000",
)
)
with pytest.raises(HTTPException) as exc_info:
create_tts_service(user_config, audio_config=None)
assert exc_info.value.status_code == 400
assert "localhost" in exc_info.value.detail
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
monkeypatch.setattr(
"api.utils.url_security.DEPLOYMENT_MODE", "saas"
)
with pytest.raises(ValueError, match="public IP"):
OpenAIEmbeddingService(
db_client=SimpleNamespace(),
api_key="test-key",
base_url="http://10.0.0.10/v1",
)

66
api/utils/url_security.py Normal file
View file

@ -0,0 +1,66 @@
import ipaddress
import socket
from urllib.parse import urlparse
from api.constants import DEPLOYMENT_MODE
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
def validate_user_configured_service_url(
url: str,
*,
field_name: str,
) -> None:
"""Restrict user-configured service URLs in hosted deployments.
OSS deployments commonly point model services at localhost or private LAN
hosts. SaaS deployments must not allow users to make Dograh infrastructure
connect to private/internal network locations.
"""
if DEPLOYMENT_MODE == "oss":
return
parsed = urlparse(url)
if parsed.scheme not in {"http", "https", "ws", "wss"} or not parsed.hostname:
raise ValueError(f"{field_name} must be an http, https, ws, or wss URL")
hostname = parsed.hostname
if hostname.lower() == "localhost":
raise ValueError(f"{field_name} cannot point to localhost in SaaS mode")
for ip in _resolve_hostname_ips(hostname, parsed.port):
if _is_blocked_saas_service_ip(ip):
raise ValueError(
f"{field_name} must resolve to a public IP address in SaaS mode"
)
def _resolve_hostname_ips(
hostname: str, port: int | None
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
try:
return [ipaddress.ip_address(hostname)]
except ValueError:
pass
try:
addr_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as e:
raise ValueError("Could not resolve service URL hostname") from e
return [ipaddress.ip_address(addr_info[4][0]) for addr_info in addr_infos]
def _is_blocked_saas_service_ip(
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
) -> bool:
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
or (ip.version == 4 and ip in _CGNAT_NETWORK)
)