From 8a58b0992d588c199f6ee1f77d959efc16a2a97c Mon Sep 17 00:00:00 2001 From: Abhishek Date: Wed, 27 May 2026 13:07:45 +0530 Subject: [PATCH] feat: allow overriding base URL of OpenAI models (#368) * Add OpenAI-compatible API option in model configuration Backend-only cherry-pick from 20617db37a8417e4ee4f64efb6063fc5cd4aea98. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: harden the base url settings in SaaS mode --------- Co-authored-by: Chris Briddock Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- api/services/configuration/check_validity.py | 98 +++++++- api/services/configuration/registry.py | 4 + .../gen_ai/embedding/openai_service.py | 5 + api/services/pipecat/service_factory.py | 36 ++- ...st_user_configured_service_url_security.py | 227 ++++++++++++++++++ api/utils/url_security.py | 66 +++++ 6 files changed, 425 insertions(+), 11 deletions(-) create mode 100644 api/tests/test_user_configured_service_url_security.py create mode 100644 api/utils/url_security.py diff --git a/api/services/configuration/check_validity.py b/api/services/configuration/check_validity.py index 721884b..3a76147 100644 --- a/api/services/configuration/check_validity.py +++ b/api/services/configuration/check_validity.py @@ -13,6 +13,7 @@ from api.schemas.user_configuration import ( ) from api.services.configuration.registry import ServiceConfig, ServiceProviders from api.services.mps_service_key_client import mps_service_key_client +from api.utils.url_security import validate_user_configured_service_url AuthContext = TypedDict( "AuthContext", @@ -107,6 +108,17 @@ class UserConfigurationValidator: provider = service_config.provider + for url_field in ("base_url", "endpoint"): + url = getattr(service_config, url_field, None) + if url: + try: + validate_user_configured_service_url( + url, + field_name=url_field, + ) + except ValueError as e: + return [{"model": service_name, "message": str(e)}] + # Speaches doesn't require an API key if provider == ServiceProviders.SPEACHES.value: try: @@ -181,30 +193,92 @@ class UserConfigurationValidator: api_key = service_config.api_key try: - if not self._check_api_key(provider, api_key): + if not self._check_api_key(provider, api_key, service_config): return [ - {"model": service_name, "message": f"Invalid {provider} API key"} + { + "model": service_name, + "message": ( + f"Invalid {provider} API key. Please verify your API key is " + f"correct, has not expired, and has the required permissions." + ), + } ] except ValueError as e: return [{"model": service_name, "message": str(e)}] return [] - def _check_api_key(self, provider: str, api_key: str) -> bool: + def _check_api_key( + self, + provider: str, + api_key: str, + service_config: Optional[ServiceConfig] = None, + ) -> bool: """Check if an API key for a provider is valid.""" validator = self._validator_map.get(provider) if not validator: return False + if provider in ( + ServiceProviders.OPENAI.value, + ServiceProviders.OPENAI_REALTIME.value, + ): + return validator(provider, api_key, service_config) return validator(provider, api_key) - def _check_openai_api_key(self, model: str, api_key: str) -> bool: - client = openai.OpenAI(api_key=api_key) + def _check_openai_api_key( + self, model: str, api_key: str, service_config: Optional[ServiceConfig] = None + ) -> bool: + client_kwargs: dict[str, str] = {"api_key": api_key} + base_url = getattr(service_config, "base_url", None) if service_config else None + if base_url: + client_kwargs["base_url"] = base_url + client = openai.OpenAI(**client_kwargs) try: client.models.list() return True except openai.AuthenticationError: - return False + if base_url and "openai.com" not in base_url: + raise ValueError( + f"Invalid OpenAI API key. The key was rejected by the API at {base_url}. " + "Please check that your API key is correct and has not been revoked." + ) + raise ValueError( + "Invalid OpenAI API key. The key was rejected by the OpenAI API. " + "Please check that your API key is correct and has not been revoked. " + "You can verify your keys at https://platform.openai.com/api-keys." + ) + except openai.APIConnectionError: + if base_url: + raise ValueError( + f"Could not connect to the OpenAI-compatible API at {base_url}. " + "Please verify that the base_url is correct and reachable, and try again." + ) + raise ValueError( + "Could not connect to the OpenAI API. Please check your network connection " + "and try again." + ) + except openai.APIError: + if base_url: + raise ValueError( + f"The OpenAI-compatible API at {base_url} returned an error while " + "validating the API key. Please verify that the base_url is correct, " + "the service is available, and the API key is valid." + ) + raise ValueError( + "The OpenAI API returned an error while validating the API key. " + "Please try again later." + ) + except Exception: + if base_url: + raise ValueError( + f"Failed to validate the OpenAI API key using the API at {base_url}. " + "Please verify that the base_url is correct and reachable, and that the " + "API key is valid." + ) + raise ValueError( + "Failed to validate the OpenAI API key. Please try again later." + ) def _check_deepgram_api_key(self, model: str, api_key: str) -> bool: try: @@ -212,7 +286,11 @@ class UserConfigurationValidator: deepgram.manage.v1.projects.list() return True except Exception: - return False + raise ValueError( + "Invalid Deepgram API key. The key was rejected by the Deepgram API. " + "Please check that your API key is correct and active. " + "You can verify your keys at https://console.deepgram.com/." + ) def _check_groq_api_key(self, model: str, api_key: str) -> bool: client = Groq(api_key=api_key) @@ -220,7 +298,11 @@ class UserConfigurationValidator: client.models.list() return True except Exception: - return False + raise ValueError( + "Invalid Groq API key. The key was rejected by the Groq API. " + "Please check that your API key is correct and active. " + "You can verify your keys at https://console.groq.com/keys." + ) def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool: return True diff --git a/api/services/configuration/registry.py b/api/services/configuration/registry.py index e60db18..498a6fc 100644 --- a/api/services/configuration/registry.py +++ b/api/services/configuration/registry.py @@ -290,6 +290,10 @@ class OpenAILLMService(BaseLLMConfiguration): description="OpenAI chat model to use.", json_schema_extra={"examples": OPENAI_MODELS, "allow_custom_input": True}, ) + base_url: str = Field( + default="https://api.openai.com/v1", + description="Override only if using an OpenAI-compatible API (e.g. local LLM, proxy).", + ) @register_llm diff --git a/api/services/gen_ai/embedding/openai_service.py b/api/services/gen_ai/embedding/openai_service.py index 2b54644..da5d3d4 100644 --- a/api/services/gen_ai/embedding/openai_service.py +++ b/api/services/gen_ai/embedding/openai_service.py @@ -11,6 +11,7 @@ from loguru import logger from openai import AsyncOpenAI from api.db.db_client import DBClient +from api.utils.url_security import validate_user_configured_service_url from .base import BaseEmbeddingService @@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService): if self._api_key_configured: client_kwargs = {"api_key": api_key} if base_url: + validate_user_configured_service_url( + base_url, + field_name="base_url", + ) client_kwargs["base_url"] = base_url self.client = AsyncOpenAI(**client_kwargs) logger.info(f"OpenAI embedding service initialized with model: {model_id}") diff --git a/api/services/pipecat/service_factory.py b/api/services/pipecat/service_factory.py index ad5c357..1c796e4 100644 --- a/api/services/pipecat/service_factory.py +++ b/api/services/pipecat/service_factory.py @@ -7,6 +7,7 @@ from loguru import logger from api.constants import MPS_API_URL from api.services.configuration.registry import ServiceProviders from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService +from api.utils.url_security import validate_user_configured_service_url from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings @@ -62,6 +63,16 @@ if TYPE_CHECKING: from api.services.pipecat.audio_config import AudioConfig +def _validate_runtime_service_url(url: str, field_name: str) -> None: + try: + validate_user_configured_service_url( + url, + field_name=field_name, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) from e + + def create_stt_service( user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None ): @@ -174,6 +185,7 @@ def create_stt_service( ) elif user_config.stt.provider == ServiceProviders.SPEACHES.value: language = getattr(user_config.stt, "language", None) + _validate_runtime_service_url(user_config.stt.base_url, "base_url") return SpeachesSTTService( base_url=user_config.stt.base_url, api_key=user_config.stt.api_key or "none", @@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): # ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP # scheme (matching ElevenLabs documentation, e.g. # https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme. + _validate_runtime_service_url(user_config.tts.base_url, "base_url") elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace( "http://", "ws://" ) @@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): tts._settings.language = language return tts elif user_config.tts.provider == ServiceProviders.SPEACHES.value: + _validate_runtime_service_url(user_config.tts.base_url, "base_url") return SpeachesTTSService( base_url=user_config.tts.base_url, api_key=user_config.tts.api_key or "none", @@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"): ).rstrip("/") if not base_url.endswith("/t2a_v2"): base_url = f"{base_url}/t2a_v2" + _validate_runtime_service_url(base_url, "base_url") session = aiohttp.ClientSession() return MiniMaxOwnedSessionTTSService( @@ -504,6 +519,10 @@ def create_llm_service_from_provider( """ logger.info(f"Creating LLM service: provider={provider}, model={model}") if provider == ServiceProviders.OPENAI.value: + kwargs = {} + if base_url: + _validate_runtime_service_url(base_url, "base_url") + kwargs["base_url"] = base_url if "gpt-5" in model: return OpenAILLMService( api_key=api_key, @@ -511,10 +530,12 @@ def create_llm_service_from_provider( model=model, extra={"reasoning_effort": "minimal", "verbosity": "low"}, ), + **kwargs, ) return OpenAILLMService( api_key=api_key, settings=OpenAILLMSettings(model=model, temperature=0.1), + **kwargs, ) elif provider == ServiceProviders.GROQ.value: return GroqLLMService( @@ -524,6 +545,7 @@ def create_llm_service_from_provider( elif provider == ServiceProviders.OPENROUTER.value: kwargs = {} if base_url: + _validate_runtime_service_url(base_url, "base_url") kwargs["base_url"] = base_url return OpenRouterLLMService( api_key=api_key, @@ -543,6 +565,8 @@ def create_llm_service_from_provider( settings=GoogleVertexLLMSettings(model=model, temperature=0.1), ) elif provider == ServiceProviders.AZURE.value: + if endpoint: + _validate_runtime_service_url(endpoint, "endpoint") return AzureLLMService( api_key=api_key, endpoint=endpoint, @@ -562,15 +586,19 @@ def create_llm_service_from_provider( settings=AWSBedrockLLMSettings(model=model), ) elif provider == ServiceProviders.SPEACHES.value: + base_url = base_url or "http://localhost:11434/v1" + _validate_runtime_service_url(base_url, "base_url") return SpeachesLLMService( - base_url=base_url or "http://localhost:11434/v1", + base_url=base_url, api_key=api_key or "none", settings=SpeachesLLMSettings(model=model), ) elif provider == ServiceProviders.MINIMAX.value: + base_url = base_url or "https://api.minimax.io/v1" + _validate_runtime_service_url(base_url, "base_url") return MiniMaxLLMService( api_key=api_key, - base_url=base_url or "https://api.minimax.io/v1", + base_url=base_url, settings=MiniMaxLLMService.Settings( model=model, temperature=temperature if temperature is not None else 1.0, @@ -709,7 +737,9 @@ def create_llm_service(user_config): api_key = user_config.llm.api_key kwargs = {} - if provider == ServiceProviders.OPENROUTER.value: + if provider == ServiceProviders.OPENAI.value: + kwargs["base_url"] = user_config.llm.base_url + elif provider == ServiceProviders.OPENROUTER.value: kwargs["base_url"] = user_config.llm.base_url elif provider == ServiceProviders.AZURE.value: kwargs["endpoint"] = user_config.llm.endpoint diff --git a/api/tests/test_user_configured_service_url_security.py b/api/tests/test_user_configured_service_url_security.py new file mode 100644 index 0000000..ecfc1c3 --- /dev/null +++ b/api/tests/test_user_configured_service_url_security.py @@ -0,0 +1,227 @@ +from types import SimpleNamespace + +import pytest +from fastapi import HTTPException + +from api.services.configuration.check_validity import UserConfigurationValidator +from api.services.configuration.registry import ( + ServiceProviders, + SpeachesLLMConfiguration, +) +from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService +from api.services.pipecat.service_factory import ( + create_llm_service_from_provider, + create_tts_service, +) +from api.utils.url_security import validate_user_configured_service_url + + +def test_oss_allows_local_service_urls(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss") + + validate_user_configured_service_url( + "http://localhost:11434/v1", + field_name="base_url", + ) + + +@pytest.mark.parametrize( + "url", + [ + "http://localhost:11434/v1", + "http://127.0.0.1:11434/v1", + "http://10.0.0.10/v1", + "http://169.254.169.254/latest/meta-data", + "http://100.64.0.1/v1", + ], +) +def test_saas_blocks_local_and_internal_service_urls(url, monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(ValueError): + validate_user_configured_service_url( + url, + field_name="base_url", + ) + + +def test_saas_rejects_unsupported_service_url_schemes(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(ValueError, match="http, https, ws, or wss"): + validate_user_configured_service_url( + "file:///etc/passwd", + field_name="base_url", + ) + + +def test_saas_checks_resolved_hostname_ips(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + def fake_getaddrinfo(*_args, **_kwargs): + return [(None, None, None, None, ("10.0.0.10", 443))] + + monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo) + + with pytest.raises(ValueError, match="public IP"): + validate_user_configured_service_url( + "https://internal.example.com/v1", + field_name="base_url", + ) + + +def test_saas_allows_public_service_url(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + def fake_getaddrinfo(*_args, **_kwargs): + return [(None, None, None, None, ("8.8.8.8", 443))] + + monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo) + + validate_user_configured_service_url( + "https://api.example.com/v1", + field_name="base_url", + ) + + +def test_saas_allows_public_websocket_service_url(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + def fake_getaddrinfo(*_args, **_kwargs): + return [(None, None, None, None, ("8.8.8.8", 443))] + + monkeypatch.setattr("api.utils.url_security.socket.getaddrinfo", fake_getaddrinfo) + + validate_user_configured_service_url( + "wss://api.example.com/v1", + field_name="base_url", + ) + + +def test_saas_blocks_local_websocket_service_url(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(ValueError, match="localhost"): + validate_user_configured_service_url( + "ws://localhost:8000/v1", + field_name="base_url", + ) + + +def test_validator_blocks_speaches_local_base_url_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + validator = UserConfigurationValidator() + config = SpeachesLLMConfiguration() + + result = validator._validate_service(config, "llm") + + assert result == [ + { + "model": "llm", + "message": "base_url cannot point to localhost in SaaS mode", + } + ] + + +def test_validator_blocks_azure_private_endpoint_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + validator = UserConfigurationValidator() + config = SimpleNamespace( + provider=ServiceProviders.AZURE.value, + endpoint="http://10.0.0.10/openai", + api_key="test-key", + ) + + result = validator._validate_service(config, "llm") + + assert result == [ + { + "model": "llm", + "message": "endpoint must resolve to a public IP address in SaaS mode", + } + ] + + +def test_validator_allows_speaches_local_base_url_in_oss(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "oss") + validator = UserConfigurationValidator() + config = SpeachesLLMConfiguration() + + assert validator._validate_service(config, "llm") == [] + + +def test_runtime_blocks_speaches_default_llm_base_url_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(HTTPException) as exc_info: + create_llm_service_from_provider( + provider=ServiceProviders.SPEACHES.value, + model="llama3", + api_key=None, + ) + + assert exc_info.value.status_code == 400 + assert "localhost" in exc_info.value.detail + + +def test_runtime_blocks_openai_private_base_url_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(HTTPException) as exc_info: + create_llm_service_from_provider( + provider=ServiceProviders.OPENAI.value, + model="gpt-4.1", + api_key="test-key", + base_url="http://10.0.0.10/v1", + ) + + assert exc_info.value.status_code == 400 + assert "public IP" in exc_info.value.detail + + +def test_runtime_blocks_azure_private_endpoint_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + + with pytest.raises(HTTPException) as exc_info: + create_llm_service_from_provider( + provider=ServiceProviders.AZURE.value, + model="gpt-4.1-mini", + api_key="test-key", + endpoint="http://10.0.0.10/openai", + ) + + assert exc_info.value.status_code == 400 + assert "public IP" in exc_info.value.detail + + +def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch): + monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas") + user_config = SimpleNamespace( + tts=SimpleNamespace( + provider=ServiceProviders.ELEVENLABS.value, + api_key="test-key", + model="eleven_flash_v2_5", + voice="voice-id", + speed=1.0, + base_url="http://localhost:8000", + ) + ) + + with pytest.raises(HTTPException) as exc_info: + create_tts_service(user_config, audio_config=None) + + assert exc_info.value.status_code == 400 + assert "localhost" in exc_info.value.detail + + +def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch): + monkeypatch.setattr( + "api.utils.url_security.DEPLOYMENT_MODE", "saas" + ) + + with pytest.raises(ValueError, match="public IP"): + OpenAIEmbeddingService( + db_client=SimpleNamespace(), + api_key="test-key", + base_url="http://10.0.0.10/v1", + ) diff --git a/api/utils/url_security.py b/api/utils/url_security.py new file mode 100644 index 0000000..b2c9db9 --- /dev/null +++ b/api/utils/url_security.py @@ -0,0 +1,66 @@ +import ipaddress +import socket +from urllib.parse import urlparse + +from api.constants import DEPLOYMENT_MODE + +_CGNAT_NETWORK = ipaddress.ip_network("100.64.0.0/10") + + +def validate_user_configured_service_url( + url: str, + *, + field_name: str, +) -> None: + """Restrict user-configured service URLs in hosted deployments. + + OSS deployments commonly point model services at localhost or private LAN + hosts. SaaS deployments must not allow users to make Dograh infrastructure + connect to private/internal network locations. + """ + if DEPLOYMENT_MODE == "oss": + return + + parsed = urlparse(url) + if parsed.scheme not in {"http", "https", "ws", "wss"} or not parsed.hostname: + raise ValueError(f"{field_name} must be an http, https, ws, or wss URL") + + hostname = parsed.hostname + if hostname.lower() == "localhost": + raise ValueError(f"{field_name} cannot point to localhost in SaaS mode") + + for ip in _resolve_hostname_ips(hostname, parsed.port): + if _is_blocked_saas_service_ip(ip): + raise ValueError( + f"{field_name} must resolve to a public IP address in SaaS mode" + ) + + +def _resolve_hostname_ips( + hostname: str, port: int | None +) -> list[ipaddress.IPv4Address | ipaddress.IPv6Address]: + try: + return [ipaddress.ip_address(hostname)] + except ValueError: + pass + + try: + addr_infos = socket.getaddrinfo(hostname, port, type=socket.SOCK_STREAM) + except socket.gaierror as e: + raise ValueError("Could not resolve service URL hostname") from e + + return [ipaddress.ip_address(addr_info[4][0]) for addr_info in addr_infos] + + +def _is_blocked_saas_service_ip( + ip: ipaddress.IPv4Address | ipaddress.IPv6Address, +) -> bool: + return ( + ip.is_private + or ip.is_loopback + or ip.is_link_local + or ip.is_multicast + or ip.is_reserved + or ip.is_unspecified + or (ip.version == 4 and ip in _CGNAT_NETWORK) + )