fix: reject misrouted smallwebrtc runs on the telephony websocket (#468)

* fix: reject misrouted smallwebrtc runs on the telephony websocket

A smallwebrtc (browser/WebRTC) workflow run is established through the WebRTC
signaling endpoint, not the PSTN telephony websocket. When such a run reached
_handle_telephony_websocket it read no "provider" from initial_context and
closed with an opaque "Provider type not found". Detect smallwebrtc runs and
close with a clear reason pointing to the signaling endpoint, without setting
the run to running or invoking a telephony provider. Also store the provider on
smallwebrtc runs at creation so they are self-describing, and make the generic
no-provider close reason include the run id and mode.

Closes #433

* fix: merge workflow run initial context defaults

---------

Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com>
Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
Matt Van Horn 2026-06-26 07:07:40 -07:00 committed by GitHub
parent faa73427c6
commit 3309face2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 108 additions and 5 deletions

View file

@ -93,12 +93,17 @@ class WorkflowRunClient(BaseDBClient):
else workflow.template_context_variables
)
merged_initial_context = {
**(default_context or {}),
**(initial_context or {}),
}
new_run = WorkflowRunModel(
name=name,
workflow=workflow,
mode=mode,
definition_id=target_def.id if target_def else None,
initial_context=initial_context or default_context,
initial_context=merged_initial_context,
gathered_context=gathered_context or {},
logs=logs or {},
campaign_id=campaign_id,

View file

@ -309,7 +309,10 @@ async def initialize_embed_session(
workflow_id=embed_token.workflow_id,
mode=WorkflowRunMode.SMALLWEBRTC.value,
user_id=embed_token.created_by, # Use token creator as run owner
initial_context=init_request.context_variables,
initial_context={
**(init_request.context_variables or {}),
"provider": WorkflowRunMode.SMALLWEBRTC.value,
},
)
except Exception as e:
logger.error(f"Failed to create workflow run: {e}")

View file

@ -21,7 +21,7 @@ from starlette.websockets import WebSocketDisconnect
from api.db import db_client
from api.db.models import UserModel
from api.enums import CallType, WorkflowRunState
from api.enums import CallType, WorkflowRunMode, WorkflowRunState
from api.errors.telephony_errors import TelephonyError
from api.sdk_expose import sdk_expose
from api.services.auth.depends import get_user
@ -584,12 +584,36 @@ async def _handle_telephony_websocket(
provider_type = workflow_run.initial_context.get("provider")
logger.info(f"Extracted provider_type: {provider_type}")
if (
workflow_run.mode == WorkflowRunMode.SMALLWEBRTC.value
or provider_type == WorkflowRunMode.SMALLWEBRTC.value
):
logger.warning(
f"SmallWebRTC workflow run {workflow_run_id} reached telephony "
f"websocket; mode={workflow_run.mode}, provider={provider_type}"
)
await websocket.close(
code=4400,
reason=(
"smallwebrtc runs connect through the WebRTC signaling endpoint, "
"not the telephony websocket"
),
)
return
if not provider_type:
logger.error(
f"No provider type found in workflow run {workflow_run_id}. "
f"gathered_context: {workflow_run.gathered_context}, mode: {workflow_run.mode}"
)
await websocket.close(code=4400, reason="Provider type not found")
await websocket.close(
code=4400,
reason=(
f"No provider type found for workflow run {workflow_run_id} "
f"(mode: {workflow_run.mode}); telephony websocket requires "
"a telephony provider"
),
)
return
logger.info(

View file

@ -1,10 +1,12 @@
from types import SimpleNamespace
from unittest.mock import ANY, AsyncMock, Mock, patch
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from api.routes.telephony import router
from api.enums import WorkflowRunMode, WorkflowRunState
from api.routes.telephony import _handle_telephony_websocket, router
from api.services.auth.depends import get_user
@ -215,3 +217,41 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
mock_db.get_workflow_run.assert_awaited_once_with(501, organization_id=11)
assert not mock_db.create_workflow_run.called
assert provider.initiate_call.await_count == 0
@pytest.mark.asyncio
async def test_smallwebrtc_run_reaching_telephony_websocket_closes_without_running():
websocket = AsyncMock()
workflow_run = SimpleNamespace(
id=501,
workflow_id=33,
mode=WorkflowRunMode.SMALLWEBRTC.value,
state=WorkflowRunState.INITIALIZED.value,
initial_context={},
gathered_context={},
)
workflow = SimpleNamespace(id=33, organization_id=11)
provider_lookup = AsyncMock()
with (
patch("api.routes.telephony.db_client") as mock_db,
patch(
"api.routes.telephony.get_telephony_provider_for_run",
new=provider_lookup,
),
):
mock_db.get_workflow_run = AsyncMock(return_value=workflow_run)
mock_db.get_workflow_by_id = AsyncMock(return_value=workflow)
mock_db.update_workflow_run = AsyncMock()
await _handle_telephony_websocket(websocket, 33, 99, 501)
websocket.close.assert_awaited_once_with(
code=4400,
reason=(
"smallwebrtc runs connect through the WebRTC signaling endpoint, "
"not the telephony websocket"
),
)
assert mock_db.update_workflow_run.await_count == 0
assert provider_lookup.await_count == 0

View file

@ -606,3 +606,34 @@ class TestRunDefinitionBinding:
)
assert run.definition_id == draft.id
async def test_run_initial_context_merges_with_template_context(
self, db_session, workflow_with_v1
):
"""Explicit run context should augment template context, not replace it."""
workflow, user = workflow_with_v1
await db_session.save_workflow_draft(
workflow_id=workflow.id,
template_context_variables={
"company_name": "Acme",
"default_only": "kept",
},
)
await db_session.publish_workflow_draft(workflow.id)
run = await db_session.create_workflow_run(
name="Embed Run",
workflow_id=workflow.id,
mode="smallwebrtc",
user_id=user.id,
initial_context={
"company_name": "Override Co",
"provider": "smallwebrtc",
},
)
assert run.initial_context == {
"company_name": "Override Co",
"default_only": "kept",
"provider": "smallwebrtc",
}