fix: harden the base url settings in SaaS mode

This commit is contained in:
Abhishek Kumar 2026-05-27 13:04:27 +05:30
parent 88d6ac425b
commit c7b5ee1ae2
5 changed files with 339 additions and 3 deletions

View file

@ -13,6 +13,7 @@ from api.schemas.user_configuration import (
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
from api.utils.url_security import validate_user_configured_service_url
AuthContext = TypedDict(
"AuthContext",
@ -107,6 +108,17 @@ class UserConfigurationValidator:
provider = service_config.provider
for url_field in ("base_url", "endpoint"):
url = getattr(service_config, url_field, None)
if url:
try:
validate_user_configured_service_url(
url,
field_name=url_field,
)
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
# Speaches doesn't require an API key
if provider == ServiceProviders.SPEACHES.value:
try:
@ -197,7 +209,10 @@ class UserConfigurationValidator:
return []
def _check_api_key(
self, provider: str, api_key: str, service_config: Optional[ServiceConfig] = None
self,
provider: str,
api_key: str,
service_config: Optional[ServiceConfig] = None,
) -> bool:
"""Check if an API key for a provider is valid."""
validator = self._validator_map.get(provider)

View file

@ -11,6 +11,7 @@ from loguru import logger
from openai import AsyncOpenAI
from api.db.db_client import DBClient
from api.utils.url_security import validate_user_configured_service_url
from .base import BaseEmbeddingService
@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
if self._api_key_configured:
client_kwargs = {"api_key": api_key}
if base_url:
validate_user_configured_service_url(
base_url,
field_name="base_url",
)
client_kwargs["base_url"] = base_url
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")

View file

@ -7,6 +7,7 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
from api.utils.url_security import validate_user_configured_service_url
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
@ -62,6 +63,16 @@ if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
def _validate_runtime_service_url(url: str, field_name: str) -> None:
try:
validate_user_configured_service_url(
url,
field_name=field_name,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
):
@ -174,6 +185,7 @@ def create_stt_service(
)
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
language = getattr(user_config.stt, "language", None)
_validate_runtime_service_url(user_config.stt.base_url, "base_url")
return SpeachesSTTService(
base_url=user_config.stt.base_url,
api_key=user_config.stt.api_key or "none",
@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
# ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP
# scheme (matching ElevenLabs documentation, e.g.
# https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme.
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
tts._settings.language = language
return tts
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
return SpeachesTTSService(
base_url=user_config.tts.base_url,
api_key=user_config.tts.api_key or "none",
@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
).rstrip("/")
if not base_url.endswith("/t2a_v2"):
base_url = f"{base_url}/t2a_v2"
_validate_runtime_service_url(base_url, "base_url")
session = aiohttp.ClientSession()
return MiniMaxOwnedSessionTTSService(
@ -506,6 +521,7 @@ def create_llm_service_from_provider(
if provider == ServiceProviders.OPENAI.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
if "gpt-5" in model:
return OpenAILLMService(
@ -529,6 +545,7 @@ def create_llm_service_from_provider(
elif provider == ServiceProviders.OPENROUTER.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
return OpenRouterLLMService(
api_key=api_key,
@ -548,6 +565,8 @@ def create_llm_service_from_provider(
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.AZURE.value:
if endpoint:
_validate_runtime_service_url(endpoint, "endpoint")
return AzureLLMService(
api_key=api_key,
endpoint=endpoint,
@ -567,15 +586,19 @@ def create_llm_service_from_provider(
settings=AWSBedrockLLMSettings(model=model),
)
elif provider == ServiceProviders.SPEACHES.value:
base_url = base_url or "http://localhost:11434/v1"
_validate_runtime_service_url(base_url, "base_url")
return SpeachesLLMService(
base_url=base_url or "http://localhost:11434/v1",
base_url=base_url,
api_key=api_key or "none",
settings=SpeachesLLMSettings(model=model),
)
elif provider == ServiceProviders.MINIMAX.value:
base_url = base_url or "https://api.minimax.io/v1"
_validate_runtime_service_url(base_url, "base_url")
return MiniMaxLLMService(
api_key=api_key,
base_url=base_url or "https://api.minimax.io/v1",
base_url=base_url,
settings=MiniMaxLLMService.Settings(
model=model,
temperature=temperature if temperature is not None else 1.0,

View file

@ -0,0 +1,227 @@
from types import SimpleNamespace
import pytest
from fastapi import HTTPException
from api.services.configuration.check_validity import UserConfigurationValidator
from api.services.configuration.registry import (
ServiceProviders,
SpeachesLLMConfiguration,
)
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
from api.services.pipecat.service_factory import (
create_llm_service_from_provider,
create_tts_service,
)
from api.utils.url_security import validate_user_configured_service_url
def test_oss_allows_local_service_urls(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
validate_user_configured_service_url(
"http://localhost:11434/v1",
field_name="base_url",
)
@pytest.mark.parametrize(
"url",
[
"http://localhost:11434/v1",
"http://127.0.0.1:11434/v1",
"http://10.0.0.10/v1",
"http://169.254.169.254/latest/meta-data",
"http://100.64.0.1/v1",
],
)
def test_saas_blocks_local_and_internal_service_urls(url, monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError):
validate_user_configured_service_url(
url,
field_name="base_url",
)
def test_saas_rejects_unsupported_service_url_schemes(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError, match="http, https, ws, or wss"):
validate_user_configured_service_url(
"file:///etc/passwd",
field_name="base_url",
)
def test_saas_checks_resolved_hostname_ips(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("10.0.0.10", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
with pytest.raises(ValueError, match="public IP"):
validate_user_configured_service_url(
"https://internal.example.com/v1",
field_name="base_url",
)
def test_saas_allows_public_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("8.8.8.8", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
validate_user_configured_service_url(
"https://api.example.com/v1",
field_name="base_url",
)
def test_saas_allows_public_websocket_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
def fake_getaddrinfo(*_args, **_kwargs):
return [(None, None, None, None, ("8.8.8.8", 443))]
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
validate_user_configured_service_url(
"wss://api.example.com/v1",
field_name="base_url",
)
def test_saas_blocks_local_websocket_service_url(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(ValueError, match="localhost"):
validate_user_configured_service_url(
"ws://localhost:8000/v1",
field_name="base_url",
)
def test_validator_blocks_speaches_local_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
validator = UserConfigurationValidator()
config = SpeachesLLMConfiguration()
result = validator._validate_service(config, "llm")
assert result == [
{
"model": "llm",
"message": "base_url cannot point to localhost in SaaS mode",
}
]
def test_validator_blocks_azure_private_endpoint_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
validator = UserConfigurationValidator()
config = SimpleNamespace(
provider=ServiceProviders.AZURE.value,
endpoint="http://10.0.0.10/openai",
api_key="test-key",
)
result = validator._validate_service(config, "llm")
assert result == [
{
"model": "llm",
"message": "endpoint must resolve to a public IP address in SaaS mode",
}
]
def test_validator_allows_speaches_local_base_url_in_oss(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
validator = UserConfigurationValidator()
config = SpeachesLLMConfiguration()
assert validator._validate_service(config, "llm") == []
def test_runtime_blocks_speaches_default_llm_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.SPEACHES.value,
model="llama3",
api_key=None,
)
assert exc_info.value.status_code == 400
assert "localhost" in exc_info.value.detail
def test_runtime_blocks_openai_private_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.OPENAI.value,
model="gpt-4.1",
api_key="test-key",
base_url="http://10.0.0.10/v1",
)
assert exc_info.value.status_code == 400
assert "public IP" in exc_info.value.detail
def test_runtime_blocks_azure_private_endpoint_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
with pytest.raises(HTTPException) as exc_info:
create_llm_service_from_provider(
provider=ServiceProviders.AZURE.value,
model="gpt-4.1-mini",
api_key="test-key",
endpoint="http://10.0.0.10/openai",
)
assert exc_info.value.status_code == 400
assert "public IP" in exc_info.value.detail
def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch):
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
user_config = SimpleNamespace(
tts=SimpleNamespace(
provider=ServiceProviders.ELEVENLABS.value,
api_key="test-key",
model="eleven_flash_v2_5",
voice="voice-id",
speed=1.0,
base_url="http://localhost:8000",
)
)
with pytest.raises(HTTPException) as exc_info:
create_tts_service(user_config, audio_config=None)
assert exc_info.value.status_code == 400
assert "localhost" in exc_info.value.detail
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
monkeypatch.setattr(
"api.utils.url_security.DEPLOYMENT_MODE", "saas"
)
with pytest.raises(ValueError, match="public IP"):
OpenAIEmbeddingService(
db_client=SimpleNamespace(),
api_key="test-key",
base_url="http://10.0.0.10/v1",
)

66
api/utils/url_security.py Normal file
View file

@ -0,0 +1,66 @@
import ipaddress
import socket
from urllib.parse import urlparse
from api.constants import DEPLOYMENT_MODE
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
def validate_user_configured_service_url(
url: str,
*,
field_name: str,
) -> None:
"""Restrict user-configured service URLs in hosted deployments.
OSS deployments commonly point model services at localhost or private LAN
hosts. SaaS deployments must not allow users to make Dograh infrastructure
connect to private/internal network locations.
"""
if DEPLOYMENT_MODE == "oss":
return
parsed = urlparse(url)
if parsed.scheme not in {"http", "https", "ws", "wss"} or not parsed.hostname:
raise ValueError(f"{field_name} must be an http, https, ws, or wss URL")
hostname = parsed.hostname
if hostname.lower() == "localhost":
raise ValueError(f"{field_name} cannot point to localhost in SaaS mode")
for ip in _resolve_hostname_ips(hostname, parsed.port):
if _is_blocked_saas_service_ip(ip):
raise ValueError(
f"{field_name} must resolve to a public IP address in SaaS mode"
)
def _resolve_hostname_ips(
hostname: str, port: int | None
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
try:
return [ipaddress.ip_address(hostname)]
except ValueError:
pass
try:
addr_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
except socket.gaierror as e:
raise ValueError("Could not resolve service URL hostname") from e
return [ipaddress.ip_address(addr_info[4][0]) for addr_info in addr_infos]
def _is_blocked_saas_service_ip(
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
) -> bool:
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
or ip.is_unspecified
or (ip.version == 4 and ip in _CGNAT_NETWORK)
)