fix: harden the base url settings in SaaS mode

This commit is contained in:
Abhishek Kumar 2026-05-27 13:04:27 +05:30
parent 88d6ac425b
commit c7b5ee1ae2
5 changed files with 339 additions and 3 deletions

View file

@ -13,6 +13,7 @@ from api.schemas.user_configuration import (
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
from api.utils.url_security import validate_user_configured_service_url
AuthContext = TypedDict(
"AuthContext",
@ -107,6 +108,17 @@ class UserConfigurationValidator:
provider = service_config.provider
for url_field in ("base_url", "endpoint"):
url = getattr(service_config, url_field, None)
if url:
try:
validate_user_configured_service_url(
url,
field_name=url_field,
)
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
# Speaches doesn't require an API key
if provider == ServiceProviders.SPEACHES.value:
try:
@ -197,7 +209,10 @@ class UserConfigurationValidator:
return []
def _check_api_key(
self, provider: str, api_key: str, service_config: Optional[ServiceConfig] = None
self,
provider: str,
api_key: str,
service_config: Optional[ServiceConfig] = None,
) -> bool:
"""Check if an API key for a provider is valid."""
validator = self._validator_map.get(provider)

View file

@ -11,6 +11,7 @@ from loguru import logger
from openai import AsyncOpenAI
from api.db.db_client import DBClient
from api.utils.url_security import validate_user_configured_service_url
from .base import BaseEmbeddingService
@ -54,6 +55,10 @@ class OpenAIEmbeddingService(BaseEmbeddingService):
if self._api_key_configured:
client_kwargs = {"api_key": api_key}
if base_url:
validate_user_configured_service_url(
base_url,
field_name="base_url",
)
client_kwargs["base_url"] = base_url
self.client = AsyncOpenAI(**client_kwargs)
logger.info(f"OpenAI embedding service initialized with model: {model_id}")

View file

@ -7,6 +7,7 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from api.services.pipecat.minimax_tts import MiniMaxOwnedSessionTTSService
from api.utils.url_security import validate_user_configured_service_url
from pipecat.services.assemblyai.stt import AssemblyAISTTService, AssemblyAISTTSettings
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
@ -62,6 +63,16 @@ if TYPE_CHECKING:
from api.services.pipecat.audio_config import AudioConfig
def _validate_runtime_service_url(url: str, field_name: str) -> None:
try:
validate_user_configured_service_url(
url,
field_name=field_name,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) from e
def create_stt_service(
user_config, audio_config: "AudioConfig", keyterms: list[str] | None = None
):
@ -174,6 +185,7 @@ def create_stt_service(
)
elif user_config.stt.provider == ServiceProviders.SPEACHES.value:
language = getattr(user_config.stt, "language", None)
_validate_runtime_service_url(user_config.stt.base_url, "base_url")
return SpeachesSTTService(
base_url=user_config.stt.base_url,
api_key=user_config.stt.api_key or "none",
@ -301,6 +313,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
# ElevenLabs TTS uses WebSocket. Users configure base_url with an HTTP
# scheme (matching ElevenLabs documentation, e.g.
# https://api.eu.residency.elevenlabs.io); rewrite it to the WS scheme.
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
elevenlabs_url = user_config.tts.base_url.replace("https://", "wss://").replace(
"http://", "ws://"
)
@ -376,6 +389,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
tts._settings.language = language
return tts
elif user_config.tts.provider == ServiceProviders.SPEACHES.value:
_validate_runtime_service_url(user_config.tts.base_url, "base_url")
return SpeachesTTSService(
base_url=user_config.tts.base_url,
api_key=user_config.tts.api_key or "none",
@ -461,6 +475,7 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
).rstrip("/")
if not base_url.endswith("/t2a_v2"):
base_url = f"{base_url}/t2a_v2"
_validate_runtime_service_url(base_url, "base_url")
session = aiohttp.ClientSession()
return MiniMaxOwnedSessionTTSService(
@ -506,6 +521,7 @@ def create_llm_service_from_provider(
if provider == ServiceProviders.OPENAI.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
if "gpt-5" in model:
return OpenAILLMService(
@ -529,6 +545,7 @@ def create_llm_service_from_provider(
elif provider == ServiceProviders.OPENROUTER.value:
kwargs = {}
if base_url:
_validate_runtime_service_url(base_url, "base_url")
kwargs["base_url"] = base_url
return OpenRouterLLMService(
api_key=api_key,
@ -548,6 +565,8 @@ def create_llm_service_from_provider(
settings=GoogleVertexLLMSettings(model=model, temperature=0.1),
)
elif provider == ServiceProviders.AZURE.value:
if endpoint:
_validate_runtime_service_url(endpoint, "endpoint")
return AzureLLMService(
api_key=api_key,
endpoint=endpoint,
@ -567,15 +586,19 @@ def create_llm_service_from_provider(
settings=AWSBedrockLLMSettings(model=model),
)
elif provider == ServiceProviders.SPEACHES.value:
base_url = base_url or "http://localhost:11434/v1"
_validate_runtime_service_url(base_url, "base_url")
return SpeachesLLMService(
base_url=base_url or "http://localhost:11434/v1",
base_url=base_url,
api_key=api_key or "none",
settings=SpeachesLLMSettings(model=model),
)
elif provider == ServiceProviders.MINIMAX.value:
base_url = base_url or "https://api.minimax.io/v1"
_validate_runtime_service_url(base_url, "base_url")
return MiniMaxLLMService(
api_key=api_key,
base_url=base_url or "https://api.minimax.io/v1",
base_url=base_url,
settings=MiniMaxLLMService.Settings(
model=model,
temperature=temperature if temperature is not None else 1.0,