mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: harden the base url settings in SaaS mode
This commit is contained in:
parent
88d6ac425b
commit
c7b5ee1ae2
5 changed files with 339 additions and 3 deletions
|
|
@ -13,6 +13,7 @@ from api.schemas.user_configuration import (
|
|||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
AuthContext = TypedDict(
|
||||
"AuthContext",
|
||||
|
|
@ -107,6 +108,17 @@ class UserConfigurationValidator:
|
|||
|
||||
provider = service_config.provider
|
||||
|
||||
for url_field in ("base_url", "endpoint"):
|
||||
url = getattr(service_config, url_field, None)
|
||||
if url:
|
||||
try:
|
||||
validate_user_configured_service_url(
|
||||
url,
|
||||
field_name=url_field,
|
||||
)
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
|
||||
# Speaches doesn't require an API key
|
||||
if provider == ServiceProviders.SPEACHES.value:
|
||||
try:
|
||||
|
|
@ -197,7 +209,10 @@ class UserConfigurationValidator:
|
|||
return []
|
||||
|
||||
def _check_api_key(
|
||||
self, provider: str, api_key: str, service_config: Optional[ServiceConfig] = None
|
||||
self,
|
||||
provider: str,
|
||||
api_key: str,
|
||||
service_config: Optional[ServiceConfig] = None,
|
||||
) -> bool:
|
||||
"""Check if an API key for a provider is valid."""
|
||||
validator = self._validator_map.get(provider)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from loguru import logger
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
from .base import BaseEmbeddingService
|
||||
|
||||
|
|
@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
|
|||
if self._api_key_configured:
|
||||
client_kwargs = {"api_key": api_key}
|
||||
if base_url:
|
||||
validate_user_configured_service_url(
|
||||
base_url,
|
||||
field_name="base_url",
|
||||
)
|
||||
client_kwargs["base_url"] = base_url
|
||||
self.client = AsyncOpenAI(**client_kwargs)
|
||||
logger.info(f"OpenAI embedding service initialized with model: {model_id}")
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from loguru import logger
|
|||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
|
||||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
|
|
@ -62,6 +63,16 @@ if TYPE_CHECKING:
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
|
||||
def _validate_runtime_service_url(url: str, field_name: str) -> None:
|
||||
try:
|
||||
validate_user_configured_service_url(
|
||||
url,
|
||||
field_name=field_name,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e)) from e
|
||||
|
||||
|
||||
def create_stt_service(
|
||||
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
|
||||
):
|
||||
|
|
@ -174,6 +185,7 @@ def create_stt_service(
|
|||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
_validate_runtime_service_url(user_config.stt.base_url, "base_url")
|
||||
return SpeachesSTTService(
|
||||
base_url=user_config.stt.base_url,
|
||||
api_key=user_config.stt.api_key or "none",
|
||||
|
|
@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
# ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP
|
||||
# scheme (matching ElevenLabs documentation, e.g.
|
||||
# https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme.
|
||||
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
|
||||
elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace(
|
||||
"http://", "ws://"
|
||||
)
|
||||
|
|
@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
tts._settings.language = language
|
||||
return tts
|
||||
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
|
||||
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
|
||||
return SpeachesTTSService(
|
||||
base_url=user_config.tts.base_url,
|
||||
api_key=user_config.tts.api_key or "none",
|
||||
|
|
@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
).rstrip("/")
|
||||
if not base_url.endswith("/t2a_v2"):
|
||||
base_url = f"{base_url}/t2a_v2"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
|
||||
session = aiohttp.ClientSession()
|
||||
return MiniMaxOwnedSessionTTSService(
|
||||
|
|
@ -506,6 +521,7 @@ def create_llm_service_from_provider(
|
|||
if provider == ServiceProviders.OPENAI.value:
|
||||
kwargs = {}
|
||||
if base_url:
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
kwargs["base_url"] = base_url
|
||||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
|
|
@ -529,6 +545,7 @@ def create_llm_service_from_provider(
|
|||
elif provider == ServiceProviders.OPENROUTER.value:
|
||||
kwargs = {}
|
||||
if base_url:
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
kwargs["base_url"] = base_url
|
||||
return OpenRouterLLMService(
|
||||
api_key=api_key,
|
||||
|
|
@ -548,6 +565,8 @@ def create_llm_service_from_provider(
|
|||
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
if endpoint:
|
||||
_validate_runtime_service_url(endpoint, "endpoint")
|
||||
return AzureLLMService(
|
||||
api_key=api_key,
|
||||
endpoint=endpoint,
|
||||
|
|
@ -567,15 +586,19 @@ def create_llm_service_from_provider(
|
|||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
base_url = base_url or "http://localhost:11434/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return SpeachesLLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
base_url=base_url,
|
||||
api_key=api_key or "none",
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
base_url = base_url or "https://api.minimax.io/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return MiniMaxLLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url or "https://api.minimax.io/v1",
|
||||
base_url=base_url,
|
||||
settings=MiniMaxLLMService.Settings(
|
||||
model=model,
|
||||
temperature=temperature if temperature is not None else 1.0,
|
||||
|
|
|
|||
227
api/tests/test_user_configured_service_url_security.py
Normal file
227
api/tests/test_user_configured_service_url_security.py
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
ServiceProviders,
|
||||
SpeachesLLMConfiguration,
|
||||
)
|
||||
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service_from_provider,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.utils.url_security import validate_user_configured_service_url
|
||||
|
||||
|
||||
def test_oss_allows_local_service_urls(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
|
||||
|
||||
validate_user_configured_service_url(
|
||||
"http://localhost:11434/v1",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"url",
|
||||
[
|
||||
"http://localhost:11434/v1",
|
||||
"http://127.0.0.1:11434/v1",
|
||||
"http://10.0.0.10/v1",
|
||||
"http://169.254.169.254/latest/meta-data",
|
||||
"http://100.64.0.1/v1",
|
||||
],
|
||||
)
|
||||
def test_saas_blocks_local_and_internal_service_urls(url, monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
validate_user_configured_service_url(
|
||||
url,
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_saas_rejects_unsupported_service_url_schemes(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(ValueError, match="http, https, ws, or wss"):
|
||||
validate_user_configured_service_url(
|
||||
"file:///etc/passwd",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_saas_checks_resolved_hostname_ips(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
def fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(None, None, None, None, ("10.0.0.10", 443))]
|
||||
|
||||
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
with pytest.raises(ValueError, match="public IP"):
|
||||
validate_user_configured_service_url(
|
||||
"https://internal.example.com/v1",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_saas_allows_public_service_url(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
def fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(None, None, None, None, ("8.8.8.8", 443))]
|
||||
|
||||
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
validate_user_configured_service_url(
|
||||
"https://api.example.com/v1",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_saas_allows_public_websocket_service_url(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
def fake_getaddrinfo(*_args, **_kwargs):
|
||||
return [(None, None, None, None, ("8.8.8.8", 443))]
|
||||
|
||||
monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo)
|
||||
|
||||
validate_user_configured_service_url(
|
||||
"wss://api.example.com/v1",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_saas_blocks_local_websocket_service_url(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(ValueError, match="localhost"):
|
||||
validate_user_configured_service_url(
|
||||
"ws://localhost:8000/v1",
|
||||
field_name="base_url",
|
||||
)
|
||||
|
||||
|
||||
def test_validator_blocks_speaches_local_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
validator = UserConfigurationValidator()
|
||||
config = SpeachesLLMConfiguration()
|
||||
|
||||
result = validator._validate_service(config, "llm")
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"model": "llm",
|
||||
"message": "base_url cannot point to localhost in SaaS mode",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_validator_blocks_azure_private_endpoint_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
validator = UserConfigurationValidator()
|
||||
config = SimpleNamespace(
|
||||
provider=ServiceProviders.AZURE.value,
|
||||
endpoint="http://10.0.0.10/openai",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
result = validator._validate_service(config, "llm")
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"model": "llm",
|
||||
"message": "endpoint must resolve to a public IP address in SaaS mode",
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def test_validator_allows_speaches_local_base_url_in_oss(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss")
|
||||
validator = UserConfigurationValidator()
|
||||
config = SpeachesLLMConfiguration()
|
||||
|
||||
assert validator._validate_service(config, "llm") == []
|
||||
|
||||
|
||||
def test_runtime_blocks_speaches_default_llm_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.SPEACHES.value,
|
||||
model="llama3",
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "localhost" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_openai_private_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
model="gpt-4.1",
|
||||
api_key="test-key",
|
||||
base_url="http://10.0.0.10/v1",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "public IP" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_azure_private_endpoint_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_llm_service_from_provider(
|
||||
provider=ServiceProviders.AZURE.value,
|
||||
model="gpt-4.1-mini",
|
||||
api_key="test-key",
|
||||
endpoint="http://10.0.0.10/openai",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "public IP" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.ELEVENLABS.value,
|
||||
api_key="test-key",
|
||||
model="eleven_flash_v2_5",
|
||||
voice="voice-id",
|
||||
speed=1.0,
|
||||
base_url="http://localhost:8000",
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
create_tts_service(user_config, audio_config=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "localhost" in exc_info.value.detail
|
||||
|
||||
|
||||
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"api.utils.url_security.DEPLOYMENT_MODE", "saas"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="public IP"):
|
||||
OpenAIEmbeddingService(
|
||||
db_client=SimpleNamespace(),
|
||||
api_key="test-key",
|
||||
base_url="http://10.0.0.10/v1",
|
||||
)
|
||||
66
api/utils/url_security.py
Normal file
66
api/utils/url_security.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
|
||||
_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10")
|
||||
|
||||
|
||||
def validate_user_configured_service_url(
|
||||
url: str,
|
||||
*,
|
||||
field_name: str,
|
||||
) -> None:
|
||||
"""Restrict user-configured service URLs in hosted deployments.
|
||||
|
||||
OSS deployments commonly point model services at localhost or private LAN
|
||||
hosts. SaaS deployments must not allow users to make Dograh infrastructure
|
||||
connect to private/internal network locations.
|
||||
"""
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
return
|
||||
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https", "ws", "wss"} or not parsed.hostname:
|
||||
raise ValueError(f"{field_name} must be an http, https, ws, or wss URL")
|
||||
|
||||
hostname = parsed.hostname
|
||||
if hostname.lower() == "localhost":
|
||||
raise ValueError(f"{field_name} cannot point to localhost in SaaS mode")
|
||||
|
||||
for ip in _resolve_hostname_ips(hostname, parsed.port):
|
||||
if _is_blocked_saas_service_ip(ip):
|
||||
raise ValueError(
|
||||
f"{field_name} must resolve to a public IP address in SaaS mode"
|
||||
)
|
||||
|
||||
|
||||
def _resolve_hostname_ips(
|
||||
hostname: str, port: int | None
|
||||
) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]:
|
||||
try:
|
||||
return [ipaddress.ip_address(hostname)]
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
addr_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM)
|
||||
except socket.gaierror as e:
|
||||
raise ValueError("Could not resolve service URL hostname") from e
|
||||
|
||||
return [ipaddress.ip_address(addr_info[4][0]) for addr_info in addr_infos]
|
||||
|
||||
|
||||
def _is_blocked_saas_service_ip(
|
||||
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
|
||||
) -> bool:
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
or ip.is_unspecified
|
||||
or (ip.version == 4 and ip in _CGNAT_NETWORK)
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue