diff --git a/api/routes/main.py b/api/routes/main.py
index 8086114..a9f9a6f 100644
--- a/api/routes/main.py
+++ b/api/routes/main.py
@@ -58,11 +58,13 @@ class HealthResponse(BaseModel):
@router.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
- from api.constants import APP_VERSION, BACKEND_API_ENDPOINT
+ from api.constants import APP_VERSION
+ from api.utils.common import get_backend_endpoints
logger.debug("Health endpoint called")
+ backend_endpoint, _ = await get_backend_endpoints()
return HealthResponse(
status="ok",
version=APP_VERSION,
- backend_api_endpoint=BACKEND_API_ENDPOINT,
+ backend_api_endpoint=backend_endpoint,
)
diff --git a/api/routes/public_agent.py b/api/routes/public_agent.py
index 23c6578..c9cae40 100644
--- a/api/routes/public_agent.py
+++ b/api/routes/public_agent.py
@@ -15,7 +15,7 @@ from api.db import db_client
from api.enums import TriggerState
from api.services.quota_service import check_dograh_quota_by_user_id
from api.services.telephony.factory import get_telephony_provider
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
router = APIRouter(prefix="/public/agent")
@@ -147,11 +147,11 @@ async def initiate_call(
)
# 9. Construct webhook URL for telephony provider callback
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT
webhook_url = (
- f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
+ f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
f"?workflow_id={trigger.workflow_id}"
f"&user_id={api_key.created_by}"
f"&workflow_run_id={workflow_run.id}"
diff --git a/api/routes/telephony.py b/api/routes/telephony.py
index 440eb36..70aada7 100644
--- a/api/routes/telephony.py
+++ b/api/routes/telephony.py
@@ -30,13 +30,13 @@ from api.services.telephony.factory import (
get_all_telephony_providers,
get_telephony_provider,
)
+from api.utils.common import get_backend_endpoints
from api.utils.telephony_helper import (
generic_hangup_response,
normalize_webhook_data,
numbers_match,
parse_webhook_request,
)
-from api.utils.tunnel import TunnelURLProvider
from pipecat.utils.context import set_current_run_id
router = APIRouter(prefix="/telephony")
@@ -159,12 +159,12 @@ async def initiate_call(
workflow_run_name = workflow_run.name
# Construct webhook URL based on provider type
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT
webhook_url = (
- f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
+ f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
f"?workflow_id={request.workflow_id}"
f"&user_id={user.id}"
f"&workflow_run_id={workflow_run_id}"
@@ -313,10 +313,8 @@ async def _validate_inbound_request(
# Verify webhook signature/API key if provided
provider_instance = None
if x_twilio_signature or x_vobiz_signature or x_cx_apikey:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- webhook_url = (
- f"https://{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}"
- )
+ backend_endpoint, _ = await get_backend_endpoints()
+ webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}"
# Get the real telephony provider with actual credentials for signature verification
provider_instance = await get_telephony_provider(organization_id)
@@ -613,8 +611,8 @@ async def handle_twilio_status_callback(
provider = await get_telephony_provider(workflow.organization_id)
if x_webhook_signature:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ full_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
is_valid = await provider.verify_webhook_signature(
full_url, callback_data, x_webhook_signature
@@ -887,8 +885,8 @@ async def handle_vobiz_hangup_callback(
webhook_body = raw_body.decode("utf-8")
# Verify signature
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
is_valid = await provider.verify_webhook_signature(
webhook_url,
@@ -1009,8 +1007,10 @@ async def handle_vobiz_ring_callback(
webhook_body = raw_body.decode("utf-8")
# Verify signature
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ webhook_url = (
+ f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
+ )
is_valid = await provider.verify_webhook_signature(
webhook_url,
@@ -1157,8 +1157,8 @@ async def handle_vobiz_hangup_callback_by_workflow(
if x_vobiz_signature:
raw_body = await request.body()
webhook_body = raw_body.decode("utf-8")
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
is_valid = await provider.verify_webhook_signature(
webhook_url,
@@ -1338,8 +1338,8 @@ async def handle_inbound_telephony(
)
# Generate response URLs
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- websocket_url = f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}"
+ _, wss_backend_endpoint = await get_backend_endpoints()
+ websocket_url = f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}"
response = await provider_class.generate_inbound_response(
websocket_url, workflow_run_id
)
diff --git a/api/services/campaign/call_dispatcher.py b/api/services/campaign/call_dispatcher.py
index 5d3c38f..523df77 100644
--- a/api/services/campaign/call_dispatcher.py
+++ b/api/services/campaign/call_dispatcher.py
@@ -12,7 +12,7 @@ from api.enums import OrganizationConfigurationKey, WorkflowRunState
from api.services.campaign.rate_limiter import rate_limiter
from api.services.telephony.base import TelephonyProvider
from api.services.telephony.factory import get_telephony_provider
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
class CampaignCallDispatcher:
@@ -249,10 +249,10 @@ class CampaignCallDispatcher:
# Initiate call via telephony provider
try:
# Construct webhook URL with parameters
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ backend_endpoint, _ = await get_backend_endpoints()
webhook_endpoint = provider.WEBHOOK_ENDPOINT
webhook_url = (
- f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
+ f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
f"?workflow_id={campaign.workflow_id}"
f"&user_id={campaign.created_by}"
f"&workflow_run_id={workflow_run.id}"
diff --git a/api/services/telephony/providers/cloudonix_provider.py b/api/services/telephony/providers/cloudonix_provider.py
index 048f49c..7126302 100644
--- a/api/services/telephony/providers/cloudonix_provider.py
+++ b/api/services/telephony/providers/cloudonix_provider.py
@@ -15,7 +15,7 @@ from api.services.telephony.base import (
NormalizedInboundData,
TelephonyProvider,
)
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
from fastapi import WebSocket
@@ -91,13 +91,13 @@ class CloudonixProvider(TelephonyProvider):
# Prepare call data using Cloudonix callObject schema
# Note: 'caller-id' is REQUIRED by Cloudonix API
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
data: Dict[str, Any] = {
"destination": to_number,
"cxml": f"""
-
+
""",
@@ -106,7 +106,7 @@ class CloudonixProvider(TelephonyProvider):
# Add status callback if workflow_run_id provided
if workflow_run_id:
- callback_url = f"https://{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}"
+ callback_url = f"{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}"
data["callback"] = callback_url
# Merge any additional kwargs
diff --git a/api/services/telephony/providers/twilio_provider.py b/api/services/telephony/providers/twilio_provider.py
index 7ac451f..bd3dab2 100644
--- a/api/services/telephony/providers/twilio_provider.py
+++ b/api/services/telephony/providers/twilio_provider.py
@@ -16,7 +16,7 @@ from api.services.telephony.base import (
NormalizedInboundData,
TelephonyProvider,
)
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
from fastapi import WebSocket
@@ -75,8 +75,8 @@ class TwilioProvider(TelephonyProvider):
# Add status callback if workflow_run_id provided
if workflow_run_id:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
data.update(
{
"StatusCallback": callback_url,
@@ -158,12 +158,12 @@ class TwilioProvider(TelephonyProvider):
"""
Generate TwiML response for starting a call session.
"""
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ _, wss_backend_endpoint = await get_backend_endpoints()
twiml_content = f"""
-
+
"""
@@ -405,8 +405,8 @@ class TwilioProvider(TelephonyProvider):
# Generate StatusCallback URL using same pattern as outbound calls
status_callback_attr = ""
if workflow_run_id:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- status_callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ status_callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
status_callback_attr = f' statusCallback="{status_callback_url}"'
twiml_content = f"""
diff --git a/api/services/telephony/providers/vobiz_provider.py b/api/services/telephony/providers/vobiz_provider.py
index 596c02f..0493e3c 100644
--- a/api/services/telephony/providers/vobiz_provider.py
+++ b/api/services/telephony/providers/vobiz_provider.py
@@ -15,7 +15,7 @@ from api.services.telephony.base import (
NormalizedInboundData,
TelephonyProvider,
)
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
from fastapi import WebSocket
@@ -89,9 +89,9 @@ class VobizProvider(TelephonyProvider):
# Add hangup callback if workflow_run_id provided
if workflow_run_id:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- hangup_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
- ring_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ hangup_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
+ ring_url = f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
data.update(
{
"hangup_url": hangup_url,
@@ -254,11 +254,11 @@ class VobizProvider(TelephonyProvider):
- audioTrack: Which audio to stream (inbound, outbound, both)
- contentType: audio/x-mulaw;rate=8000
"""
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ _, wss_backend_endpoint = await get_backend_endpoints()
vobiz_xml = f"""
- wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}
+ {wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}
"""
return vobiz_xml
diff --git a/api/services/telephony/providers/vonage_provider.py b/api/services/telephony/providers/vonage_provider.py
index e3b40fa..d4d9daf 100644
--- a/api/services/telephony/providers/vonage_provider.py
+++ b/api/services/telephony/providers/vonage_provider.py
@@ -18,7 +18,7 @@ from api.services.telephony.base import (
NormalizedInboundData,
TelephonyProvider,
)
-from api.utils.tunnel import TunnelURLProvider
+from api.utils.common import get_backend_endpoints
if TYPE_CHECKING:
from fastapi import WebSocket
@@ -106,8 +106,10 @@ class VonageProvider(TelephonyProvider):
# Add event webhook if workflow_run_id provided
if workflow_run_id:
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
- event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
+ backend_endpoint, _ = await get_backend_endpoints()
+ event_url = (
+ f"{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
+ )
data.update({"event_url": [event_url], "event_method": "POST"})
data.update(kwargs)
@@ -201,7 +203,7 @@ class VonageProvider(TelephonyProvider):
Generate NCCO response for starting a call session.
NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML.
"""
- backend_endpoint = await TunnelURLProvider.get_tunnel_url()
+ _, wss_backend_endpoint = await get_backend_endpoints()
# NCCO for WebSocket connection
ncco = [
@@ -210,7 +212,7 @@ class VonageProvider(TelephonyProvider):
"endpoint": [
{
"type": "websocket",
- "uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
+ "uri": f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
"headers": {},
}
diff --git a/api/tests/test_get_backend_endpoints.py b/api/tests/test_get_backend_endpoints.py
new file mode 100644
index 0000000..1d32ff7
--- /dev/null
+++ b/api/tests/test_get_backend_endpoints.py
@@ -0,0 +1,420 @@
+"""
+Tests for get_backend_endpoints function in api/utils/common.py
+
+Expected behavior:
+- Output URLs must always have a scheme (http:// or https://, ws:// or wss://)
+- Output URLs must NOT have trailing slashes
+"""
+
+from unittest.mock import AsyncMock, patch
+
+import pytest
+
+from api.utils.common import get_backend_endpoints, get_scheme
+
+# Valid test URLs covering various formats
+possible_env_paths = [
+ "http://localhost",
+ "http://localhost/",
+ "http://localhost:8000",
+ "http://localhost:8000/",
+ "http://127.0.0.1",
+ "http://127.0.0.1/",
+ "http://127.0.0.1:8000",
+ "http://127.0.0.1:8000/",
+ "http://xyz.com",
+ "http://xyz.com/",
+ "https://xyz.com",
+ "https://xyz.com/",
+ "localhost",
+ "localhost:8000",
+ "localhost/",
+ "localhost:8000/",
+ "xyz.com",
+ "xyz.com/",
+ "127.0.0.1",
+ "127.0.0.1/",
+ "127.0.0.1:8000",
+ "127.0.0.1:8000/",
+]
+
+# Invalid URLs that should raise ValueError
+invalid_env_paths = [
+ "http:/localhost", # Typo: single slash in scheme
+ "http:/xyz.com", # Typo: single slash in scheme
+ "https:/xyz.com", # Typo: single slash in scheme
+ "htp://xyz.com", # Typo: missing 't' in http
+ "htps://xyz.com", # Typo: missing 't' in https
+ "http//xyz.com", # Missing colon
+ "http:xyz.com", # Missing slashes
+ "http://xyz.com:abc", # Invalid port (non-numeric)
+ "http://xyz.com:-1", # Invalid port (negative)
+ "http://xyz.com:99999", # Invalid port (out of range)
+ "http://", # Missing host
+ "https://", # Missing host
+ "http:// ", # Whitespace host
+ "http://xyz .com", # Space in hostname
+ "http://xyz\t.com", # Tab in hostname
+ "http://xyz\n.com", # Newline in hostname
+ "", # Empty string
+ " ", # Only whitespace
+]
+
+
+class TestGetScheme:
+ """Tests for the get_scheme helper function."""
+
+ def test_http_scheme(self):
+ assert get_scheme("http://example.com") == "http"
+
+ def test_https_scheme(self):
+ assert get_scheme("https://example.com") == "https"
+
+ def test_no_scheme(self):
+ assert get_scheme("example.com") is None
+ assert get_scheme("localhost:8000") is None
+
+ def test_malformed_url_single_slash(self):
+ # 'http:/localhost' doesn't have '://' so returns None
+ assert get_scheme("http:/localhost") is None
+
+ def test_ws_scheme(self):
+ assert get_scheme("ws://example.com") == "ws"
+ assert get_scheme("wss://example.com") == "wss"
+
+
+class TestGetBackendEndpointsWithEnvVar:
+ """Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is set."""
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "env_url,expected_http,expected_ws",
+ [
+ # URLs with http:// scheme (with and without trailing slash -> no trailing slash)
+ ("http://xyz.com", "http://xyz.com", "ws://xyz.com"),
+ ("http://xyz.com/", "http://xyz.com", "ws://xyz.com"),
+ # URLs with https:// scheme (with and without trailing slash -> no trailing slash)
+ ("https://xyz.com", "https://xyz.com", "wss://xyz.com"),
+ ("https://xyz.com/", "https://xyz.com", "wss://xyz.com"),
+ # URLs without scheme (should add http/ws, no trailing slash)
+ ("xyz.com", "http://xyz.com", "ws://xyz.com"),
+ ("xyz.com/", "http://xyz.com", "ws://xyz.com"),
+ ],
+ )
+ async def test_non_localhost_urls(self, env_url, expected_http, expected_ws):
+ """Test non-localhost URLs return correct http and ws endpoints."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == expected_http
+ assert ws_url == expected_ws
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "env_url,expected_http,expected_ws",
+ [
+ # localhost URLs with http:// scheme
+ ("http://localhost", "http://localhost", "ws://localhost"),
+ ("http://localhost/", "http://localhost", "ws://localhost"),
+ ("http://localhost:8000", "http://localhost:8000", "ws://localhost:8000"),
+ ("http://localhost:8000/", "http://localhost:8000", "ws://localhost:8000"),
+ # localhost URLs without scheme (should add http/ws)
+ ("localhost", "http://localhost", "ws://localhost"),
+ ("localhost/", "http://localhost", "ws://localhost"),
+ ("localhost:8000", "http://localhost:8000", "ws://localhost:8000"),
+ ("localhost:8000/", "http://localhost:8000", "ws://localhost:8000"),
+ # 127.0.0.1 URLs with http:// scheme
+ ("http://127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"),
+ ("http://127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"),
+ ("http://127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
+ ("http://127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
+ # 127.0.0.1 URLs without scheme (should add http/ws)
+ ("127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"),
+ ("127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"),
+ ("127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
+ ("127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
+ ],
+ )
+ async def test_localhost_urls_no_tunnel(self, env_url, expected_http, expected_ws):
+ """Test localhost/127.0.0.1 URLs when tunnel is not available."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.return_value = None
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == expected_http
+ assert ws_url == expected_ws
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "env_url",
+ [
+ "http://localhost",
+ "http://localhost/",
+ "http://localhost:8000",
+ "http://localhost:8000/",
+ "localhost",
+ "localhost/",
+ "localhost:8000",
+ "localhost:8000/",
+ "http://127.0.0.1",
+ "http://127.0.0.1/",
+ "http://127.0.0.1:8000",
+ "http://127.0.0.1:8000/",
+ "127.0.0.1",
+ "127.0.0.1/",
+ "127.0.0.1:8000",
+ "127.0.0.1:8000/",
+ ],
+ )
+ async def test_localhost_urls_with_tunnel_available(self, env_url):
+ """Test localhost/127.0.0.1 URLs prefer tunnel when available."""
+ tunnel_http = "https://abc123.trycloudflare.com"
+ tunnel_ws = "wss://abc123.trycloudflare.com"
+
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.return_value = (tunnel_http, tunnel_ws)
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == tunnel_http
+ assert ws_url == tunnel_ws
+
+ @pytest.mark.asyncio
+ async def test_localhost_tunnel_exception_falls_back(self):
+ """Test that tunnel exceptions fall back to localhost endpoint."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000"):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.side_effect = Exception("Tunnel not available")
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://localhost:8000"
+ assert ws_url == "ws://localhost:8000"
+
+ @pytest.mark.asyncio
+ async def test_localhost_with_trailing_slash_tunnel_exception_falls_back(self):
+ """Test that tunnel exceptions fall back to localhost endpoint, trailing slash stripped."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000/"):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.side_effect = Exception("Tunnel not available")
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://localhost:8000"
+ assert ws_url == "ws://localhost:8000"
+
+ @pytest.mark.asyncio
+ async def test_127_tunnel_exception_falls_back(self):
+ """Test that tunnel exceptions fall back to 127.0.0.1 endpoint."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000"):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.side_effect = Exception("Tunnel not available")
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://127.0.0.1:8000"
+ assert ws_url == "ws://127.0.0.1:8000"
+
+ @pytest.mark.asyncio
+ async def test_127_with_trailing_slash_tunnel_exception_falls_back(self):
+ """Test that tunnel exceptions fall back to 127.0.0.1 endpoint, trailing slash stripped."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000/"):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.side_effect = Exception("Tunnel not available")
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://127.0.0.1:8000"
+ assert ws_url == "ws://127.0.0.1:8000"
+
+
+class TestGetBackendEndpointsNoEnvVar:
+ """Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is not set."""
+
+ @pytest.mark.asyncio
+ async def test_uses_tunnel_when_no_env_var(self):
+ """Test that tunnel URLs are used when env var is not set."""
+ tunnel_http = "https://abc123.trycloudflare.com"
+ tunnel_ws = "wss://abc123.trycloudflare.com"
+
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", None):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.return_value = (tunnel_http, tunnel_ws)
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == tunnel_http
+ assert ws_url == tunnel_ws
+
+ @pytest.mark.asyncio
+ async def test_raises_when_no_env_var_and_no_tunnel(self):
+ """Test that ValueError is raised when no env var and no tunnel."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", None):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.return_value = None
+ with pytest.raises(ValueError, match="No tunnel URL available"):
+ await get_backend_endpoints()
+
+
+class TestSchemeMapping:
+ """Tests to verify correct scheme mapping (http->ws, https->wss)."""
+
+ @pytest.mark.asyncio
+ async def test_http_maps_to_ws(self):
+ """Test http:// maps to ws://"""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com"):
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://example.com"
+ assert ws_url == "ws://example.com"
+
+ @pytest.mark.asyncio
+ async def test_https_maps_to_wss(self):
+ """Test https:// maps to wss://"""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "https://example.com"):
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "https://example.com"
+ assert ws_url == "wss://example.com"
+
+ @pytest.mark.asyncio
+ async def test_no_scheme_defaults_to_http_ws(self):
+ """Test URLs without scheme default to http:// and ws://"""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "example.com"):
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://example.com"
+ assert ws_url == "ws://example.com"
+
+ @pytest.mark.asyncio
+ async def test_trailing_slash_stripped(self):
+ """Test trailing slashes are stripped from output."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com/"):
+ http_url, ws_url = await get_backend_endpoints()
+ assert http_url == "http://example.com"
+ assert ws_url == "ws://example.com"
+ assert not http_url.endswith("/")
+ assert not ws_url.endswith("/")
+
+
+class TestInvalidUrls:
+ """Tests for invalid URLs that should raise ValueError."""
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "http:/localhost", # Typo: single slash in scheme
+ "http:/xyz.com", # Typo: single slash in scheme
+ "https:/xyz.com", # Typo: single slash in scheme
+ ],
+ )
+ async def test_malformed_scheme_single_slash(self, invalid_url):
+ """Test URLs with single slash in scheme raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with patch(
+ "api.utils.common.TunnelURLProvider.get_tunnel_urls",
+ new_callable=AsyncMock,
+ ) as mock_tunnel:
+ mock_tunnel.return_value = None
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "htp://xyz.com", # Typo: missing 't' in http
+ "htps://xyz.com", # Typo: missing 't' in https
+ "ftp://xyz.com", # Unsupported scheme
+ "file://xyz.com", # Unsupported scheme
+ ],
+ )
+ async def test_invalid_or_unsupported_scheme(self, invalid_url):
+ """Test URLs with invalid or unsupported schemes raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "http//xyz.com", # Missing colon
+ "http:xyz.com", # Missing slashes
+ ],
+ )
+ async def test_malformed_scheme_separator(self, invalid_url):
+ """Test URLs with malformed scheme separators raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "http://xyz.com:abc", # Invalid port (non-numeric)
+ "http://xyz.com:-1", # Invalid port (negative)
+ "http://xyz.com:99999", # Invalid port (out of range)
+ "http://xyz.com:", # Empty port
+ ],
+ )
+ async def test_invalid_port(self, invalid_url):
+ """Test URLs with invalid ports raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "http://", # Missing host
+ "https://", # Missing host
+ ],
+ )
+ async def test_missing_host(self, invalid_url):
+ """Test URLs with missing host raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "http:// ", # Whitespace host
+ "http://xyz .com", # Space in hostname
+ "http://xyz\t.com", # Tab in hostname
+ "http://xyz\n.com", # Newline in hostname
+ ],
+ )
+ async def test_invalid_characters_in_host(self, invalid_url):
+ """Test URLs with invalid characters in hostname raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
+
+ @pytest.mark.asyncio
+ @pytest.mark.parametrize(
+ "invalid_url",
+ [
+ "", # Empty string
+ " ", # Only whitespace
+ ],
+ )
+ async def test_empty_or_whitespace_url(self, invalid_url):
+ """Test empty or whitespace-only URLs raise ValueError."""
+ with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
+ with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
+ await get_backend_endpoints()
diff --git a/api/utils/common.py b/api/utils/common.py
new file mode 100644
index 0000000..b843d13
--- /dev/null
+++ b/api/utils/common.py
@@ -0,0 +1,181 @@
+"""
+Common utilities.
+Shared functions used across the application.
+"""
+
+import re
+
+from loguru import logger
+
+from api.constants import BACKEND_API_ENDPOINT
+from api.utils.tunnel import TunnelURLProvider
+
+
+def get_scheme(url: str) -> str | None:
+ """
+ Extract scheme from a given URL if present.
+ Returns None if not found
+ """
+ idx = url.find("://")
+ if idx == -1:
+ return None
+ return url[:idx]
+
+
+def _validate_url(url: str) -> None:
+ """
+ Validate URL format and raise ValueError for invalid URLs.
+
+ Checks for:
+ - Empty or whitespace-only URLs
+ - Malformed schemes (single slash, missing colon/slashes)
+ - Invalid/unsupported schemes
+ - Invalid ports (non-numeric, out of range, empty)
+ - Missing hosts
+ - Invalid characters in hostname (whitespace)
+ """
+ # Check for empty or whitespace-only URLs
+ if not url or not url.strip():
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT: URL cannot be empty or whitespace"
+ )
+
+ # Check for malformed schemes (single slash like http:/localhost)
+ if re.match(r"^https?:/[^/]", url):
+ raise ValueError(f"Invalid BACKEND_API_ENDPOINT: malformed scheme in '{url}'")
+
+ # Check for malformed scheme separators (http// or http:xyz without //)
+ if re.match(r"^https?//[^/]", url) or re.match(r"^https?:[^/]", url):
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT: malformed scheme separator in '{url}'"
+ )
+
+ # Check for invalid/unsupported schemes
+ scheme = get_scheme(url)
+ if scheme and scheme not in ("http", "https"):
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT: unsupported scheme '{scheme}' in '{url}'"
+ )
+
+ # Parse URL for further validation
+ if scheme:
+ # URL has a scheme, extract host part
+ host_part = url[len(scheme) + 3 :] # Skip "scheme://"
+ else:
+ host_part = url
+
+ # Strip trailing slash for host validation
+ host_part = host_part.rstrip("/")
+
+ # Check for missing host
+ if not host_part or not host_part.strip():
+ raise ValueError(f"Invalid BACKEND_API_ENDPOINT: missing host in '{url}'")
+
+ # Check for invalid characters in hostname (whitespace)
+ if re.search(r"\s", host_part):
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT: invalid characters in hostname '{url}'"
+ )
+
+ # Check for invalid port - look for colon followed by anything
+ port_match = re.search(r":([^/]*)$", host_part)
+ if port_match:
+ port_str = port_match.group(1)
+ if not port_str:
+ raise ValueError(f"Invalid BACKEND_API_ENDPOINT: empty port in '{url}'")
+ # Check if port is numeric
+ if not port_str.isdigit():
+ raise ValueError(f"Invalid BACKEND_API_ENDPOINT: invalid port in '{url}'")
+ port = int(port_str)
+ if port < 0 or port > 65535:
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT: port out of range in '{url}'"
+ )
+
+
+async def get_backend_endpoints() -> tuple[str, str]:
+ """
+ Get the backend endpoint URLs for external access (webhooks, callbacks, WebSocket connections).
+
+ Priority:
+ 1. BACKEND_API_ENDPOINT environment variable (if set and not localhost)
+ 2. Cloudflared Tunnel URLs (fallback for localhost or missing env var)
+
+ Protocol Handling:
+ 1. If URL has http:// - returns http:// and ws://
+ 2. If URL has https:// - returns https:// and wss://
+ 3. If URL has no protocol - defaults to http:// and ws://
+
+ Returns:
+ tuple[str, str]: (backend_endpoint, wss_backend_endpoint)
+
+ Raises:
+ ValueError: If no endpoint URL can be determined or URL is invalid
+ """
+
+ # If env var is explicitly set (even to empty/whitespace), validate it
+ if BACKEND_API_ENDPOINT is not None:
+ # Validate - this will raise for empty/whitespace
+ _validate_url(BACKEND_API_ENDPOINT)
+
+ if BACKEND_API_ENDPOINT:
+ logger.debug(
+ f"Processing BACKEND_API_ENDPOINT from environment: {BACKEND_API_ENDPOINT}"
+ )
+
+ # Handle localhost/127.0.0.1 special case - use tunnel URL if available
+ if "localhost" in BACKEND_API_ENDPOINT or "127.0.0.1" in BACKEND_API_ENDPOINT:
+ logger.debug(
+ f"BACKEND_API_ENDPOINT is local ({BACKEND_API_ENDPOINT}), checking tunnel URL"
+ )
+ try:
+ tunnel_urls = await TunnelURLProvider.get_tunnel_urls()
+ if tunnel_urls:
+ logger.debug(
+ f"Tunnel URLs available, using tunnel URLs instead of localhost"
+ )
+ return tunnel_urls
+ else:
+ logger.debug(
+ f"Tunnel URLs returned None, proceeding with localhost endpoint"
+ )
+ except Exception as e:
+ logger.debug(
+ f"No tunnel URLs available ({e}), proceeding with localhost endpoint"
+ )
+
+ try:
+ # Parse the URL to validate and handle protocol
+ scheme = get_scheme(BACKEND_API_ENDPOINT)
+
+ if scheme:
+ http_url = BACKEND_API_ENDPOINT.rstrip("/")
+ ws_scheme = {"http": "ws", "https": "wss"}[scheme]
+ ws_url = BACKEND_API_ENDPOINT.rstrip("/").replace(scheme, ws_scheme, 1)
+ else:
+ http_url = "http://" + BACKEND_API_ENDPOINT.rstrip("/")
+ ws_url = "ws://" + BACKEND_API_ENDPOINT.rstrip("/")
+
+ logger.debug(
+ f"Returning backend URLs - HTTP: {http_url}, WebSocket: {ws_url}"
+ )
+ return http_url, ws_url
+
+ except Exception as e:
+ # Case 4: Invalid URL format
+ raise ValueError(
+ f"Invalid BACKEND_API_ENDPOINT format: '{BACKEND_API_ENDPOINT}' - {str(e)}"
+ )
+
+ # Second priority: Query cloudflared tunnel URL when no environment variable is set
+ logger.debug("No BACKEND_API_ENDPOINT set, using tunnel URL")
+ tunnel_urls = await TunnelURLProvider.get_tunnel_urls()
+ if tunnel_urls:
+ logger.debug(f"Retrieved tunnel URLs: {tunnel_urls}")
+ return tunnel_urls
+ else:
+ logger.debug("No tunnel URLs available")
+ raise ValueError(
+ "No tunnel URL available. Please set BACKEND_API_ENDPOINT environment "
+ "variable or ensure cloudflared service is running."
+ )
diff --git a/api/utils/tunnel.py b/api/utils/tunnel.py
index 31a4588..8e2a521 100644
--- a/api/utils/tunnel.py
+++ b/api/utils/tunnel.py
@@ -1,7 +1,6 @@
"""Utility for getting the cloudflared tunnel URL at runtime."""
import asyncio
-import os
import re
from typing import Optional
@@ -10,37 +9,26 @@ from loguru import logger
class TunnelURLProvider:
- """Provider for getting the tunnel URL from cloudflared or environment."""
+ """Provider for getting tunnel URLs from cloudflared service."""
@classmethod
- async def get_tunnel_url(cls) -> str:
+ async def get_tunnel_urls(cls) -> tuple[str, str]:
"""
- Get the tunnel URL for external access.
-
- Priority:
- 1. BACKEND_API_ENDPOINT environment variable (if set)
- 2. Query cloudflared metrics endpoint
- 3. Raise error if neither available
+ Get the tunnel URLs for external access.
Returns:
- str: The tunnel domain (without protocol)
+ tuple[str, str]: (https_url, wss_url) - Both URLs include full protocol
Raises:
ValueError: If no tunnel URL can be determined
"""
- # First priority: Check environment variable
- env_endpoint = os.getenv("BACKEND_API_ENDPOINT")
- if env_endpoint:
- logger.debug(f"Using BACKEND_API_ENDPOINT from environment: {env_endpoint}")
- return env_endpoint
- # Second priority: Query cloudflared
try:
# Try to get URL from cloudflared metrics
- url = await cls._get_cloudflared_url()
- if url:
- logger.info(f"Retrieved tunnel URL from cloudflared: {url}")
- return url
+ urls = await cls._get_cloudflared_urls()
+ if urls:
+ logger.info(f"Retrieved tunnel URLs from cloudflared: {urls}")
+ return urls
except Exception as e:
logger.warning(f"Failed to get tunnel URL from cloudflared: {e}")
@@ -50,12 +38,12 @@ class TunnelURLProvider:
)
@classmethod
- async def _get_cloudflared_url(cls) -> Optional[str]:
+ async def _get_cloudflared_urls(cls) -> Optional[tuple[str, str]]:
"""
- Query cloudflared metrics endpoint to get the tunnel URL.
+ Query cloudflared metrics endpoint to get the tunnel URLs.
Returns:
- Optional[str]: The tunnel domain (without protocol), or None if not found
+ Optional[tuple[str, str]]: (https_url, wss_url) with full protocols, or None if not found
"""
try:
# Try to connect to cloudflared metrics endpoint
@@ -83,12 +71,16 @@ class TunnelURLProvider:
hostname = hostname.replace("https://", "").replace(
"wss://", ""
)
- return hostname
+ return "https://" + hostname, "wss://" + hostname
# Alternative: Look for trycloudflare.com domain
match = re.search(r"([a-z0-9-]+\.trycloudflare\.com)", text)
if match:
- return match.group(1)
+ hostname = match.group(1)
+ hostname = hostname.replace("https://", "").replace(
+ "wss://", ""
+ )
+ return f"https://{hostname}", f"wss://{hostname}"
logger.warning("Could not find tunnel URL in cloudflared metrics")
return None