mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
feat: allow overriding base URL of OpenAI STT and TTS (#377)
Mirrors the LLM treatment from #368 for the OpenAI STT and OpenAI TTS providers. Users running OpenAI-compatible self-hosted services (vLLM, Speaches, llama.cpp, custom proxies) can now point Dograh at them via the OpenAI provider with `base_url`, instead of being forced onto the Speaches provider as a workaround. Changes: * `registry.py` — add `base_url` field (default `https://api.openai.com/v1`) to `OpenAISTTConfiguration` and `OpenAITTSService`, identical in shape to the existing `OpenAILLMService.base_url` from #368. * `service_factory.py` — in the OPENAI branches of `create_stt_service` and `create_tts_service`, lift `base_url` off the user config, run it through `_validate_runtime_service_url`, and forward it as a kwarg to `OpenAISTTService` / `OpenAITTSService` (both already accept it). Same pattern as the LLM branch. * `test_user_configured_service_url_security.py` — adds four runtime validation tests covering private-IP rejection and localhost rejection in SaaS mode for both STT and TTS. Existing OSS-mode permissiveness is unchanged (DEPLOYMENT_MODE=oss skips the validator, as before). No schema migration needed — Pydantic populates the default; existing configurations without `base_url` continue to talk to api.openai.com. `check_validity.py` requires no edits because the per-service validation loop already iterates `("base_url", "endpoint")` via `getattr`, and the `_check_openai_api_key` dispatcher already routes OPENAI providers through the base_url-aware code path (introduced in #368) for STT and TTS too. Tests pass locally: pytest api/tests/test_user_configured_service_url_security.py 23 passed in 4.80s (19 existing + 4 new) Co-authored-by: developer603 <developer603@users.noreply.github.com>
This commit is contained in:
parent
dd85c4a1b4
commit
8a4a2e25db
3 changed files with 95 additions and 0 deletions
|
|
@ -830,6 +830,10 @@ class OpenAITTSService(BaseTTSConfiguration):
|
||||||
default="alloy",
|
default="alloy",
|
||||||
description="OpenAI TTS voice name.",
|
description="OpenAI TTS voice name.",
|
||||||
)
|
)
|
||||||
|
base_url: str = Field(
|
||||||
|
default="https://api.openai.com/v1",
|
||||||
|
description="Override only if using an OpenAI-compatible API (e.g. local TTS, proxy).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DOGRAH_TTS_MODELS = ["default"]
|
DOGRAH_TTS_MODELS = ["default"]
|
||||||
|
|
@ -1088,6 +1092,10 @@ class OpenAISTTConfiguration(BaseSTTConfiguration):
|
||||||
description="OpenAI transcription model.",
|
description="OpenAI transcription model.",
|
||||||
json_schema_extra={"examples": OPENAI_STT_MODELS},
|
json_schema_extra={"examples": OPENAI_STT_MODELS},
|
||||||
)
|
)
|
||||||
|
base_url: str = Field(
|
||||||
|
default="https://api.openai.com/v1",
|
||||||
|
description="Override only if using an OpenAI-compatible API (e.g. local STT, proxy).",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_stt
|
@register_stt
|
||||||
|
|
|
||||||
|
|
@ -119,9 +119,15 @@ def create_stt_service(
|
||||||
sample_rate=audio_config.transport_in_sample_rate,
|
sample_rate=audio_config.transport_in_sample_rate,
|
||||||
)
|
)
|
||||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||||
|
kwargs = {}
|
||||||
|
base_url = getattr(user_config.stt, "base_url", None)
|
||||||
|
if base_url:
|
||||||
|
_validate_runtime_service_url(base_url, "base_url")
|
||||||
|
kwargs["base_url"] = base_url
|
||||||
return OpenAISTTService(
|
return OpenAISTTService(
|
||||||
api_key=user_config.stt.api_key,
|
api_key=user_config.stt.api_key,
|
||||||
settings=OpenAISTTSettings(model=user_config.stt.model),
|
settings=OpenAISTTSettings(model=user_config.stt.model),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif user_config.stt.provider == ServiceProviders.GOOGLE.value:
|
elif user_config.stt.provider == ServiceProviders.GOOGLE.value:
|
||||||
language = getattr(user_config.stt, "language", None) or "en-US"
|
language = getattr(user_config.stt, "language", None) or "en-US"
|
||||||
|
|
@ -283,12 +289,18 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
||||||
silence_time_s=1.0,
|
silence_time_s=1.0,
|
||||||
)
|
)
|
||||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||||
|
kwargs = {}
|
||||||
|
base_url = getattr(user_config.tts, "base_url", None)
|
||||||
|
if base_url:
|
||||||
|
_validate_runtime_service_url(base_url, "base_url")
|
||||||
|
kwargs["base_url"] = base_url
|
||||||
return OpenAITTSService(
|
return OpenAITTSService(
|
||||||
api_key=user_config.tts.api_key,
|
api_key=user_config.tts.api_key,
|
||||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||||
text_filters=[xml_function_tag_filter],
|
text_filters=[xml_function_tag_filter],
|
||||||
skip_aggregator_types=["recording_router", "recording"],
|
skip_aggregator_types=["recording_router", "recording"],
|
||||||
silence_time_s=1.0,
|
silence_time_s=1.0,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
elif user_config.tts.provider == ServiceProviders.GOOGLE.value:
|
elif user_config.tts.provider == ServiceProviders.GOOGLE.value:
|
||||||
model = getattr(user_config.tts, "model", None) or "chirp_3_hd"
|
model = getattr(user_config.tts, "model", None) or "chirp_3_hd"
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ from api.services.configuration.registry import (
|
||||||
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
|
from api.services.gen_ai.embedding.openai_service import OpenAIEmbeddingService
|
||||||
from api.services.pipecat.service_factory import (
|
from api.services.pipecat.service_factory import (
|
||||||
create_llm_service_from_provider,
|
create_llm_service_from_provider,
|
||||||
|
create_stt_service,
|
||||||
create_tts_service,
|
create_tts_service,
|
||||||
)
|
)
|
||||||
from api.utils.url_security import validate_user_configured_service_url
|
from api.utils.url_security import validate_user_configured_service_url
|
||||||
|
|
@ -214,6 +215,80 @@ def test_runtime_blocks_elevenlabs_local_tts_base_url_in_saas(monkeypatch):
|
||||||
assert "localhost" in exc_info.value.detail
|
assert "localhost" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_blocks_openai_stt_private_base_url_in_saas(monkeypatch):
|
||||||
|
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
stt=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.OPENAI.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="gpt-4o-transcribe",
|
||||||
|
base_url="http://10.0.0.10/v1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
create_stt_service(user_config, audio_config=None)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "public IP" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_blocks_openai_stt_localhost_base_url_in_saas(monkeypatch):
|
||||||
|
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
stt=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.OPENAI.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="gpt-4o-transcribe",
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
create_stt_service(user_config, audio_config=None)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "localhost" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_blocks_openai_tts_private_base_url_in_saas(monkeypatch):
|
||||||
|
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
tts=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.OPENAI.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="gpt-4o-mini-tts",
|
||||||
|
voice="alloy",
|
||||||
|
base_url="http://10.0.0.10/v1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
create_tts_service(user_config, audio_config=None)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "public IP" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_blocks_openai_tts_localhost_base_url_in_saas(monkeypatch):
|
||||||
|
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||||
|
user_config = SimpleNamespace(
|
||||||
|
tts=SimpleNamespace(
|
||||||
|
provider=ServiceProviders.OPENAI.value,
|
||||||
|
api_key="test-key",
|
||||||
|
model="gpt-4o-mini-tts",
|
||||||
|
voice="alloy",
|
||||||
|
base_url="http://localhost:8000/v1",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
create_tts_service(user_config, audio_config=None)
|
||||||
|
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
assert "localhost" in exc_info.value.detail
|
||||||
|
|
||||||
|
|
||||||
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
|
def test_embedding_service_blocks_private_base_url_in_saas(monkeypatch):
|
||||||
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
monkeypatch.setattr("api.utils.url_security.DEPLOYMENT_MODE", "saas")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue