mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-28 08:49:42 +02:00
fix: reject misrouted smallwebrtc runs on the telephony websocket (#468)
* fix: reject misrouted smallwebrtc runs on the telephony websocket A smallwebrtc (browser/WebRTC) workflow run is established through the WebRTC signaling endpoint, not the PSTN telephony websocket. When such a run reached _handle_telephony_websocket it read no "provider" from initial_context and closed with an opaque "Provider type not found". Detect smallwebrtc runs and close with a clear reason pointing to the signaling endpoint, without setting the run to running or invoking a telephony provider. Also store the provider on smallwebrtc runs at creation so they are self-describing, and make the generic no-provider close reason include the run id and mode. Closes #433 * fix: merge workflow run initial context defaults --------- Co-authored-by: Matt Van Horn <455140+mvanhorn@users.noreply.github.com> Co-authored-by: Abhishek Kumar <abhishek@a6k.me>
This commit is contained in:
parent
faa73427c6
commit
3309face2c
5 changed files with 108 additions and 5 deletions
|
|
@ -93,12 +93,17 @@ class WorkflowRunClient(BaseDBClient):
|
||||||
else workflow.template_context_variables
|
else workflow.template_context_variables
|
||||||
)
|
)
|
||||||
|
|
||||||
|
merged_initial_context = {
|
||||||
|
**(default_context or {}),
|
||||||
|
**(initial_context or {}),
|
||||||
|
}
|
||||||
|
|
||||||
new_run = WorkflowRunModel(
|
new_run = WorkflowRunModel(
|
||||||
name=name,
|
name=name,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
definition_id=target_def.id if target_def else None,
|
definition_id=target_def.id if target_def else None,
|
||||||
initial_context=initial_context or default_context,
|
initial_context=merged_initial_context,
|
||||||
gathered_context=gathered_context or {},
|
gathered_context=gathered_context or {},
|
||||||
logs=logs or {},
|
logs=logs or {},
|
||||||
campaign_id=campaign_id,
|
campaign_id=campaign_id,
|
||||||
|
|
|
||||||
|
|
@ -309,7 +309,10 @@ async def initialize_embed_session(
|
||||||
workflow_id=embed_token.workflow_id,
|
workflow_id=embed_token.workflow_id,
|
||||||
mode=WorkflowRunMode.SMALLWEBRTC.value,
|
mode=WorkflowRunMode.SMALLWEBRTC.value,
|
||||||
user_id=embed_token.created_by, # Use token creator as run owner
|
user_id=embed_token.created_by, # Use token creator as run owner
|
||||||
initial_context=init_request.context_variables,
|
initial_context={
|
||||||
|
**(init_request.context_variables or {}),
|
||||||
|
"provider": WorkflowRunMode.SMALLWEBRTC.value,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to create workflow run: {e}")
|
logger.error(f"Failed to create workflow run: {e}")
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from starlette.websockets import WebSocketDisconnect
|
||||||
|
|
||||||
from api.db import db_client
|
from api.db import db_client
|
||||||
from api.db.models import UserModel
|
from api.db.models import UserModel
|
||||||
from api.enums import CallType, WorkflowRunState
|
from api.enums import CallType, WorkflowRunMode, WorkflowRunState
|
||||||
from api.errors.telephony_errors import TelephonyError
|
from api.errors.telephony_errors import TelephonyError
|
||||||
from api.sdk_expose import sdk_expose
|
from api.sdk_expose import sdk_expose
|
||||||
from api.services.auth.depends import get_user
|
from api.services.auth.depends import get_user
|
||||||
|
|
@ -584,12 +584,36 @@ async def _handle_telephony_websocket(
|
||||||
provider_type = workflow_run.initial_context.get("provider")
|
provider_type = workflow_run.initial_context.get("provider")
|
||||||
logger.info(f"Extracted provider_type: {provider_type}")
|
logger.info(f"Extracted provider_type: {provider_type}")
|
||||||
|
|
||||||
|
if (
|
||||||
|
workflow_run.mode == WorkflowRunMode.SMALLWEBRTC.value
|
||||||
|
or provider_type == WorkflowRunMode.SMALLWEBRTC.value
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
f"SmallWebRTC workflow run {workflow_run_id} reached telephony "
|
||||||
|
f"websocket; mode={workflow_run.mode}, provider={provider_type}"
|
||||||
|
)
|
||||||
|
await websocket.close(
|
||||||
|
code=4400,
|
||||||
|
reason=(
|
||||||
|
"smallwebrtc runs connect through the WebRTC signaling endpoint, "
|
||||||
|
"not the telephony websocket"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
if not provider_type:
|
if not provider_type:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"No provider type found in workflow run {workflow_run_id}. "
|
f"No provider type found in workflow run {workflow_run_id}. "
|
||||||
f"gathered_context: {workflow_run.gathered_context}, mode: {workflow_run.mode}"
|
f"gathered_context: {workflow_run.gathered_context}, mode: {workflow_run.mode}"
|
||||||
)
|
)
|
||||||
await websocket.close(code=4400, reason="Provider type not found")
|
await websocket.close(
|
||||||
|
code=4400,
|
||||||
|
reason=(
|
||||||
|
f"No provider type found for workflow run {workflow_run_id} "
|
||||||
|
f"(mode: {workflow_run.mode}); telephony websocket requires "
|
||||||
|
"a telephony provider"
|
||||||
|
),
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import ANY, AsyncMock, Mock, patch
|
from unittest.mock import ANY, AsyncMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from api.routes.telephony import router
|
from api.enums import WorkflowRunMode, WorkflowRunState
|
||||||
|
from api.routes.telephony import _handle_telephony_websocket, router
|
||||||
from api.services.auth.depends import get_user
|
from api.services.auth.depends import get_user
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -215,3 +217,41 @@ def test_initiate_call_rejects_existing_run_for_different_workflow():
|
||||||
mock_db.get_workflow_run.assert_awaited_once_with(501, organization_id=11)
|
mock_db.get_workflow_run.assert_awaited_once_with(501, organization_id=11)
|
||||||
assert not mock_db.create_workflow_run.called
|
assert not mock_db.create_workflow_run.called
|
||||||
assert provider.initiate_call.await_count == 0
|
assert provider.initiate_call.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_smallwebrtc_run_reaching_telephony_websocket_closes_without_running():
|
||||||
|
websocket = AsyncMock()
|
||||||
|
workflow_run = SimpleNamespace(
|
||||||
|
id=501,
|
||||||
|
workflow_id=33,
|
||||||
|
mode=WorkflowRunMode.SMALLWEBRTC.value,
|
||||||
|
state=WorkflowRunState.INITIALIZED.value,
|
||||||
|
initial_context={},
|
||||||
|
gathered_context={},
|
||||||
|
)
|
||||||
|
workflow = SimpleNamespace(id=33, organization_id=11)
|
||||||
|
provider_lookup = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("api.routes.telephony.db_client") as mock_db,
|
||||||
|
patch(
|
||||||
|
"api.routes.telephony.get_telephony_provider_for_run",
|
||||||
|
new=provider_lookup,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
mock_db.get_workflow_run = AsyncMock(return_value=workflow_run)
|
||||||
|
mock_db.get_workflow_by_id = AsyncMock(return_value=workflow)
|
||||||
|
mock_db.update_workflow_run = AsyncMock()
|
||||||
|
|
||||||
|
await _handle_telephony_websocket(websocket, 33, 99, 501)
|
||||||
|
|
||||||
|
websocket.close.assert_awaited_once_with(
|
||||||
|
code=4400,
|
||||||
|
reason=(
|
||||||
|
"smallwebrtc runs connect through the WebRTC signaling endpoint, "
|
||||||
|
"not the telephony websocket"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
assert mock_db.update_workflow_run.await_count == 0
|
||||||
|
assert provider_lookup.await_count == 0
|
||||||
|
|
|
||||||
|
|
@ -606,3 +606,34 @@ class TestRunDefinitionBinding:
|
||||||
)
|
)
|
||||||
|
|
||||||
assert run.definition_id == draft.id
|
assert run.definition_id == draft.id
|
||||||
|
|
||||||
|
async def test_run_initial_context_merges_with_template_context(
|
||||||
|
self, db_session, workflow_with_v1
|
||||||
|
):
|
||||||
|
"""Explicit run context should augment template context, not replace it."""
|
||||||
|
workflow, user = workflow_with_v1
|
||||||
|
await db_session.save_workflow_draft(
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
template_context_variables={
|
||||||
|
"company_name": "Acme",
|
||||||
|
"default_only": "kept",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await db_session.publish_workflow_draft(workflow.id)
|
||||||
|
|
||||||
|
run = await db_session.create_workflow_run(
|
||||||
|
name="Embed Run",
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
mode="smallwebrtc",
|
||||||
|
user_id=user.id,
|
||||||
|
initial_context={
|
||||||
|
"company_name": "Override Co",
|
||||||
|
"provider": "smallwebrtc",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert run.initial_context == {
|
||||||
|
"company_name": "Override Co",
|
||||||
|
"default_only": "kept",
|
||||||
|
"provider": "smallwebrtc",
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue