From fc04f31639e0d326525d6840ca117babe2b25ea8 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Sat, 16 May 2026 18:37:38 +0530 Subject: [PATCH] fix: force FORCE_TURN_RELAY for local IPs in setup --- api/services/telephony/factory.py | 32 +++++----- api/tests/test_telephony_factory.py | 96 +++++++++++++++++++++++++++++ scripts/lib/setup_common.sh | 22 +++++++ scripts/setup_local.sh | 17 +++++ scripts/setup_remote.sh | 11 ++++ 5 files changed, 160 insertions(+), 18 deletions(-) create mode 100644 api/tests/test_telephony_factory.py diff --git a/api/services/telephony/factory.py b/api/services/telephony/factory.py index 5147a16..b5ca671 100644 --- a/api/services/telephony/factory.py +++ b/api/services/telephony/factory.py @@ -28,7 +28,7 @@ from api.services.telephony.base import TelephonyProvider async def load_telephony_config_by_id( - telephony_configuration_id: int, + telephony_configuration_id: int | str | None, organization_id: int, ) -> Dict[str, Any]: """Load and normalize the config row by primary key, scoped to the org. @@ -39,17 +39,19 @@ async def load_telephony_config_by_id( or doesn't belong to ``organization_id`` — the org scope is what makes this safe to expose to user-driven request flows. """ - if not telephony_configuration_id: - raise ValueError("telephony_configuration_id is required") + try: + resolved_cfg_id = int(telephony_configuration_id) + except (TypeError, ValueError) as e: + raise ValueError("telephony_configuration_id must be an integer") from e if not organization_id: raise ValueError("organization_id is required") row = await db_client.get_telephony_configuration_for_org( - telephony_configuration_id, organization_id + resolved_cfg_id, organization_id ) if not row: raise ValueError( - f"Telephony configuration {telephony_configuration_id} not found " + f"Telephony configuration {resolved_cfg_id} not found " f"for organization {organization_id}" ) return await _normalize_with_phone_numbers(row) @@ -120,7 +122,7 @@ async def find_telephony_config_for_inbound( async def get_telephony_provider_by_id( - telephony_configuration_id: int, + telephony_configuration_id: int | str | None, organization_id: int, ) -> TelephonyProvider: config = await load_telephony_config_by_id( @@ -142,7 +144,7 @@ async def get_telephony_provider_for_run( still resolve. """ cfg_id = (workflow_run.initial_context or {}).get("telephony_configuration_id") - if cfg_id: + if cfg_id is not None: return await get_telephony_provider_by_id(cfg_id, organization_id) return await get_default_telephony_provider(organization_id) @@ -167,7 +169,7 @@ async def get_telephony_provider_for_inbound( async def load_credentials_for_transport( organization_id: int, - telephony_configuration_id: Optional[int], + telephony_configuration_id: Optional[int | str], expected_provider: str, ) -> Dict[str, Any]: """Helper for per-provider transport modules. @@ -178,10 +180,9 @@ async def load_credentials_for_transport( so legacy runs created before the multi-config migration still work. Raises ValueError when the resolved config is for a different provider. """ - if telephony_configuration_id: - config = await load_telephony_config_by_id( - telephony_configuration_id, organization_id - ) + resolved_cfg_id = telephony_configuration_id + if resolved_cfg_id is not None: + config = await load_telephony_config_by_id(resolved_cfg_id, organization_id) else: config = await load_default_telephony_config(organization_id) @@ -189,7 +190,7 @@ async def load_credentials_for_transport( if actual != expected_provider: raise ValueError( f"Expected {expected_provider} provider, got {actual} " - f"(config_id={telephony_configuration_id}, org={organization_id})" + f"(config_id={resolved_cfg_id}, org={organization_id})" ) return config @@ -199,11 +200,6 @@ async def get_all_telephony_providers() -> List[Type[TelephonyProvider]]: return [spec.provider_cls for spec in registry.all_specs()] -# --------------------------------------------------------------------------- -# Internals -# --------------------------------------------------------------------------- - - async def _normalize_with_phone_numbers( row: TelephonyConfigurationModel, ) -> Dict[str, Any]: diff --git a/api/tests/test_telephony_factory.py b/api/tests/test_telephony_factory.py new file mode 100644 index 0000000..cca9da4 --- /dev/null +++ b/api/tests/test_telephony_factory.py @@ -0,0 +1,96 @@ +from types import SimpleNamespace +from unittest.mock import AsyncMock, patch + +import pytest + +from api.services.telephony.factory import ( + get_telephony_provider_for_run, + load_credentials_for_transport, + load_telephony_config_by_id, +) + + +@pytest.mark.asyncio +async def test_get_telephony_provider_for_run_casts_numeric_string_config_id(): + workflow_run = SimpleNamespace( + initial_context={"telephony_configuration_id": "213"} + ) + + with ( + patch( + "api.services.telephony.factory.get_telephony_provider_by_id", + new_callable=AsyncMock, + return_value="provider", + ) as get_provider, + patch( + "api.services.telephony.factory.get_default_telephony_provider", + new_callable=AsyncMock, + ) as get_default, + ): + result = await get_telephony_provider_for_run(workflow_run, 2617) + + assert result == "provider" + get_provider.assert_awaited_once_with("213", 2617) + get_default.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_get_telephony_provider_for_run_rejects_non_numeric_string_config_id(): + workflow_run = SimpleNamespace( + initial_context={"telephony_configuration_id": "twilio-main"} + ) + + with patch( + "api.services.telephony.factory.get_default_telephony_provider", + new_callable=AsyncMock, + ) as get_default: + with pytest.raises( + ValueError, + match="telephony_configuration_id must be an integer", + ): + await get_telephony_provider_for_run(workflow_run, 2617) + + get_default.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_load_credentials_for_transport_casts_numeric_string_config_id(): + with ( + patch( + "api.services.telephony.factory.load_telephony_config_by_id", + new_callable=AsyncMock, + return_value={"provider": "twilio"}, + ) as load_by_id, + patch( + "api.services.telephony.factory.load_default_telephony_config", + new_callable=AsyncMock, + ) as load_default, + ): + result = await load_credentials_for_transport(2617, "213", "twilio") + + assert result == {"provider": "twilio"} + load_by_id.assert_awaited_once_with("213", 2617) + load_default.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_load_telephony_config_by_id_casts_numeric_string_before_db_lookup(): + row = SimpleNamespace(id=213) + + with ( + patch( + "api.services.telephony.factory.db_client.get_telephony_configuration_for_org", + new_callable=AsyncMock, + return_value=row, + ) as get_config, + patch( + "api.services.telephony.factory._normalize_with_phone_numbers", + new_callable=AsyncMock, + return_value={"provider": "twilio"}, + ) as normalize, + ): + result = await load_telephony_config_by_id("213", 2617) + + assert result == {"provider": "twilio"} + get_config.assert_awaited_once_with(213, 2617) + normalize.assert_awaited_once_with(row) diff --git a/scripts/lib/setup_common.sh b/scripts/lib/setup_common.sh index e22b9ee..85758e6 100644 --- a/scripts/lib/setup_common.sh +++ b/scripts/lib/setup_common.sh @@ -98,6 +98,28 @@ dograh_is_ipv4() { [[ "$1" =~ ^([0-9]{1,3}\.){3}[0-9]{1,3}$ ]] } +dograh_is_local_ipv4() { + local ip=$1 + local o1 o2 o3 o4 octet + + dograh_is_ipv4 "$ip" || return 1 + IFS=. read -r o1 o2 o3 o4 <<< "$ip" + + for octet in "$o1" "$o2" "$o3" "$o4"; do + [[ "$octet" =~ ^[0-9]+$ ]] || return 1 + (( octet >= 0 && octet <= 255 )) || return 1 + done + + (( o1 == 10 )) && return 0 + (( o1 == 127 )) && return 0 + (( o1 == 169 && o2 == 254 )) && return 0 + (( o1 == 172 && o2 >= 16 && o2 <= 31 )) && return 0 + (( o1 == 192 && o2 == 168 )) && return 0 + (( o1 == 100 && o2 >= 64 && o2 <= 127 )) && return 0 + + return 1 +} + dograh_infer_server_ip() { local project_dir=${1:-$(dograh_project_dir)} local turn_conf="$project_dir/turnserver.conf" diff --git a/scripts/setup_local.sh b/scripts/setup_local.sh index be31384..590b774 100755 --- a/scripts/setup_local.sh +++ b/scripts/setup_local.sh @@ -68,6 +68,8 @@ if [[ "${ENABLE_COTURN:-false}" == "true" ]]; then ip=$(hostname -I 2>/dev/null | awk '{print $1}') [[ -n "$ip" ]] && { echo "$ip"; return; } fi + + return 0 } DEFAULT_TURN_HOST="$(detect_lan_ip)" @@ -100,6 +102,17 @@ if [[ "${ENABLE_COTURN:-false}" == "true" ]]; then fi fi +if [[ "${ENABLE_COTURN:-false}" != "true" ]]; then + FORCE_TURN_RELAY=false +elif [[ -z "${FORCE_TURN_RELAY:-}" ]]; then + if dograh_is_local_ipv4 "$TURN_HOST"; then + FORCE_TURN_RELAY=true + echo -e "${YELLOW}Detected a local/private TURN host IP; enabling FORCE_TURN_RELAY=true.${NC}" + else + FORCE_TURN_RELAY=false + fi +fi + # Telemetry opt-out (default: true) ENABLE_TELEMETRY="${ENABLE_TELEMETRY:-true}" @@ -112,6 +125,7 @@ echo -e " Coturn: ${BLUE}${ENABLE_COTURN:-false}${NC}" if [[ "${ENABLE_COTURN:-false}" == "true" ]]; then echo -e " TURN Host: ${BLUE}$TURN_HOST${NC}" echo -e " TURN Secret: ${BLUE}********${NC}" + echo -e " Force relay: ${BLUE}$FORCE_TURN_RELAY${NC}" fi echo -e " Telemetry: ${BLUE}$ENABLE_TELEMETRY${NC}" echo -e " Registry: ${BLUE}$REGISTRY${NC}" @@ -155,6 +169,9 @@ OSS_JWT_SECRET=$OSS_JWT_SECRET # Telemetry (set to false to disable) ENABLE_TELEMETRY=$ENABLE_TELEMETRY + +# Relay-only ICE candidates (auto-enabled for local/private TURN host IPs) +FORCE_TURN_RELAY=$FORCE_TURN_RELAY ENV_EOF if [[ "${ENABLE_COTURN:-false}" == "true" ]]; then diff --git a/scripts/setup_remote.sh b/scripts/setup_remote.sh index cd01f7f..073689f 100755 --- a/scripts/setup_remote.sh +++ b/scripts/setup_remote.sh @@ -49,6 +49,15 @@ if ! dograh_is_ipv4 "$SERVER_IP"; then dograh_fail "Invalid IP address format" fi +if [[ -z "${FORCE_TURN_RELAY:-}" ]]; then + if dograh_is_local_ipv4 "$SERVER_IP"; then + FORCE_TURN_RELAY=true + dograh_warn "Detected a local/private server IP; enabling FORCE_TURN_RELAY=true." + else + FORCE_TURN_RELAY=false + fi +fi + # Get the TURN secret (skip prompt if TURN_SECRET is already set) if [[ -z "${TURN_SECRET:-}" ]]; then echo -e "${YELLOW}Enter a shared secret for the TURN server (press Enter to generate a random one):${NC}" @@ -185,6 +194,7 @@ echo -e "${GREEN}Configuration:${NC}" echo -e " Server IP: ${BLUE}$SERVER_IP${NC}" echo -e " TURN Secret: ${BLUE}********${NC}" echo -e " Deploy mode: ${BLUE}$DEPLOY_MODE${NC}" +echo -e " Force TURN relay: ${BLUE}$FORCE_TURN_RELAY${NC}" echo -e " FastAPI workers: ${BLUE}$FASTAPI_WORKERS${NC} (ports 8000..$((8000 + FASTAPI_WORKERS - 1)))" if [[ "$DEPLOY_MODE" == "build" ]]; then if [[ "${REPO_SOURCE:-}" == "clone" ]]; then @@ -267,6 +277,7 @@ MINIO_PUBLIC_ENDPOINT=https://$SERVER_IP # TURN Server Configuration (time-limited credentials via TURN REST API) TURN_HOST=$SERVER_IP TURN_SECRET=$TURN_SECRET +FORCE_TURN_RELAY=$FORCE_TURN_RELAY # JWT secret for OSS authentication OSS_JWT_SECRET=$OSS_JWT_SECRET