diff --git a/api/routes/main.py b/api/routes/main.py index 8086114..a9f9a6f 100644 --- a/api/routes/main.py +++ b/api/routes/main.py @@ -58,11 +58,13 @@ class HealthResponse(BaseModel): @router.get("/health", response_model=HealthResponse) async def health() -> HealthResponse: - from api.constants import APP_VERSION, BACKEND_API_ENDPOINT + from api.constants import APP_VERSION + from api.utils.common import get_backend_endpoints logger.debug("Health endpoint called") + backend_endpoint, _ = await get_backend_endpoints() return HealthResponse( status="ok", version=APP_VERSION, - backend_api_endpoint=BACKEND_API_ENDPOINT, + backend_api_endpoint=backend_endpoint, ) diff --git a/api/routes/public_agent.py b/api/routes/public_agent.py index 23c6578..c9cae40 100644 --- a/api/routes/public_agent.py +++ b/api/routes/public_agent.py @@ -15,7 +15,7 @@ from api.db import db_client from api.enums import TriggerState from api.services.quota_service import check_dograh_quota_by_user_id from api.services.telephony.factory import get_telephony_provider -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints router = APIRouter(prefix="/public/agent") @@ -147,11 +147,11 @@ async def initiate_call( ) # 9. Construct webhook URL for telephony provider callback - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + backend_endpoint, _ = await get_backend_endpoints() webhook_endpoint = provider.WEBHOOK_ENDPOINT webhook_url = ( - f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" + f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" f"?workflow_id={trigger.workflow_id}" f"&user_id={api_key.created_by}" f"&workflow_run_id={workflow_run.id}" diff --git a/api/routes/telephony.py b/api/routes/telephony.py index 440eb36..70aada7 100644 --- a/api/routes/telephony.py +++ b/api/routes/telephony.py @@ -30,13 +30,13 @@ from api.services.telephony.factory import ( get_all_telephony_providers, get_telephony_provider, ) +from api.utils.common import get_backend_endpoints from api.utils.telephony_helper import ( generic_hangup_response, normalize_webhook_data, numbers_match, parse_webhook_request, ) -from api.utils.tunnel import TunnelURLProvider from pipecat.utils.context import set_current_run_id router = APIRouter(prefix="/telephony") @@ -159,12 +159,12 @@ async def initiate_call( workflow_run_name = workflow_run.name # Construct webhook URL based on provider type - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + backend_endpoint, _ = await get_backend_endpoints() webhook_endpoint = provider.WEBHOOK_ENDPOINT webhook_url = ( - f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" + f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" f"?workflow_id={request.workflow_id}" f"&user_id={user.id}" f"&workflow_run_id={workflow_run_id}" @@ -313,10 +313,8 @@ async def _validate_inbound_request( # Verify webhook signature/API key if provided provider_instance = None if x_twilio_signature or x_vobiz_signature or x_cx_apikey: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - webhook_url = ( - f"https://{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}" - ) + backend_endpoint, _ = await get_backend_endpoints() + webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}" # Get the real telephony provider with actual credentials for signature verification provider_instance = await get_telephony_provider(organization_id) @@ -613,8 +611,8 @@ async def handle_twilio_status_callback( provider = await get_telephony_provider(workflow.organization_id) if x_webhook_signature: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + full_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" is_valid = await provider.verify_webhook_signature( full_url, callback_data, x_webhook_signature @@ -887,8 +885,8 @@ async def handle_vobiz_hangup_callback( webhook_body = raw_body.decode("utf-8") # Verify signature - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}" is_valid = await provider.verify_webhook_signature( webhook_url, @@ -1009,8 +1007,10 @@ async def handle_vobiz_ring_callback( webhook_body = raw_body.decode("utf-8") # Verify signature - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + webhook_url = ( + f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}" + ) is_valid = await provider.verify_webhook_signature( webhook_url, @@ -1157,8 +1157,8 @@ async def handle_vobiz_hangup_callback_by_workflow( if x_vobiz_signature: raw_body = await request.body() webhook_body = raw_body.decode("utf-8") - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}" + backend_endpoint, _ = await get_backend_endpoints() + webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}" is_valid = await provider.verify_webhook_signature( webhook_url, @@ -1338,8 +1338,8 @@ async def handle_inbound_telephony( ) # Generate response URLs - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - websocket_url = f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}" + _, wss_backend_endpoint = await get_backend_endpoints() + websocket_url = f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}" response = await provider_class.generate_inbound_response( websocket_url, workflow_run_id ) diff --git a/api/services/campaign/call_dispatcher.py b/api/services/campaign/call_dispatcher.py index 5d3c38f..523df77 100644 --- a/api/services/campaign/call_dispatcher.py +++ b/api/services/campaign/call_dispatcher.py @@ -12,7 +12,7 @@ from api.enums import OrganizationConfigurationKey, WorkflowRunState from api.services.campaign.rate_limiter import rate_limiter from api.services.telephony.base import TelephonyProvider from api.services.telephony.factory import get_telephony_provider -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints class CampaignCallDispatcher: @@ -249,10 +249,10 @@ class CampaignCallDispatcher: # Initiate call via telephony provider try: # Construct webhook URL with parameters - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + backend_endpoint, _ = await get_backend_endpoints() webhook_endpoint = provider.WEBHOOK_ENDPOINT webhook_url = ( - f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" + f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}" f"?workflow_id={campaign.workflow_id}" f"&user_id={campaign.created_by}" f"&workflow_run_id={workflow_run.id}" diff --git a/api/services/telephony/providers/cloudonix_provider.py b/api/services/telephony/providers/cloudonix_provider.py index 048f49c..7126302 100644 --- a/api/services/telephony/providers/cloudonix_provider.py +++ b/api/services/telephony/providers/cloudonix_provider.py @@ -15,7 +15,7 @@ from api.services.telephony.base import ( NormalizedInboundData, TelephonyProvider, ) -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints if TYPE_CHECKING: from fastapi import WebSocket @@ -91,13 +91,13 @@ class CloudonixProvider(TelephonyProvider): # Prepare call data using Cloudonix callObject schema # Note: 'caller-id' is REQUIRED by Cloudonix API - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + backend_endpoint, wss_backend_endpoint = await get_backend_endpoints() data: Dict[str, Any] = { "destination": to_number, "cxml": f""" - + """, @@ -106,7 +106,7 @@ class CloudonixProvider(TelephonyProvider): # Add status callback if workflow_run_id provided if workflow_run_id: - callback_url = f"https://{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}" + callback_url = f"{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}" data["callback"] = callback_url # Merge any additional kwargs diff --git a/api/services/telephony/providers/twilio_provider.py b/api/services/telephony/providers/twilio_provider.py index 7ac451f..bd3dab2 100644 --- a/api/services/telephony/providers/twilio_provider.py +++ b/api/services/telephony/providers/twilio_provider.py @@ -16,7 +16,7 @@ from api.services.telephony.base import ( NormalizedInboundData, TelephonyProvider, ) -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints if TYPE_CHECKING: from fastapi import WebSocket @@ -75,8 +75,8 @@ class TwilioProvider(TelephonyProvider): # Add status callback if workflow_run_id provided if workflow_run_id: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" data.update( { "StatusCallback": callback_url, @@ -158,12 +158,12 @@ class TwilioProvider(TelephonyProvider): """ Generate TwiML response for starting a call session. """ - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + _, wss_backend_endpoint = await get_backend_endpoints() twiml_content = f""" - + """ @@ -405,8 +405,8 @@ class TwilioProvider(TelephonyProvider): # Generate StatusCallback URL using same pattern as outbound calls status_callback_attr = "" if workflow_run_id: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - status_callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + status_callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}" status_callback_attr = f' statusCallback="{status_callback_url}"' twiml_content = f""" diff --git a/api/services/telephony/providers/vobiz_provider.py b/api/services/telephony/providers/vobiz_provider.py index 596c02f..0493e3c 100644 --- a/api/services/telephony/providers/vobiz_provider.py +++ b/api/services/telephony/providers/vobiz_provider.py @@ -15,7 +15,7 @@ from api.services.telephony.base import ( NormalizedInboundData, TelephonyProvider, ) -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints if TYPE_CHECKING: from fastapi import WebSocket @@ -89,9 +89,9 @@ class VobizProvider(TelephonyProvider): # Add hangup callback if workflow_run_id provided if workflow_run_id: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - hangup_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}" - ring_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + hangup_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}" + ring_url = f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}" data.update( { "hangup_url": hangup_url, @@ -254,11 +254,11 @@ class VobizProvider(TelephonyProvider): - audioTrack: Which audio to stream (inbound, outbound, both) - contentType: audio/x-mulaw;rate=8000 """ - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + _, wss_backend_endpoint = await get_backend_endpoints() vobiz_xml = f""" - wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id} + {wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id} """ return vobiz_xml diff --git a/api/services/telephony/providers/vonage_provider.py b/api/services/telephony/providers/vonage_provider.py index e3b40fa..d4d9daf 100644 --- a/api/services/telephony/providers/vonage_provider.py +++ b/api/services/telephony/providers/vonage_provider.py @@ -18,7 +18,7 @@ from api.services.telephony.base import ( NormalizedInboundData, TelephonyProvider, ) -from api.utils.tunnel import TunnelURLProvider +from api.utils.common import get_backend_endpoints if TYPE_CHECKING: from fastapi import WebSocket @@ -106,8 +106,10 @@ class VonageProvider(TelephonyProvider): # Add event webhook if workflow_run_id provided if workflow_run_id: - backend_endpoint = await TunnelURLProvider.get_tunnel_url() - event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}" + backend_endpoint, _ = await get_backend_endpoints() + event_url = ( + f"{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}" + ) data.update({"event_url": [event_url], "event_method": "POST"}) data.update(kwargs) @@ -201,7 +203,7 @@ class VonageProvider(TelephonyProvider): Generate NCCO response for starting a call session. NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML. """ - backend_endpoint = await TunnelURLProvider.get_tunnel_url() + _, wss_backend_endpoint = await get_backend_endpoints() # NCCO for WebSocket connection ncco = [ @@ -210,7 +212,7 @@ class VonageProvider(TelephonyProvider): "endpoint": [ { "type": "websocket", - "uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}", + "uri": f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}", "content-type": "audio/l16;rate=16000", # 16kHz Linear PCM "headers": {}, } diff --git a/api/tests/test_get_backend_endpoints.py b/api/tests/test_get_backend_endpoints.py new file mode 100644 index 0000000..1d32ff7 --- /dev/null +++ b/api/tests/test_get_backend_endpoints.py @@ -0,0 +1,420 @@ +""" +Tests for get_backend_endpoints function in api/utils/common.py + +Expected behavior: +- Output URLs must always have a scheme (http:// or https://, ws:// or wss://) +- Output URLs must NOT have trailing slashes +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from api.utils.common import get_backend_endpoints, get_scheme + +# Valid test URLs covering various formats +possible_env_paths = [ + "http://localhost", + "http://localhost/", + "http://localhost:8000", + "http://localhost:8000/", + "http://127.0.0.1", + "http://127.0.0.1/", + "http://127.0.0.1:8000", + "http://127.0.0.1:8000/", + "http://xyz.com", + "http://xyz.com/", + "https://xyz.com", + "https://xyz.com/", + "localhost", + "localhost:8000", + "localhost/", + "localhost:8000/", + "xyz.com", + "xyz.com/", + "127.0.0.1", + "127.0.0.1/", + "127.0.0.1:8000", + "127.0.0.1:8000/", +] + +# Invalid URLs that should raise ValueError +invalid_env_paths = [ + "http:/localhost", # Typo: single slash in scheme + "http:/xyz.com", # Typo: single slash in scheme + "https:/xyz.com", # Typo: single slash in scheme + "htp://xyz.com", # Typo: missing 't' in http + "htps://xyz.com", # Typo: missing 't' in https + "http//xyz.com", # Missing colon + "http:xyz.com", # Missing slashes + "http://xyz.com:abc", # Invalid port (non-numeric) + "http://xyz.com:-1", # Invalid port (negative) + "http://xyz.com:99999", # Invalid port (out of range) + "http://", # Missing host + "https://", # Missing host + "http:// ", # Whitespace host + "http://xyz .com", # Space in hostname + "http://xyz\t.com", # Tab in hostname + "http://xyz\n.com", # Newline in hostname + "", # Empty string + " ", # Only whitespace +] + + +class TestGetScheme: + """Tests for the get_scheme helper function.""" + + def test_http_scheme(self): + assert get_scheme("http://example.com") == "http" + + def test_https_scheme(self): + assert get_scheme("https://example.com") == "https" + + def test_no_scheme(self): + assert get_scheme("example.com") is None + assert get_scheme("localhost:8000") is None + + def test_malformed_url_single_slash(self): + # 'http:/localhost' doesn't have '://' so returns None + assert get_scheme("http:/localhost") is None + + def test_ws_scheme(self): + assert get_scheme("ws://example.com") == "ws" + assert get_scheme("wss://example.com") == "wss" + + +class TestGetBackendEndpointsWithEnvVar: + """Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is set.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_url,expected_http,expected_ws", + [ + # URLs with http:// scheme (with and without trailing slash -> no trailing slash) + ("http://xyz.com", "http://xyz.com", "ws://xyz.com"), + ("http://xyz.com/", "http://xyz.com", "ws://xyz.com"), + # URLs with https:// scheme (with and without trailing slash -> no trailing slash) + ("https://xyz.com", "https://xyz.com", "wss://xyz.com"), + ("https://xyz.com/", "https://xyz.com", "wss://xyz.com"), + # URLs without scheme (should add http/ws, no trailing slash) + ("xyz.com", "http://xyz.com", "ws://xyz.com"), + ("xyz.com/", "http://xyz.com", "ws://xyz.com"), + ], + ) + async def test_non_localhost_urls(self, env_url, expected_http, expected_ws): + """Test non-localhost URLs return correct http and ws endpoints.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url): + http_url, ws_url = await get_backend_endpoints() + assert http_url == expected_http + assert ws_url == expected_ws + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_url,expected_http,expected_ws", + [ + # localhost URLs with http:// scheme + ("http://localhost", "http://localhost", "ws://localhost"), + ("http://localhost/", "http://localhost", "ws://localhost"), + ("http://localhost:8000", "http://localhost:8000", "ws://localhost:8000"), + ("http://localhost:8000/", "http://localhost:8000", "ws://localhost:8000"), + # localhost URLs without scheme (should add http/ws) + ("localhost", "http://localhost", "ws://localhost"), + ("localhost/", "http://localhost", "ws://localhost"), + ("localhost:8000", "http://localhost:8000", "ws://localhost:8000"), + ("localhost:8000/", "http://localhost:8000", "ws://localhost:8000"), + # 127.0.0.1 URLs with http:// scheme + ("http://127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"), + ("http://127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"), + ("http://127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"), + ("http://127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"), + # 127.0.0.1 URLs without scheme (should add http/ws) + ("127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"), + ("127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"), + ("127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"), + ("127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"), + ], + ) + async def test_localhost_urls_no_tunnel(self, env_url, expected_http, expected_ws): + """Test localhost/127.0.0.1 URLs when tunnel is not available.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.return_value = None + http_url, ws_url = await get_backend_endpoints() + assert http_url == expected_http + assert ws_url == expected_ws + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_url", + [ + "http://localhost", + "http://localhost/", + "http://localhost:8000", + "http://localhost:8000/", + "localhost", + "localhost/", + "localhost:8000", + "localhost:8000/", + "http://127.0.0.1", + "http://127.0.0.1/", + "http://127.0.0.1:8000", + "http://127.0.0.1:8000/", + "127.0.0.1", + "127.0.0.1/", + "127.0.0.1:8000", + "127.0.0.1:8000/", + ], + ) + async def test_localhost_urls_with_tunnel_available(self, env_url): + """Test localhost/127.0.0.1 URLs prefer tunnel when available.""" + tunnel_http = "https://abc123.trycloudflare.com" + tunnel_ws = "wss://abc123.trycloudflare.com" + + with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.return_value = (tunnel_http, tunnel_ws) + http_url, ws_url = await get_backend_endpoints() + assert http_url == tunnel_http + assert ws_url == tunnel_ws + + @pytest.mark.asyncio + async def test_localhost_tunnel_exception_falls_back(self): + """Test that tunnel exceptions fall back to localhost endpoint.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000"): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.side_effect = Exception("Tunnel not available") + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://localhost:8000" + assert ws_url == "ws://localhost:8000" + + @pytest.mark.asyncio + async def test_localhost_with_trailing_slash_tunnel_exception_falls_back(self): + """Test that tunnel exceptions fall back to localhost endpoint, trailing slash stripped.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000/"): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.side_effect = Exception("Tunnel not available") + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://localhost:8000" + assert ws_url == "ws://localhost:8000" + + @pytest.mark.asyncio + async def test_127_tunnel_exception_falls_back(self): + """Test that tunnel exceptions fall back to 127.0.0.1 endpoint.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000"): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.side_effect = Exception("Tunnel not available") + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://127.0.0.1:8000" + assert ws_url == "ws://127.0.0.1:8000" + + @pytest.mark.asyncio + async def test_127_with_trailing_slash_tunnel_exception_falls_back(self): + """Test that tunnel exceptions fall back to 127.0.0.1 endpoint, trailing slash stripped.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000/"): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.side_effect = Exception("Tunnel not available") + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://127.0.0.1:8000" + assert ws_url == "ws://127.0.0.1:8000" + + +class TestGetBackendEndpointsNoEnvVar: + """Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is not set.""" + + @pytest.mark.asyncio + async def test_uses_tunnel_when_no_env_var(self): + """Test that tunnel URLs are used when env var is not set.""" + tunnel_http = "https://abc123.trycloudflare.com" + tunnel_ws = "wss://abc123.trycloudflare.com" + + with patch("api.utils.common.BACKEND_API_ENDPOINT", None): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.return_value = (tunnel_http, tunnel_ws) + http_url, ws_url = await get_backend_endpoints() + assert http_url == tunnel_http + assert ws_url == tunnel_ws + + @pytest.mark.asyncio + async def test_raises_when_no_env_var_and_no_tunnel(self): + """Test that ValueError is raised when no env var and no tunnel.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", None): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.return_value = None + with pytest.raises(ValueError, match="No tunnel URL available"): + await get_backend_endpoints() + + +class TestSchemeMapping: + """Tests to verify correct scheme mapping (http->ws, https->wss).""" + + @pytest.mark.asyncio + async def test_http_maps_to_ws(self): + """Test http:// maps to ws://""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com"): + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://example.com" + assert ws_url == "ws://example.com" + + @pytest.mark.asyncio + async def test_https_maps_to_wss(self): + """Test https:// maps to wss://""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "https://example.com"): + http_url, ws_url = await get_backend_endpoints() + assert http_url == "https://example.com" + assert ws_url == "wss://example.com" + + @pytest.mark.asyncio + async def test_no_scheme_defaults_to_http_ws(self): + """Test URLs without scheme default to http:// and ws://""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "example.com"): + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://example.com" + assert ws_url == "ws://example.com" + + @pytest.mark.asyncio + async def test_trailing_slash_stripped(self): + """Test trailing slashes are stripped from output.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com/"): + http_url, ws_url = await get_backend_endpoints() + assert http_url == "http://example.com" + assert ws_url == "ws://example.com" + assert not http_url.endswith("/") + assert not ws_url.endswith("/") + + +class TestInvalidUrls: + """Tests for invalid URLs that should raise ValueError.""" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "http:/localhost", # Typo: single slash in scheme + "http:/xyz.com", # Typo: single slash in scheme + "https:/xyz.com", # Typo: single slash in scheme + ], + ) + async def test_malformed_scheme_single_slash(self, invalid_url): + """Test URLs with single slash in scheme raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with patch( + "api.utils.common.TunnelURLProvider.get_tunnel_urls", + new_callable=AsyncMock, + ) as mock_tunnel: + mock_tunnel.return_value = None + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "htp://xyz.com", # Typo: missing 't' in http + "htps://xyz.com", # Typo: missing 't' in https + "ftp://xyz.com", # Unsupported scheme + "file://xyz.com", # Unsupported scheme + ], + ) + async def test_invalid_or_unsupported_scheme(self, invalid_url): + """Test URLs with invalid or unsupported schemes raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "http//xyz.com", # Missing colon + "http:xyz.com", # Missing slashes + ], + ) + async def test_malformed_scheme_separator(self, invalid_url): + """Test URLs with malformed scheme separators raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "http://xyz.com:abc", # Invalid port (non-numeric) + "http://xyz.com:-1", # Invalid port (negative) + "http://xyz.com:99999", # Invalid port (out of range) + "http://xyz.com:", # Empty port + ], + ) + async def test_invalid_port(self, invalid_url): + """Test URLs with invalid ports raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "http://", # Missing host + "https://", # Missing host + ], + ) + async def test_missing_host(self, invalid_url): + """Test URLs with missing host raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "http:// ", # Whitespace host + "http://xyz .com", # Space in hostname + "http://xyz\t.com", # Tab in hostname + "http://xyz\n.com", # Newline in hostname + ], + ) + async def test_invalid_characters_in_host(self, invalid_url): + """Test URLs with invalid characters in hostname raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "invalid_url", + [ + "", # Empty string + " ", # Only whitespace + ], + ) + async def test_empty_or_whitespace_url(self, invalid_url): + """Test empty or whitespace-only URLs raise ValueError.""" + with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url): + with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"): + await get_backend_endpoints() diff --git a/api/utils/common.py b/api/utils/common.py new file mode 100644 index 0000000..b843d13 --- /dev/null +++ b/api/utils/common.py @@ -0,0 +1,181 @@ +""" +Common utilities. +Shared functions used across the application. +""" + +import re + +from loguru import logger + +from api.constants import BACKEND_API_ENDPOINT +from api.utils.tunnel import TunnelURLProvider + + +def get_scheme(url: str) -> str | None: + """ + Extract scheme from a given URL if present. + Returns None if not found + """ + idx = url.find("://") + if idx == -1: + return None + return url[:idx] + + +def _validate_url(url: str) -> None: + """ + Validate URL format and raise ValueError for invalid URLs. + + Checks for: + - Empty or whitespace-only URLs + - Malformed schemes (single slash, missing colon/slashes) + - Invalid/unsupported schemes + - Invalid ports (non-numeric, out of range, empty) + - Missing hosts + - Invalid characters in hostname (whitespace) + """ + # Check for empty or whitespace-only URLs + if not url or not url.strip(): + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT: URL cannot be empty or whitespace" + ) + + # Check for malformed schemes (single slash like http:/localhost) + if re.match(r"^https?:/[^/]", url): + raise ValueError(f"Invalid BACKEND_API_ENDPOINT: malformed scheme in '{url}'") + + # Check for malformed scheme separators (http// or http:xyz without //) + if re.match(r"^https?//[^/]", url) or re.match(r"^https?:[^/]", url): + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT: malformed scheme separator in '{url}'" + ) + + # Check for invalid/unsupported schemes + scheme = get_scheme(url) + if scheme and scheme not in ("http", "https"): + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT: unsupported scheme '{scheme}' in '{url}'" + ) + + # Parse URL for further validation + if scheme: + # URL has a scheme, extract host part + host_part = url[len(scheme) + 3 :] # Skip "scheme://" + else: + host_part = url + + # Strip trailing slash for host validation + host_part = host_part.rstrip("/") + + # Check for missing host + if not host_part or not host_part.strip(): + raise ValueError(f"Invalid BACKEND_API_ENDPOINT: missing host in '{url}'") + + # Check for invalid characters in hostname (whitespace) + if re.search(r"\s", host_part): + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT: invalid characters in hostname '{url}'" + ) + + # Check for invalid port - look for colon followed by anything + port_match = re.search(r":([^/]*)$", host_part) + if port_match: + port_str = port_match.group(1) + if not port_str: + raise ValueError(f"Invalid BACKEND_API_ENDPOINT: empty port in '{url}'") + # Check if port is numeric + if not port_str.isdigit(): + raise ValueError(f"Invalid BACKEND_API_ENDPOINT: invalid port in '{url}'") + port = int(port_str) + if port < 0 or port > 65535: + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT: port out of range in '{url}'" + ) + + +async def get_backend_endpoints() -> tuple[str, str]: + """ + Get the backend endpoint URLs for external access (webhooks, callbacks, WebSocket connections). + + Priority: + 1. BACKEND_API_ENDPOINT environment variable (if set and not localhost) + 2. Cloudflared Tunnel URLs (fallback for localhost or missing env var) + + Protocol Handling: + 1. If URL has http:// - returns http:// and ws:// + 2. If URL has https:// - returns https:// and wss:// + 3. If URL has no protocol - defaults to http:// and ws:// + + Returns: + tuple[str, str]: (backend_endpoint, wss_backend_endpoint) + + Raises: + ValueError: If no endpoint URL can be determined or URL is invalid + """ + + # If env var is explicitly set (even to empty/whitespace), validate it + if BACKEND_API_ENDPOINT is not None: + # Validate - this will raise for empty/whitespace + _validate_url(BACKEND_API_ENDPOINT) + + if BACKEND_API_ENDPOINT: + logger.debug( + f"Processing BACKEND_API_ENDPOINT from environment: {BACKEND_API_ENDPOINT}" + ) + + # Handle localhost/127.0.0.1 special case - use tunnel URL if available + if "localhost" in BACKEND_API_ENDPOINT or "127.0.0.1" in BACKEND_API_ENDPOINT: + logger.debug( + f"BACKEND_API_ENDPOINT is local ({BACKEND_API_ENDPOINT}), checking tunnel URL" + ) + try: + tunnel_urls = await TunnelURLProvider.get_tunnel_urls() + if tunnel_urls: + logger.debug( + f"Tunnel URLs available, using tunnel URLs instead of localhost" + ) + return tunnel_urls + else: + logger.debug( + f"Tunnel URLs returned None, proceeding with localhost endpoint" + ) + except Exception as e: + logger.debug( + f"No tunnel URLs available ({e}), proceeding with localhost endpoint" + ) + + try: + # Parse the URL to validate and handle protocol + scheme = get_scheme(BACKEND_API_ENDPOINT) + + if scheme: + http_url = BACKEND_API_ENDPOINT.rstrip("/") + ws_scheme = {"http": "ws", "https": "wss"}[scheme] + ws_url = BACKEND_API_ENDPOINT.rstrip("/").replace(scheme, ws_scheme, 1) + else: + http_url = "http://" + BACKEND_API_ENDPOINT.rstrip("/") + ws_url = "ws://" + BACKEND_API_ENDPOINT.rstrip("/") + + logger.debug( + f"Returning backend URLs - HTTP: {http_url}, WebSocket: {ws_url}" + ) + return http_url, ws_url + + except Exception as e: + # Case 4: Invalid URL format + raise ValueError( + f"Invalid BACKEND_API_ENDPOINT format: '{BACKEND_API_ENDPOINT}' - {str(e)}" + ) + + # Second priority: Query cloudflared tunnel URL when no environment variable is set + logger.debug("No BACKEND_API_ENDPOINT set, using tunnel URL") + tunnel_urls = await TunnelURLProvider.get_tunnel_urls() + if tunnel_urls: + logger.debug(f"Retrieved tunnel URLs: {tunnel_urls}") + return tunnel_urls + else: + logger.debug("No tunnel URLs available") + raise ValueError( + "No tunnel URL available. Please set BACKEND_API_ENDPOINT environment " + "variable or ensure cloudflared service is running." + ) diff --git a/api/utils/tunnel.py b/api/utils/tunnel.py index 31a4588..8e2a521 100644 --- a/api/utils/tunnel.py +++ b/api/utils/tunnel.py @@ -1,7 +1,6 @@ """Utility for getting the cloudflared tunnel URL at runtime.""" import asyncio -import os import re from typing import Optional @@ -10,37 +9,26 @@ from loguru import logger class TunnelURLProvider: - """Provider for getting the tunnel URL from cloudflared or environment.""" + """Provider for getting tunnel URLs from cloudflared service.""" @classmethod - async def get_tunnel_url(cls) -> str: + async def get_tunnel_urls(cls) -> tuple[str, str]: """ - Get the tunnel URL for external access. - - Priority: - 1. BACKEND_API_ENDPOINT environment variable (if set) - 2. Query cloudflared metrics endpoint - 3. Raise error if neither available + Get the tunnel URLs for external access. Returns: - str: The tunnel domain (without protocol) + tuple[str, str]: (https_url, wss_url) - Both URLs include full protocol Raises: ValueError: If no tunnel URL can be determined """ - # First priority: Check environment variable - env_endpoint = os.getenv("BACKEND_API_ENDPOINT") - if env_endpoint: - logger.debug(f"Using BACKEND_API_ENDPOINT from environment: {env_endpoint}") - return env_endpoint - # Second priority: Query cloudflared try: # Try to get URL from cloudflared metrics - url = await cls._get_cloudflared_url() - if url: - logger.info(f"Retrieved tunnel URL from cloudflared: {url}") - return url + urls = await cls._get_cloudflared_urls() + if urls: + logger.info(f"Retrieved tunnel URLs from cloudflared: {urls}") + return urls except Exception as e: logger.warning(f"Failed to get tunnel URL from cloudflared: {e}") @@ -50,12 +38,12 @@ class TunnelURLProvider: ) @classmethod - async def _get_cloudflared_url(cls) -> Optional[str]: + async def _get_cloudflared_urls(cls) -> Optional[tuple[str, str]]: """ - Query cloudflared metrics endpoint to get the tunnel URL. + Query cloudflared metrics endpoint to get the tunnel URLs. Returns: - Optional[str]: The tunnel domain (without protocol), or None if not found + Optional[tuple[str, str]]: (https_url, wss_url) with full protocols, or None if not found """ try: # Try to connect to cloudflared metrics endpoint @@ -83,12 +71,16 @@ class TunnelURLProvider: hostname = hostname.replace("https://", "").replace( "wss://", "" ) - return hostname + return "https://" + hostname, "wss://" + hostname # Alternative: Look for trycloudflare.com domain match = re.search(r"([a-z0-9-]+\.trycloudflare\.com)", text) if match: - return match.group(1) + hostname = match.group(1) + hostname = hostname.replace("https://", "").replace( + "wss://", "" + ) + return f"https://{hostname}", f"wss://{hostname}" logger.warning("Could not find tunnel URL in cloudflared metrics") return None