mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-07 07:55:16 +02:00
fix: BACKEND_API_ENDPOINT resolution from env and cloudflared tunnel (#135)
This commit is contained in:
parent
814271e7b1
commit
4a8e4fe7a1
11 changed files with 669 additions and 72 deletions
|
|
@ -58,11 +58,13 @@ class HealthResponse(BaseModel):
|
|||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health() -> HealthResponse:
|
||||
from api.constants import APP_VERSION, BACKEND_API_ENDPOINT
|
||||
from api.constants import APP_VERSION
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
logger.debug("Health endpoint called")
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
return HealthResponse(
|
||||
status="ok",
|
||||
version=APP_VERSION,
|
||||
backend_api_endpoint=BACKEND_API_ENDPOINT,
|
||||
backend_api_endpoint=backend_endpoint,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from api.db import db_client
|
|||
from api.enums import TriggerState
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
router = APIRouter(prefix="/public/agent")
|
||||
|
||||
|
|
@ -147,11 +147,11 @@ async def initiate_call(
|
|||
)
|
||||
|
||||
# 9. Construct webhook URL for telephony provider callback
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={trigger.workflow_id}"
|
||||
f"&user_id={api_key.created_by}"
|
||||
f"&workflow_run_id={workflow_run.id}"
|
||||
|
|
|
|||
|
|
@ -30,13 +30,13 @@ from api.services.telephony.factory import (
|
|||
get_all_telephony_providers,
|
||||
get_telephony_provider,
|
||||
)
|
||||
from api.utils.common import get_backend_endpoints
|
||||
from api.utils.telephony_helper import (
|
||||
generic_hangup_response,
|
||||
normalize_webhook_data,
|
||||
numbers_match,
|
||||
parse_webhook_request,
|
||||
)
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
router = APIRouter(prefix="/telephony")
|
||||
|
|
@ -159,12 +159,12 @@ async def initiate_call(
|
|||
workflow_run_name = workflow_run.name
|
||||
|
||||
# Construct webhook URL based on provider type
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={request.workflow_id}"
|
||||
f"&user_id={user.id}"
|
||||
f"&workflow_run_id={workflow_run_id}"
|
||||
|
|
@ -313,10 +313,8 @@ async def _validate_inbound_request(
|
|||
# Verify webhook signature/API key if provided
|
||||
provider_instance = None
|
||||
if x_twilio_signature or x_vobiz_signature or x_cx_apikey:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}"
|
||||
)
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/inbound/{workflow_id}"
|
||||
|
||||
# Get the real telephony provider with actual credentials for signature verification
|
||||
provider_instance = await get_telephony_provider(organization_id)
|
||||
|
|
@ -613,8 +611,8 @@ async def handle_twilio_status_callback(
|
|||
provider = await get_telephony_provider(workflow.organization_id)
|
||||
|
||||
if x_webhook_signature:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
full_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
full_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
full_url, callback_data, x_webhook_signature
|
||||
|
|
@ -887,8 +885,8 @@ async def handle_vobiz_hangup_callback(
|
|||
webhook_body = raw_body.decode("utf-8")
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
webhook_url,
|
||||
|
|
@ -1009,8 +1007,10 @@ async def handle_vobiz_ring_callback(
|
|||
webhook_body = raw_body.decode("utf-8")
|
||||
|
||||
# Verify signature
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
)
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
webhook_url,
|
||||
|
|
@ -1157,8 +1157,8 @@ async def handle_vobiz_hangup_callback_by_workflow(
|
|||
if x_vobiz_signature:
|
||||
raw_body = await request.body()
|
||||
webhook_body = raw_body.decode("utf-8")
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
webhook_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/workflow/{workflow_id}"
|
||||
|
||||
is_valid = await provider.verify_webhook_signature(
|
||||
webhook_url,
|
||||
|
|
@ -1338,8 +1338,8 @@ async def handle_inbound_telephony(
|
|||
)
|
||||
|
||||
# Generate response URLs
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
websocket_url = f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}"
|
||||
_, wss_backend_endpoint = await get_backend_endpoints()
|
||||
websocket_url = f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{workflow_context['user_id']}/{workflow_run_id}"
|
||||
response = await provider_class.generate_inbound_response(
|
||||
websocket_url, workflow_run_id
|
||||
)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from api.enums import OrganizationConfigurationKey, WorkflowRunState
|
|||
from api.services.campaign.rate_limiter import rate_limiter
|
||||
from api.services.telephony.base import TelephonyProvider
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
|
||||
class CampaignCallDispatcher:
|
||||
|
|
@ -249,10 +249,10 @@ class CampaignCallDispatcher:
|
|||
# Initiate call via telephony provider
|
||||
try:
|
||||
# Construct webhook URL with parameters
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={campaign.workflow_id}"
|
||||
f"&user_id={campaign.created_by}"
|
||||
f"&workflow_run_id={workflow_run.id}"
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from api.services.telephony.base import (
|
|||
NormalizedInboundData,
|
||||
TelephonyProvider,
|
||||
)
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -91,13 +91,13 @@ class CloudonixProvider(TelephonyProvider):
|
|||
|
||||
# Prepare call data using Cloudonix callObject schema
|
||||
# Note: 'caller-id' is REQUIRED by Cloudonix API
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
backend_endpoint, wss_backend_endpoint = await get_backend_endpoints()
|
||||
data: Dict[str, Any] = {
|
||||
"destination": to_number,
|
||||
"cxml": f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
|
||||
<Stream url="{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>""",
|
||||
|
|
@ -106,7 +106,7 @@ class CloudonixProvider(TelephonyProvider):
|
|||
|
||||
# Add status callback if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
callback_url = f"https://{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}"
|
||||
callback_url = f"{backend_endpoint}/api/v1/telephony/cloudonix/status-callback/{workflow_run_id}"
|
||||
data["callback"] = callback_url
|
||||
|
||||
# Merge any additional kwargs
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from api.services.telephony.base import (
|
|||
NormalizedInboundData,
|
||||
TelephonyProvider,
|
||||
)
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -75,8 +75,8 @@ class TwilioProvider(TelephonyProvider):
|
|||
|
||||
# Add status callback if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
data.update(
|
||||
{
|
||||
"StatusCallback": callback_url,
|
||||
|
|
@ -158,12 +158,12 @@ class TwilioProvider(TelephonyProvider):
|
|||
"""
|
||||
Generate TwiML response for starting a call session.
|
||||
"""
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
_, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
||||
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Connect>
|
||||
<Stream url="wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
|
||||
<Stream url="{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}"></Stream>
|
||||
</Connect>
|
||||
<Pause length="40"/>
|
||||
</Response>"""
|
||||
|
|
@ -405,8 +405,8 @@ class TwilioProvider(TelephonyProvider):
|
|||
# Generate StatusCallback URL using same pattern as outbound calls
|
||||
status_callback_attr = ""
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
status_callback_url = f"https://{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
status_callback_url = f"{backend_endpoint}/api/v1/telephony/twilio/status-callback/{workflow_run_id}"
|
||||
status_callback_attr = f' statusCallback="{status_callback_url}"'
|
||||
|
||||
twiml_content = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from api.services.telephony.base import (
|
|||
NormalizedInboundData,
|
||||
TelephonyProvider,
|
||||
)
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -89,9 +89,9 @@ class VobizProvider(TelephonyProvider):
|
|||
|
||||
# Add hangup callback if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
hangup_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
ring_url = f"https://{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
hangup_url = f"{backend_endpoint}/api/v1/telephony/vobiz/hangup-callback/{workflow_run_id}"
|
||||
ring_url = f"{backend_endpoint}/api/v1/telephony/vobiz/ring-callback/{workflow_run_id}"
|
||||
data.update(
|
||||
{
|
||||
"hangup_url": hangup_url,
|
||||
|
|
@ -254,11 +254,11 @@ class VobizProvider(TelephonyProvider):
|
|||
- audioTrack: Which audio to stream (inbound, outbound, both)
|
||||
- contentType: audio/x-mulaw;rate=8000
|
||||
"""
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
_, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
||||
vobiz_xml = f"""<?xml version="1.0" encoding="UTF-8"?>
|
||||
<Response>
|
||||
<Stream bidirectional="true" keepCallAlive="true" contentType="audio/x-mulaw;rate=8000">wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}</Stream>
|
||||
<Stream bidirectional="true" keepCallAlive="true" contentType="audio/x-mulaw;rate=8000">{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}</Stream>
|
||||
</Response>"""
|
||||
return vobiz_xml
|
||||
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from api.services.telephony.base import (
|
|||
NormalizedInboundData,
|
||||
TelephonyProvider,
|
||||
)
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from api.utils.common import get_backend_endpoints
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import WebSocket
|
||||
|
|
@ -106,8 +106,10 @@ class VonageProvider(TelephonyProvider):
|
|||
|
||||
# Add event webhook if workflow_run_id provided
|
||||
if workflow_run_id:
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
event_url = f"https://{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
|
||||
backend_endpoint, _ = await get_backend_endpoints()
|
||||
event_url = (
|
||||
f"{backend_endpoint}/api/v1/telephony/vonage/events/{workflow_run_id}"
|
||||
)
|
||||
data.update({"event_url": [event_url], "event_method": "POST"})
|
||||
|
||||
data.update(kwargs)
|
||||
|
|
@ -201,7 +203,7 @@ class VonageProvider(TelephonyProvider):
|
|||
Generate NCCO response for starting a call session.
|
||||
NCCO (Nexmo Call Control Objects) is JSON-based, unlike TwiML which is XML.
|
||||
"""
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
_, wss_backend_endpoint = await get_backend_endpoints()
|
||||
|
||||
# NCCO for WebSocket connection
|
||||
ncco = [
|
||||
|
|
@ -210,7 +212,7 @@ class VonageProvider(TelephonyProvider):
|
|||
"endpoint": [
|
||||
{
|
||||
"type": "websocket",
|
||||
"uri": f"wss://{backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
|
||||
"uri": f"{wss_backend_endpoint}/api/v1/telephony/ws/{workflow_id}/{user_id}/{workflow_run_id}",
|
||||
"content-type": "audio/l16;rate=16000", # 16kHz Linear PCM
|
||||
"headers": {},
|
||||
}
|
||||
|
|
|
|||
420
api/tests/test_get_backend_endpoints.py
Normal file
420
api/tests/test_get_backend_endpoints.py
Normal file
|
|
@ -0,0 +1,420 @@
|
|||
"""
|
||||
Tests for get_backend_endpoints function in api/utils/common.py
|
||||
|
||||
Expected behavior:
|
||||
- Output URLs must always have a scheme (http:// or https://, ws:// or wss://)
|
||||
- Output URLs must NOT have trailing slashes
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.utils.common import get_backend_endpoints, get_scheme
|
||||
|
||||
# Valid test URLs covering various formats
|
||||
possible_env_paths = [
|
||||
"http://localhost",
|
||||
"http://localhost/",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8000/",
|
||||
"http://127.0.0.1",
|
||||
"http://127.0.0.1/",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://127.0.0.1:8000/",
|
||||
"http://xyz.com",
|
||||
"http://xyz.com/",
|
||||
"https://xyz.com",
|
||||
"https://xyz.com/",
|
||||
"localhost",
|
||||
"localhost:8000",
|
||||
"localhost/",
|
||||
"localhost:8000/",
|
||||
"xyz.com",
|
||||
"xyz.com/",
|
||||
"127.0.0.1",
|
||||
"127.0.0.1/",
|
||||
"127.0.0.1:8000",
|
||||
"127.0.0.1:8000/",
|
||||
]
|
||||
|
||||
# Invalid URLs that should raise ValueError
|
||||
invalid_env_paths = [
|
||||
"http:/localhost", # Typo: single slash in scheme
|
||||
"http:/xyz.com", # Typo: single slash in scheme
|
||||
"https:/xyz.com", # Typo: single slash in scheme
|
||||
"htp://xyz.com", # Typo: missing 't' in http
|
||||
"htps://xyz.com", # Typo: missing 't' in https
|
||||
"http//xyz.com", # Missing colon
|
||||
"http:xyz.com", # Missing slashes
|
||||
"http://xyz.com:abc", # Invalid port (non-numeric)
|
||||
"http://xyz.com:-1", # Invalid port (negative)
|
||||
"http://xyz.com:99999", # Invalid port (out of range)
|
||||
"http://", # Missing host
|
||||
"https://", # Missing host
|
||||
"http:// ", # Whitespace host
|
||||
"http://xyz .com", # Space in hostname
|
||||
"http://xyz\t.com", # Tab in hostname
|
||||
"http://xyz\n.com", # Newline in hostname
|
||||
"", # Empty string
|
||||
" ", # Only whitespace
|
||||
]
|
||||
|
||||
|
||||
class TestGetScheme:
|
||||
"""Tests for the get_scheme helper function."""
|
||||
|
||||
def test_http_scheme(self):
|
||||
assert get_scheme("http://example.com") == "http"
|
||||
|
||||
def test_https_scheme(self):
|
||||
assert get_scheme("https://example.com") == "https"
|
||||
|
||||
def test_no_scheme(self):
|
||||
assert get_scheme("example.com") is None
|
||||
assert get_scheme("localhost:8000") is None
|
||||
|
||||
def test_malformed_url_single_slash(self):
|
||||
# 'http:/localhost' doesn't have '://' so returns None
|
||||
assert get_scheme("http:/localhost") is None
|
||||
|
||||
def test_ws_scheme(self):
|
||||
assert get_scheme("ws://example.com") == "ws"
|
||||
assert get_scheme("wss://example.com") == "wss"
|
||||
|
||||
|
||||
class TestGetBackendEndpointsWithEnvVar:
|
||||
"""Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is set."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"env_url,expected_http,expected_ws",
|
||||
[
|
||||
# URLs with http:// scheme (with and without trailing slash -> no trailing slash)
|
||||
("http://xyz.com", "http://xyz.com", "ws://xyz.com"),
|
||||
("http://xyz.com/", "http://xyz.com", "ws://xyz.com"),
|
||||
# URLs with https:// scheme (with and without trailing slash -> no trailing slash)
|
||||
("https://xyz.com", "https://xyz.com", "wss://xyz.com"),
|
||||
("https://xyz.com/", "https://xyz.com", "wss://xyz.com"),
|
||||
# URLs without scheme (should add http/ws, no trailing slash)
|
||||
("xyz.com", "http://xyz.com", "ws://xyz.com"),
|
||||
("xyz.com/", "http://xyz.com", "ws://xyz.com"),
|
||||
],
|
||||
)
|
||||
async def test_non_localhost_urls(self, env_url, expected_http, expected_ws):
|
||||
"""Test non-localhost URLs return correct http and ws endpoints."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == expected_http
|
||||
assert ws_url == expected_ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"env_url,expected_http,expected_ws",
|
||||
[
|
||||
# localhost URLs with http:// scheme
|
||||
("http://localhost", "http://localhost", "ws://localhost"),
|
||||
("http://localhost/", "http://localhost", "ws://localhost"),
|
||||
("http://localhost:8000", "http://localhost:8000", "ws://localhost:8000"),
|
||||
("http://localhost:8000/", "http://localhost:8000", "ws://localhost:8000"),
|
||||
# localhost URLs without scheme (should add http/ws)
|
||||
("localhost", "http://localhost", "ws://localhost"),
|
||||
("localhost/", "http://localhost", "ws://localhost"),
|
||||
("localhost:8000", "http://localhost:8000", "ws://localhost:8000"),
|
||||
("localhost:8000/", "http://localhost:8000", "ws://localhost:8000"),
|
||||
# 127.0.0.1 URLs with http:// scheme
|
||||
("http://127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"),
|
||||
("http://127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"),
|
||||
("http://127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
|
||||
("http://127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
|
||||
# 127.0.0.1 URLs without scheme (should add http/ws)
|
||||
("127.0.0.1", "http://127.0.0.1", "ws://127.0.0.1"),
|
||||
("127.0.0.1/", "http://127.0.0.1", "ws://127.0.0.1"),
|
||||
("127.0.0.1:8000", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
|
||||
("127.0.0.1:8000/", "http://127.0.0.1:8000", "ws://127.0.0.1:8000"),
|
||||
],
|
||||
)
|
||||
async def test_localhost_urls_no_tunnel(self, env_url, expected_http, expected_ws):
|
||||
"""Test localhost/127.0.0.1 URLs when tunnel is not available."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.return_value = None
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == expected_http
|
||||
assert ws_url == expected_ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"env_url",
|
||||
[
|
||||
"http://localhost",
|
||||
"http://localhost/",
|
||||
"http://localhost:8000",
|
||||
"http://localhost:8000/",
|
||||
"localhost",
|
||||
"localhost/",
|
||||
"localhost:8000",
|
||||
"localhost:8000/",
|
||||
"http://127.0.0.1",
|
||||
"http://127.0.0.1/",
|
||||
"http://127.0.0.1:8000",
|
||||
"http://127.0.0.1:8000/",
|
||||
"127.0.0.1",
|
||||
"127.0.0.1/",
|
||||
"127.0.0.1:8000",
|
||||
"127.0.0.1:8000/",
|
||||
],
|
||||
)
|
||||
async def test_localhost_urls_with_tunnel_available(self, env_url):
|
||||
"""Test localhost/127.0.0.1 URLs prefer tunnel when available."""
|
||||
tunnel_http = "https://abc123.trycloudflare.com"
|
||||
tunnel_ws = "wss://abc123.trycloudflare.com"
|
||||
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", env_url):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.return_value = (tunnel_http, tunnel_ws)
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == tunnel_http
|
||||
assert ws_url == tunnel_ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localhost_tunnel_exception_falls_back(self):
|
||||
"""Test that tunnel exceptions fall back to localhost endpoint."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000"):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.side_effect = Exception("Tunnel not available")
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://localhost:8000"
|
||||
assert ws_url == "ws://localhost:8000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_localhost_with_trailing_slash_tunnel_exception_falls_back(self):
|
||||
"""Test that tunnel exceptions fall back to localhost endpoint, trailing slash stripped."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://localhost:8000/"):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.side_effect = Exception("Tunnel not available")
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://localhost:8000"
|
||||
assert ws_url == "ws://localhost:8000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_127_tunnel_exception_falls_back(self):
|
||||
"""Test that tunnel exceptions fall back to 127.0.0.1 endpoint."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000"):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.side_effect = Exception("Tunnel not available")
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://127.0.0.1:8000"
|
||||
assert ws_url == "ws://127.0.0.1:8000"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_127_with_trailing_slash_tunnel_exception_falls_back(self):
|
||||
"""Test that tunnel exceptions fall back to 127.0.0.1 endpoint, trailing slash stripped."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://127.0.0.1:8000/"):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.side_effect = Exception("Tunnel not available")
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://127.0.0.1:8000"
|
||||
assert ws_url == "ws://127.0.0.1:8000"
|
||||
|
||||
|
||||
class TestGetBackendEndpointsNoEnvVar:
|
||||
"""Tests for get_backend_endpoints when BACKEND_API_ENDPOINT is not set."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_tunnel_when_no_env_var(self):
|
||||
"""Test that tunnel URLs are used when env var is not set."""
|
||||
tunnel_http = "https://abc123.trycloudflare.com"
|
||||
tunnel_ws = "wss://abc123.trycloudflare.com"
|
||||
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", None):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.return_value = (tunnel_http, tunnel_ws)
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == tunnel_http
|
||||
assert ws_url == tunnel_ws
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raises_when_no_env_var_and_no_tunnel(self):
|
||||
"""Test that ValueError is raised when no env var and no tunnel."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", None):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.return_value = None
|
||||
with pytest.raises(ValueError, match="No tunnel URL available"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
|
||||
class TestSchemeMapping:
|
||||
"""Tests to verify correct scheme mapping (http->ws, https->wss)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_maps_to_ws(self):
|
||||
"""Test http:// maps to ws://"""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com"):
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://example.com"
|
||||
assert ws_url == "ws://example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_https_maps_to_wss(self):
|
||||
"""Test https:// maps to wss://"""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "https://example.com"):
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "https://example.com"
|
||||
assert ws_url == "wss://example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_scheme_defaults_to_http_ws(self):
|
||||
"""Test URLs without scheme default to http:// and ws://"""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "example.com"):
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://example.com"
|
||||
assert ws_url == "ws://example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trailing_slash_stripped(self):
|
||||
"""Test trailing slashes are stripped from output."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", "http://example.com/"):
|
||||
http_url, ws_url = await get_backend_endpoints()
|
||||
assert http_url == "http://example.com"
|
||||
assert ws_url == "ws://example.com"
|
||||
assert not http_url.endswith("/")
|
||||
assert not ws_url.endswith("/")
|
||||
|
||||
|
||||
class TestInvalidUrls:
|
||||
"""Tests for invalid URLs that should raise ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"http:/localhost", # Typo: single slash in scheme
|
||||
"http:/xyz.com", # Typo: single slash in scheme
|
||||
"https:/xyz.com", # Typo: single slash in scheme
|
||||
],
|
||||
)
|
||||
async def test_malformed_scheme_single_slash(self, invalid_url):
|
||||
"""Test URLs with single slash in scheme raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with patch(
|
||||
"api.utils.common.TunnelURLProvider.get_tunnel_urls",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_tunnel:
|
||||
mock_tunnel.return_value = None
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"htp://xyz.com", # Typo: missing 't' in http
|
||||
"htps://xyz.com", # Typo: missing 't' in https
|
||||
"ftp://xyz.com", # Unsupported scheme
|
||||
"file://xyz.com", # Unsupported scheme
|
||||
],
|
||||
)
|
||||
async def test_invalid_or_unsupported_scheme(self, invalid_url):
|
||||
"""Test URLs with invalid or unsupported schemes raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"http//xyz.com", # Missing colon
|
||||
"http:xyz.com", # Missing slashes
|
||||
],
|
||||
)
|
||||
async def test_malformed_scheme_separator(self, invalid_url):
|
||||
"""Test URLs with malformed scheme separators raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"http://xyz.com:abc", # Invalid port (non-numeric)
|
||||
"http://xyz.com:-1", # Invalid port (negative)
|
||||
"http://xyz.com:99999", # Invalid port (out of range)
|
||||
"http://xyz.com:", # Empty port
|
||||
],
|
||||
)
|
||||
async def test_invalid_port(self, invalid_url):
|
||||
"""Test URLs with invalid ports raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"http://", # Missing host
|
||||
"https://", # Missing host
|
||||
],
|
||||
)
|
||||
async def test_missing_host(self, invalid_url):
|
||||
"""Test URLs with missing host raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"http:// ", # Whitespace host
|
||||
"http://xyz .com", # Space in hostname
|
||||
"http://xyz\t.com", # Tab in hostname
|
||||
"http://xyz\n.com", # Newline in hostname
|
||||
],
|
||||
)
|
||||
async def test_invalid_characters_in_host(self, invalid_url):
|
||||
"""Test URLs with invalid characters in hostname raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_url",
|
||||
[
|
||||
"", # Empty string
|
||||
" ", # Only whitespace
|
||||
],
|
||||
)
|
||||
async def test_empty_or_whitespace_url(self, invalid_url):
|
||||
"""Test empty or whitespace-only URLs raise ValueError."""
|
||||
with patch("api.utils.common.BACKEND_API_ENDPOINT", invalid_url):
|
||||
with pytest.raises(ValueError, match="Invalid BACKEND_API_ENDPOINT"):
|
||||
await get_backend_endpoints()
|
||||
181
api/utils/common.py
Normal file
181
api/utils/common.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
"""
|
||||
Common utilities.
|
||||
Shared functions used across the application.
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import BACKEND_API_ENDPOINT
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
|
||||
|
||||
def get_scheme(url: str) -> str | None:
|
||||
"""
|
||||
Extract scheme from a given URL if present.
|
||||
Returns None if not found
|
||||
"""
|
||||
idx = url.find("://")
|
||||
if idx == -1:
|
||||
return None
|
||||
return url[:idx]
|
||||
|
||||
|
||||
def _validate_url(url: str) -> None:
|
||||
"""
|
||||
Validate URL format and raise ValueError for invalid URLs.
|
||||
|
||||
Checks for:
|
||||
- Empty or whitespace-only URLs
|
||||
- Malformed schemes (single slash, missing colon/slashes)
|
||||
- Invalid/unsupported schemes
|
||||
- Invalid ports (non-numeric, out of range, empty)
|
||||
- Missing hosts
|
||||
- Invalid characters in hostname (whitespace)
|
||||
"""
|
||||
# Check for empty or whitespace-only URLs
|
||||
if not url or not url.strip():
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT: URL cannot be empty or whitespace"
|
||||
)
|
||||
|
||||
# Check for malformed schemes (single slash like http:/localhost)
|
||||
if re.match(r"^https?:/[^/]", url):
|
||||
raise ValueError(f"Invalid BACKEND_API_ENDPOINT: malformed scheme in '{url}'")
|
||||
|
||||
# Check for malformed scheme separators (http// or http:xyz without //)
|
||||
if re.match(r"^https?//[^/]", url) or re.match(r"^https?:[^/]", url):
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT: malformed scheme separator in '{url}'"
|
||||
)
|
||||
|
||||
# Check for invalid/unsupported schemes
|
||||
scheme = get_scheme(url)
|
||||
if scheme and scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT: unsupported scheme '{scheme}' in '{url}'"
|
||||
)
|
||||
|
||||
# Parse URL for further validation
|
||||
if scheme:
|
||||
# URL has a scheme, extract host part
|
||||
host_part = url[len(scheme) + 3 :] # Skip "scheme://"
|
||||
else:
|
||||
host_part = url
|
||||
|
||||
# Strip trailing slash for host validation
|
||||
host_part = host_part.rstrip("/")
|
||||
|
||||
# Check for missing host
|
||||
if not host_part or not host_part.strip():
|
||||
raise ValueError(f"Invalid BACKEND_API_ENDPOINT: missing host in '{url}'")
|
||||
|
||||
# Check for invalid characters in hostname (whitespace)
|
||||
if re.search(r"\s", host_part):
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT: invalid characters in hostname '{url}'"
|
||||
)
|
||||
|
||||
# Check for invalid port - look for colon followed by anything
|
||||
port_match = re.search(r":([^/]*)$", host_part)
|
||||
if port_match:
|
||||
port_str = port_match.group(1)
|
||||
if not port_str:
|
||||
raise ValueError(f"Invalid BACKEND_API_ENDPOINT: empty port in '{url}'")
|
||||
# Check if port is numeric
|
||||
if not port_str.isdigit():
|
||||
raise ValueError(f"Invalid BACKEND_API_ENDPOINT: invalid port in '{url}'")
|
||||
port = int(port_str)
|
||||
if port < 0 or port > 65535:
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT: port out of range in '{url}'"
|
||||
)
|
||||
|
||||
|
||||
async def get_backend_endpoints() -> tuple[str, str]:
|
||||
"""
|
||||
Get the backend endpoint URLs for external access (webhooks, callbacks, WebSocket connections).
|
||||
|
||||
Priority:
|
||||
1. BACKEND_API_ENDPOINT environment variable (if set and not localhost)
|
||||
2. Cloudflared Tunnel URLs (fallback for localhost or missing env var)
|
||||
|
||||
Protocol Handling:
|
||||
1. If URL has http:// - returns http:// and ws://
|
||||
2. If URL has https:// - returns https:// and wss://
|
||||
3. If URL has no protocol - defaults to http:// and ws://
|
||||
|
||||
Returns:
|
||||
tuple[str, str]: (backend_endpoint, wss_backend_endpoint)
|
||||
|
||||
Raises:
|
||||
ValueError: If no endpoint URL can be determined or URL is invalid
|
||||
"""
|
||||
|
||||
# If env var is explicitly set (even to empty/whitespace), validate it
|
||||
if BACKEND_API_ENDPOINT is not None:
|
||||
# Validate - this will raise for empty/whitespace
|
||||
_validate_url(BACKEND_API_ENDPOINT)
|
||||
|
||||
if BACKEND_API_ENDPOINT:
|
||||
logger.debug(
|
||||
f"Processing BACKEND_API_ENDPOINT from environment: {BACKEND_API_ENDPOINT}"
|
||||
)
|
||||
|
||||
# Handle localhost/127.0.0.1 special case - use tunnel URL if available
|
||||
if "localhost" in BACKEND_API_ENDPOINT or "127.0.0.1" in BACKEND_API_ENDPOINT:
|
||||
logger.debug(
|
||||
f"BACKEND_API_ENDPOINT is local ({BACKEND_API_ENDPOINT}), checking tunnel URL"
|
||||
)
|
||||
try:
|
||||
tunnel_urls = await TunnelURLProvider.get_tunnel_urls()
|
||||
if tunnel_urls:
|
||||
logger.debug(
|
||||
f"Tunnel URLs available, using tunnel URLs instead of localhost"
|
||||
)
|
||||
return tunnel_urls
|
||||
else:
|
||||
logger.debug(
|
||||
f"Tunnel URLs returned None, proceeding with localhost endpoint"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"No tunnel URLs available ({e}), proceeding with localhost endpoint"
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse the URL to validate and handle protocol
|
||||
scheme = get_scheme(BACKEND_API_ENDPOINT)
|
||||
|
||||
if scheme:
|
||||
http_url = BACKEND_API_ENDPOINT.rstrip("/")
|
||||
ws_scheme = {"http": "ws", "https": "wss"}[scheme]
|
||||
ws_url = BACKEND_API_ENDPOINT.rstrip("/").replace(scheme, ws_scheme, 1)
|
||||
else:
|
||||
http_url = "http://" + BACKEND_API_ENDPOINT.rstrip("/")
|
||||
ws_url = "ws://" + BACKEND_API_ENDPOINT.rstrip("/")
|
||||
|
||||
logger.debug(
|
||||
f"Returning backend URLs - HTTP: {http_url}, WebSocket: {ws_url}"
|
||||
)
|
||||
return http_url, ws_url
|
||||
|
||||
except Exception as e:
|
||||
# Case 4: Invalid URL format
|
||||
raise ValueError(
|
||||
f"Invalid BACKEND_API_ENDPOINT format: '{BACKEND_API_ENDPOINT}' - {str(e)}"
|
||||
)
|
||||
|
||||
# Second priority: Query cloudflared tunnel URL when no environment variable is set
|
||||
logger.debug("No BACKEND_API_ENDPOINT set, using tunnel URL")
|
||||
tunnel_urls = await TunnelURLProvider.get_tunnel_urls()
|
||||
if tunnel_urls:
|
||||
logger.debug(f"Retrieved tunnel URLs: {tunnel_urls}")
|
||||
return tunnel_urls
|
||||
else:
|
||||
logger.debug("No tunnel URLs available")
|
||||
raise ValueError(
|
||||
"No tunnel URL available. Please set BACKEND_API_ENDPOINT environment "
|
||||
"variable or ensure cloudflared service is running."
|
||||
)
|
||||
|
|
@ -1,7 +1,6 @@
|
|||
"""Utility for getting the cloudflared tunnel URL at runtime."""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
|
|
@ -10,37 +9,26 @@ from loguru import logger
|
|||
|
||||
|
||||
class TunnelURLProvider:
|
||||
"""Provider for getting the tunnel URL from cloudflared or environment."""
|
||||
"""Provider for getting tunnel URLs from cloudflared service."""
|
||||
|
||||
@classmethod
|
||||
async def get_tunnel_url(cls) -> str:
|
||||
async def get_tunnel_urls(cls) -> tuple[str, str]:
|
||||
"""
|
||||
Get the tunnel URL for external access.
|
||||
|
||||
Priority:
|
||||
1. BACKEND_API_ENDPOINT environment variable (if set)
|
||||
2. Query cloudflared metrics endpoint
|
||||
3. Raise error if neither available
|
||||
Get the tunnel URLs for external access.
|
||||
|
||||
Returns:
|
||||
str: The tunnel domain (without protocol)
|
||||
tuple[str, str]: (https_url, wss_url) - Both URLs include full protocol
|
||||
|
||||
Raises:
|
||||
ValueError: If no tunnel URL can be determined
|
||||
"""
|
||||
# First priority: Check environment variable
|
||||
env_endpoint = os.getenv("BACKEND_API_ENDPOINT")
|
||||
if env_endpoint:
|
||||
logger.debug(f"Using BACKEND_API_ENDPOINT from environment: {env_endpoint}")
|
||||
return env_endpoint
|
||||
|
||||
# Second priority: Query cloudflared
|
||||
try:
|
||||
# Try to get URL from cloudflared metrics
|
||||
url = await cls._get_cloudflared_url()
|
||||
if url:
|
||||
logger.info(f"Retrieved tunnel URL from cloudflared: {url}")
|
||||
return url
|
||||
urls = await cls._get_cloudflared_urls()
|
||||
if urls:
|
||||
logger.info(f"Retrieved tunnel URLs from cloudflared: {urls}")
|
||||
return urls
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get tunnel URL from cloudflared: {e}")
|
||||
|
||||
|
|
@ -50,12 +38,12 @@ class TunnelURLProvider:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
async def _get_cloudflared_url(cls) -> Optional[str]:
|
||||
async def _get_cloudflared_urls(cls) -> Optional[tuple[str, str]]:
|
||||
"""
|
||||
Query cloudflared metrics endpoint to get the tunnel URL.
|
||||
Query cloudflared metrics endpoint to get the tunnel URLs.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The tunnel domain (without protocol), or None if not found
|
||||
Optional[tuple[str, str]]: (https_url, wss_url) with full protocols, or None if not found
|
||||
"""
|
||||
try:
|
||||
# Try to connect to cloudflared metrics endpoint
|
||||
|
|
@ -83,12 +71,16 @@ class TunnelURLProvider:
|
|||
hostname = hostname.replace("https://", "").replace(
|
||||
"wss://", ""
|
||||
)
|
||||
return hostname
|
||||
return "https://" + hostname, "wss://" + hostname
|
||||
|
||||
# Alternative: Look for trycloudflare.com domain
|
||||
match = re.search(r"([a-z0-9-]+\.trycloudflare\.com)", text)
|
||||
if match:
|
||||
return match.group(1)
|
||||
hostname = match.group(1)
|
||||
hostname = hostname.replace("https://", "").replace(
|
||||
"wss://", ""
|
||||
)
|
||||
return f"https://{hostname}", f"wss://{hostname}"
|
||||
|
||||
logger.warning("Could not find tunnel URL in cloudflared metrics")
|
||||
return None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue