mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-07-01 08:59:46 +02:00
Feat/Add API Trigger and Webhooks in Agent Builder (#83)
* feat: add api trigger node for agent runs * feat: add webhook node * Execute webhook nodes post workflow run * Add hint to go to API keys
This commit is contained in:
parent
4ddb144dd0
commit
55b727a872
37 changed files with 3667 additions and 494 deletions
|
|
@ -0,0 +1,109 @@
|
|||
"""add external credentials model
|
||||
|
||||
Revision ID: 36b5dbf670e4
|
||||
Revises: c7c56dd36b21
|
||||
Create Date: 2025-12-22 05:29:31.061141
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "36b5dbf670e4"
|
||||
down_revision: Union[str, None] = "c7c56dd36b21"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum(
|
||||
"none",
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"basic_auth",
|
||||
"custom_header",
|
||||
name="webhook_credential_type",
|
||||
).create(op.get_bind())
|
||||
op.create_table(
|
||||
"external_credentials",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("credential_uuid", sa.String(length=36), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.String(), nullable=False),
|
||||
sa.Column("description", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"credential_type",
|
||||
postgresql.ENUM(
|
||||
"none",
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"basic_auth",
|
||||
"custom_header",
|
||||
name="webhook_credential_type",
|
||||
create_type=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("credential_data", sa.JSON(), nullable=False),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint(
|
||||
"organization_id", "name", name="unique_org_credential_name"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_external_credentials_credential_uuid"),
|
||||
"external_credentials",
|
||||
["credential_uuid"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_webhook_credentials_organization_id",
|
||||
"external_credentials",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_webhook_credentials_uuid",
|
||||
"external_credentials",
|
||||
["credential_uuid"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_webhook_credentials_uuid", table_name="external_credentials")
|
||||
op.drop_index(
|
||||
"ix_webhook_credentials_organization_id", table_name="external_credentials"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_external_credentials_credential_uuid"),
|
||||
table_name="external_credentials",
|
||||
)
|
||||
op.drop_table("external_credentials")
|
||||
sa.Enum(
|
||||
"none",
|
||||
"api_key",
|
||||
"bearer_token",
|
||||
"basic_auth",
|
||||
"custom_header",
|
||||
name="webhook_credential_type",
|
||||
).drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -5,13 +5,14 @@ Revises: a188ff90e76f
|
|||
Create Date: 2025-12-10 17:34:31.232048
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = '49a8fe6841e6'
|
||||
down_revision: Union[str, None] = 'a188ff90e76f'
|
||||
revision: str = "49a8fe6841e6"
|
||||
down_revision: Union[str, None] = "a188ff90e76f"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
|
@ -19,21 +20,20 @@ depends_on: Union[str, Sequence[str], None] = None
|
|||
def upgrade() -> None:
|
||||
# Create the workflow_run_state enum type
|
||||
workflow_run_state_enum = sa.Enum(
|
||||
'initialized', 'running', 'completed',
|
||||
name='workflow_run_state'
|
||||
"initialized", "running", "completed", name="workflow_run_state"
|
||||
)
|
||||
workflow_run_state_enum.create(op.get_bind())
|
||||
|
||||
|
||||
# Add the state column to workflow_runs table (nullable first)
|
||||
op.add_column(
|
||||
'workflow_runs',
|
||||
"workflow_runs",
|
||||
sa.Column(
|
||||
'state',
|
||||
sa.Enum('initialized', 'running', 'completed', name='workflow_run_state'),
|
||||
nullable=True
|
||||
)
|
||||
"state",
|
||||
sa.Enum("initialized", "running", "completed", name="workflow_run_state"),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Set appropriate state values for existing records
|
||||
# Completed workflows should be marked as 'completed'
|
||||
# Non-completed workflows should be marked as 'initialized'
|
||||
|
|
@ -44,19 +44,16 @@ def upgrade() -> None:
|
|||
ELSE 'initialized'::workflow_run_state
|
||||
END
|
||||
""")
|
||||
|
||||
|
||||
# Now make the column non-nullable with 'initialized' as default for new records
|
||||
op.alter_column(
|
||||
'workflow_runs',
|
||||
'state',
|
||||
nullable=False,
|
||||
server_default='initialized'
|
||||
"workflow_runs", "state", nullable=False, server_default="initialized"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the state column
|
||||
op.drop_column('workflow_runs', 'state')
|
||||
|
||||
op.drop_column("workflow_runs", "state")
|
||||
|
||||
# Drop the enum type
|
||||
sa.Enum(name='workflow_run_state').drop(op.get_bind())
|
||||
sa.Enum(name="workflow_run_state").drop(op.get_bind())
|
||||
|
|
|
|||
68
api/alembic/versions/c7c56dd36b21_add_agent_trigger.py
Normal file
68
api/alembic/versions/c7c56dd36b21_add_agent_trigger.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
"""add agent trigger
|
||||
|
||||
Revision ID: c7c56dd36b21
|
||||
Revises: 49a8fe6841e6
|
||||
Create Date: 2025-12-21 08:21:06.692772
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c7c56dd36b21"
|
||||
down_revision: Union[str, None] = "49a8fe6841e6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum("active", "archived", name="trigger_state").create(op.get_bind())
|
||||
op.create_table(
|
||||
"agent_triggers",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("trigger_path", sa.String(length=36), nullable=False),
|
||||
sa.Column("workflow_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"state",
|
||||
postgresql.ENUM(
|
||||
"active", "archived", name="trigger_state", create_type=False
|
||||
),
|
||||
server_default=sa.text("'active'::trigger_state"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["workflow_id"], ["workflows.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_triggers_state", "agent_triggers", ["state"], unique=False
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_agent_triggers_trigger_path"),
|
||||
"agent_triggers",
|
||||
["trigger_path"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_agent_triggers_workflow_id", "agent_triggers", ["workflow_id"], unique=False
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index("ix_agent_triggers_workflow_id", table_name="agent_triggers")
|
||||
op.drop_index(op.f("ix_agent_triggers_trigger_path"), table_name="agent_triggers")
|
||||
op.drop_index("ix_agent_triggers_state", table_name="agent_triggers")
|
||||
op.drop_table("agent_triggers")
|
||||
sa.Enum("active", "archived", name="trigger_state").drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
118
api/db/agent_trigger_client.py
Normal file
118
api/db/agent_trigger_client.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Database client for managing agent triggers."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import and_, select, update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import AgentTriggerModel
|
||||
from api.enums import TriggerState
|
||||
|
||||
|
||||
class AgentTriggerClient(BaseDBClient):
|
||||
"""Client for managing agent triggers (UUID -> workflow_id mappings)."""
|
||||
|
||||
async def get_agent_trigger_by_path(
|
||||
self, trigger_path: str, active_only: bool = True
|
||||
) -> Optional[AgentTriggerModel]:
|
||||
"""Get an agent trigger by its unique path (UUID).
|
||||
|
||||
Args:
|
||||
trigger_path: The unique trigger UUID
|
||||
active_only: If True, only return active triggers
|
||||
|
||||
Returns:
|
||||
AgentTriggerModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(AgentTriggerModel).where(
|
||||
AgentTriggerModel.trigger_path == trigger_path
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(
|
||||
AgentTriggerModel.state == TriggerState.ACTIVE.value
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def sync_triggers_for_workflow(
|
||||
self, workflow_id: int, organization_id: int, trigger_paths: List[str]
|
||||
) -> None:
|
||||
"""Sync triggers for a workflow based on the trigger nodes in the workflow definition.
|
||||
|
||||
This creates/reactivates triggers that are in the workflow definition
|
||||
and archives triggers that are no longer in the workflow.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
trigger_paths: List of trigger UUIDs from the workflow definition
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# Get all existing triggers for this workflow (including archived)
|
||||
result = await session.execute(
|
||||
select(AgentTriggerModel).where(
|
||||
AgentTriggerModel.workflow_id == workflow_id
|
||||
)
|
||||
)
|
||||
existing_triggers = {t.trigger_path: t for t in result.scalars().all()}
|
||||
|
||||
existing_paths = set(existing_triggers.keys())
|
||||
new_paths = set(trigger_paths)
|
||||
|
||||
# Archive triggers that are no longer in the workflow definition
|
||||
paths_to_archive = existing_paths - new_paths
|
||||
if paths_to_archive:
|
||||
await session.execute(
|
||||
update(AgentTriggerModel)
|
||||
.where(AgentTriggerModel.trigger_path.in_(paths_to_archive))
|
||||
.values(state=TriggerState.ARCHIVED.value)
|
||||
)
|
||||
logger.info(
|
||||
f"Archived {len(paths_to_archive)} triggers for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
# Reactivate existing triggers that are back in the workflow
|
||||
paths_to_reactivate = new_paths & existing_paths
|
||||
if paths_to_reactivate:
|
||||
await session.execute(
|
||||
update(AgentTriggerModel)
|
||||
.where(
|
||||
and_(
|
||||
AgentTriggerModel.trigger_path.in_(paths_to_reactivate),
|
||||
AgentTriggerModel.state == TriggerState.ARCHIVED.value,
|
||||
)
|
||||
)
|
||||
.values(state=TriggerState.ACTIVE.value)
|
||||
)
|
||||
|
||||
# Add new triggers
|
||||
paths_to_add = new_paths - existing_paths
|
||||
for trigger_path in paths_to_add:
|
||||
stmt = insert(AgentTriggerModel).values(
|
||||
trigger_path=trigger_path,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
state=TriggerState.ACTIVE.value,
|
||||
)
|
||||
# Handle race condition where trigger might already exist for another workflow
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["trigger_path"],
|
||||
set_={
|
||||
"workflow_id": workflow_id,
|
||||
"organization_id": organization_id,
|
||||
"state": TriggerState.ACTIVE.value,
|
||||
},
|
||||
)
|
||||
await session.execute(stmt)
|
||||
|
||||
if paths_to_add:
|
||||
logger.info(
|
||||
f"Added {len(paths_to_add)} triggers for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
from api.db.agent_trigger_client import AgentTriggerClient
|
||||
from api.db.api_key_client import APIKeyClient
|
||||
from api.db.campaign_client import CampaignClient
|
||||
from api.db.embed_token_client import EmbedTokenClient
|
||||
|
|
@ -8,6 +9,7 @@ from api.db.organization_configuration_client import OrganizationConfigurationCl
|
|||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
from api.db.reports_client import ReportsClient
|
||||
from api.db.user_client import UserClient
|
||||
from api.db.webhook_credential_client import WebhookCredentialClient
|
||||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
|
@ -27,6 +29,8 @@ class DBClient(
|
|||
ReportsClient,
|
||||
APIKeyClient,
|
||||
EmbedTokenClient,
|
||||
AgentTriggerClient,
|
||||
WebhookCredentialClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
@ -45,6 +49,8 @@ class DBClient(
|
|||
- ReportsClient: handles reports and analytics operations
|
||||
- APIKeyClient: handles API key operations
|
||||
- EmbedTokenClient: handles embed token and session operations
|
||||
- AgentTriggerClient: handles agent trigger operations for API-based call triggering
|
||||
- WebhookCredentialClient: handles webhook credential operations
|
||||
"""
|
||||
|
||||
pass
|
||||
|
|
|
|||
126
api/db/models.py
126
api/db/models.py
|
|
@ -1,3 +1,4 @@
|
|||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from loguru import logger
|
||||
|
|
@ -19,7 +20,14 @@ from sqlalchemy import (
|
|||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from ..enums import IntegrationAction, WorkflowRunMode, WorkflowRunState, WorkflowStatus
|
||||
from ..enums import (
|
||||
IntegrationAction,
|
||||
TriggerState,
|
||||
WebhookCredentialType,
|
||||
WorkflowRunMode,
|
||||
WorkflowRunState,
|
||||
WorkflowStatus,
|
||||
)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
|
@ -676,3 +684,119 @@ class EmbedSessionModel(Base):
|
|||
# Relationships
|
||||
embed_token = relationship("EmbedTokenModel", back_populates="sessions")
|
||||
workflow_run = relationship("WorkflowRunModel")
|
||||
|
||||
|
||||
class AgentTriggerModel(Base):
|
||||
"""Model for storing agent trigger mappings (UUID -> workflow_id).
|
||||
|
||||
This is a minimal lookup table that maps trigger UUIDs to workflows.
|
||||
The trigger node in the workflow definition is the source of truth.
|
||||
"""
|
||||
|
||||
__tablename__ = "agent_triggers"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Unique trigger path (UUID format) - generated by UI when trigger node is created
|
||||
trigger_path = Column(String(36), unique=True, nullable=False, index=True)
|
||||
|
||||
# Link to workflow
|
||||
workflow_id = Column(
|
||||
Integer, ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# State management (active/archived)
|
||||
state = Column(
|
||||
Enum(*[state.value for state in TriggerState], name="trigger_state"),
|
||||
nullable=False,
|
||||
default=TriggerState.ACTIVE.value,
|
||||
server_default=text("'active'::trigger_state"),
|
||||
)
|
||||
|
||||
# Audit
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Relationships
|
||||
workflow = relationship("WorkflowModel")
|
||||
organization = relationship("OrganizationModel")
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index("ix_agent_triggers_workflow_id", "workflow_id"),
|
||||
Index("ix_agent_triggers_state", "state"),
|
||||
)
|
||||
|
||||
|
||||
class ExternalCredentialModel(Base):
|
||||
"""Model for storing external authentication credentials.
|
||||
|
||||
Credentials are stored separately from webhook configurations to allow
|
||||
reuse across multiple workflows and secure storage of sensitive data.
|
||||
"""
|
||||
|
||||
__tablename__ = "external_credentials"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Public UUID reference (used in APIs and workflow definitions)
|
||||
# This prevents enumeration attacks and hides internal IDs
|
||||
credential_uuid = Column(
|
||||
String(36),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
default=lambda: str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
# Organization scoping
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# Credential metadata
|
||||
name = Column(String, nullable=False) # Display name, e.g., "Salesforce API"
|
||||
description = Column(String, nullable=True) # Optional description
|
||||
|
||||
# Credential type - uses enum from api/enums.py
|
||||
credential_type = Column(
|
||||
Enum(
|
||||
*[t.value for t in WebhookCredentialType],
|
||||
name="webhook_credential_type",
|
||||
),
|
||||
nullable=False,
|
||||
default=WebhookCredentialType.NONE.value,
|
||||
)
|
||||
|
||||
# Encrypted credential data (JSON)
|
||||
# Structure depends on credential_type:
|
||||
# - api_key: {"header_name": "X-API-Key", "api_key": "value"}
|
||||
# - bearer_token: {"token": "value"}
|
||||
# - basic_auth: {"username": "user", "password": "value"}
|
||||
# - custom_header: {"header_name": "X-Custom", "header_value": "value"}
|
||||
credential_data = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Audit fields
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
updated_at = Column(
|
||||
DateTime(timezone=True),
|
||||
default=lambda: datetime.now(UTC),
|
||||
onupdate=lambda: datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Soft delete for safety
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes and constraints
|
||||
__table_args__ = (
|
||||
Index("ix_webhook_credentials_organization_id", "organization_id"),
|
||||
Index("ix_webhook_credentials_uuid", "credential_uuid"),
|
||||
UniqueConstraint("organization_id", "name", name="unique_org_credential_name"),
|
||||
)
|
||||
|
|
|
|||
220
api/db/webhook_credential_client.py
Normal file
220
api/db/webhook_credential_client.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Database client for managing webhook credentials."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import ExternalCredentialModel
|
||||
|
||||
|
||||
class WebhookCredentialClient(BaseDBClient):
|
||||
"""Client for managing webhook credentials (organization-scoped, UUID-referenced)."""
|
||||
|
||||
async def create_credential(
|
||||
self,
|
||||
organization_id: int,
|
||||
user_id: int,
|
||||
name: str,
|
||||
credential_type: str,
|
||||
credential_data: dict,
|
||||
description: Optional[str] = None,
|
||||
) -> ExternalCredentialModel:
|
||||
"""Create a new webhook credential.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
user_id: ID of the user creating the credential
|
||||
name: Display name for the credential
|
||||
credential_type: Type of credential (none, api_key, bearer_token, basic_auth, custom_header)
|
||||
credential_data: JSON data containing the credential details
|
||||
description: Optional description
|
||||
|
||||
Returns:
|
||||
The created ExternalCredentialModel with auto-generated UUID
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
credential = ExternalCredentialModel(
|
||||
organization_id=organization_id,
|
||||
created_by=user_id,
|
||||
name=name,
|
||||
description=description,
|
||||
credential_type=credential_type,
|
||||
credential_data=credential_data,
|
||||
)
|
||||
|
||||
session.add(credential)
|
||||
await session.commit()
|
||||
await session.refresh(credential)
|
||||
|
||||
logger.info(
|
||||
f"Created webhook credential '{name}' ({credential.credential_uuid}) "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return credential
|
||||
|
||||
async def get_credentials_for_organization(
|
||||
self, organization_id: int, active_only: bool = True
|
||||
) -> List[ExternalCredentialModel]:
|
||||
"""Get all credentials for an organization.
|
||||
|
||||
Args:
|
||||
organization_id: ID of the organization
|
||||
active_only: If True, only return active (non-deleted) credentials
|
||||
|
||||
Returns:
|
||||
List of ExternalCredentialModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(ExternalCredentialModel).where(
|
||||
ExternalCredentialModel.organization_id == organization_id
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(ExternalCredentialModel.is_active.is_(True))
|
||||
|
||||
query = query.order_by(ExternalCredentialModel.name)
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_credential_by_uuid(
|
||||
self, credential_uuid: str, organization_id: int, active_only: bool = True
|
||||
) -> Optional[ExternalCredentialModel]:
|
||||
"""Get a credential by its UUID, scoped to organization.
|
||||
|
||||
Args:
|
||||
credential_uuid: The unique credential UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
active_only: If True, only return if active
|
||||
|
||||
Returns:
|
||||
ExternalCredentialModel if found and authorized, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(ExternalCredentialModel).where(
|
||||
ExternalCredentialModel.credential_uuid == credential_uuid,
|
||||
ExternalCredentialModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
if active_only:
|
||||
query = query.where(ExternalCredentialModel.is_active.is_(True))
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update_credential(
|
||||
self,
|
||||
credential_uuid: str,
|
||||
organization_id: int,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
credential_type: Optional[str] = None,
|
||||
credential_data: Optional[dict] = None,
|
||||
) -> Optional[ExternalCredentialModel]:
|
||||
"""Update a credential by UUID.
|
||||
|
||||
Args:
|
||||
credential_uuid: The unique credential UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
name: New name (if provided)
|
||||
description: New description (if provided)
|
||||
credential_type: New credential type (if provided)
|
||||
credential_data: New credential data (if provided)
|
||||
|
||||
Returns:
|
||||
Updated ExternalCredentialModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
# First check if credential exists and belongs to organization
|
||||
credential = await self.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
)
|
||||
if not credential:
|
||||
return None
|
||||
|
||||
# Build update values
|
||||
update_values = {"updated_at": datetime.now(UTC)}
|
||||
if name is not None:
|
||||
update_values["name"] = name
|
||||
if description is not None:
|
||||
update_values["description"] = description
|
||||
if credential_type is not None:
|
||||
update_values["credential_type"] = credential_type
|
||||
if credential_data is not None:
|
||||
update_values["credential_data"] = credential_data
|
||||
|
||||
await session.execute(
|
||||
update(ExternalCredentialModel)
|
||||
.where(
|
||||
ExternalCredentialModel.credential_uuid == credential_uuid,
|
||||
ExternalCredentialModel.organization_id == organization_id,
|
||||
)
|
||||
.values(**update_values)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Fetch updated credential
|
||||
result = await session.execute(
|
||||
select(ExternalCredentialModel).where(
|
||||
ExternalCredentialModel.credential_uuid == credential_uuid
|
||||
)
|
||||
)
|
||||
updated_credential = result.scalar_one()
|
||||
|
||||
logger.info(
|
||||
f"Updated webhook credential {credential_uuid} "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return updated_credential
|
||||
|
||||
async def delete_credential(
|
||||
self, credential_uuid: str, organization_id: int
|
||||
) -> bool:
|
||||
"""Soft delete a credential by UUID.
|
||||
|
||||
Args:
|
||||
credential_uuid: The unique credential UUID
|
||||
organization_id: ID of the organization (for authorization)
|
||||
|
||||
Returns:
|
||||
True if credential was deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
update(ExternalCredentialModel)
|
||||
.where(
|
||||
ExternalCredentialModel.credential_uuid == credential_uuid,
|
||||
ExternalCredentialModel.organization_id == organization_id,
|
||||
ExternalCredentialModel.is_active.is_(True),
|
||||
)
|
||||
.values(is_active=False, updated_at=datetime.now(UTC))
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
if result.rowcount > 0:
|
||||
logger.info(
|
||||
f"Soft deleted webhook credential {credential_uuid} "
|
||||
f"for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
async def validate_credential_uuid(
|
||||
self, credential_uuid: str, organization_id: int
|
||||
) -> bool:
|
||||
"""Check if a credential UUID exists and belongs to the organization.
|
||||
|
||||
This is useful for workflow validation to ensure referenced credentials exist.
|
||||
|
||||
Args:
|
||||
credential_uuid: The credential UUID to validate
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if valid, False otherwise
|
||||
"""
|
||||
credential = await self.get_credential_by_uuid(credential_uuid, organization_id)
|
||||
return credential is not None
|
||||
|
|
@ -394,8 +394,9 @@ class WorkflowRunClient(BaseDBClient):
|
|||
result = await session.execute(
|
||||
select(WorkflowRunModel)
|
||||
.options(
|
||||
selectinload(WorkflowRunModel.workflow).selectinload(
|
||||
WorkflowModel.user
|
||||
selectinload(WorkflowRunModel.workflow).options(
|
||||
selectinload(WorkflowModel.user),
|
||||
selectinload(WorkflowModel.current_definition),
|
||||
)
|
||||
)
|
||||
.where(WorkflowRunModel.id == workflow_run_id)
|
||||
|
|
|
|||
21
api/enums.py
21
api/enums.py
|
|
@ -56,8 +56,8 @@ class StorageBackend(Enum):
|
|||
|
||||
class WorkflowRunState(Enum):
|
||||
INITIALIZED = "initialized" # Workflow run created, ready for connection
|
||||
RUNNING = "running" # Websocket connected and pipeline active
|
||||
COMPLETED = "completed" # Workflow run finished
|
||||
RUNNING = "running" # Websocket connected and pipeline active
|
||||
COMPLETED = "completed" # Workflow run finished
|
||||
|
||||
|
||||
class WorkflowRunStatus(Enum):
|
||||
|
|
@ -92,3 +92,20 @@ class RedisChannel(Enum):
|
|||
"""Redis pub/sub channel names"""
|
||||
|
||||
CAMPAIGN_EVENTS = "campaign_events"
|
||||
|
||||
|
||||
class TriggerState(Enum):
|
||||
"""Agent trigger state values"""
|
||||
|
||||
ACTIVE = "active"
|
||||
ARCHIVED = "archived"
|
||||
|
||||
|
||||
class WebhookCredentialType(Enum):
|
||||
"""Webhook credential authentication types"""
|
||||
|
||||
NONE = "none" # No authentication
|
||||
API_KEY = "api_key" # API key in header
|
||||
BEARER_TOKEN = "bearer_token" # Bearer token auth
|
||||
BASIC_AUTH = "basic_auth" # Username/password
|
||||
CUSTOM_HEADER = "custom_header" # Custom header key-value
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from api.db.models import UserModel
|
|||
from api.enums import OrganizationConfigurationKey
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.runner import campaign_runner_service
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
router = APIRouter(prefix="/campaign")
|
||||
|
|
@ -182,6 +183,11 @@ async def start_campaign(
|
|||
detail="You must configure telephony first by going to APP_URL/configure-telephony",
|
||||
)
|
||||
|
||||
# Check Dograh quota before starting campaign
|
||||
quota_result = await check_dograh_quota(user)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Verify campaign exists and belongs to organization
|
||||
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
|
||||
if not campaign:
|
||||
|
|
@ -290,6 +296,11 @@ async def resume_campaign(
|
|||
detail="You must configure telephony first by going to APP_URL/configure-telephony",
|
||||
)
|
||||
|
||||
# Check Dograh quota before resuming campaign
|
||||
quota_result = await check_dograh_quota(user)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Verify campaign exists and belongs to organization
|
||||
campaign = await db_client.get_campaign(campaign_id, user.selected_organization_id)
|
||||
if not campaign:
|
||||
|
|
|
|||
284
api/routes/credentials.py
Normal file
284
api/routes/credentials.py
Normal file
|
|
@ -0,0 +1,284 @@
|
|||
"""API routes for managing webhook credentials."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.enums import WebhookCredentialType
|
||||
from api.services.auth.depends import get_user
|
||||
|
||||
router = APIRouter(prefix="/credentials")
|
||||
|
||||
|
||||
# Request/Response schemas
|
||||
class CreateCredentialRequest(BaseModel):
|
||||
"""Request schema for creating a webhook credential."""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
credential_type: WebhookCredentialType
|
||||
credential_data: dict # Validated based on credential_type
|
||||
|
||||
|
||||
class UpdateCredentialRequest(BaseModel):
|
||||
"""Request schema for updating a webhook credential."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
credential_type: Optional[WebhookCredentialType] = None
|
||||
credential_data: Optional[dict] = None
|
||||
|
||||
|
||||
class CredentialResponse(BaseModel):
|
||||
"""Response schema for a webhook credential (never includes sensitive data)."""
|
||||
|
||||
uuid: str
|
||||
name: str
|
||||
description: Optional[str]
|
||||
credential_type: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
def validate_credential_data(
|
||||
credential_type: WebhookCredentialType, credential_data: dict
|
||||
) -> None:
|
||||
"""Validate that credential_data matches the expected structure for the credential type.
|
||||
|
||||
Args:
|
||||
credential_type: The type of credential
|
||||
credential_data: The credential data to validate
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails
|
||||
"""
|
||||
if credential_type == WebhookCredentialType.NONE:
|
||||
# No data required
|
||||
return
|
||||
|
||||
if credential_type == WebhookCredentialType.API_KEY:
|
||||
if "header_name" not in credential_data or "api_key" not in credential_data:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="API Key credential requires 'header_name' and 'api_key' fields",
|
||||
)
|
||||
|
||||
elif credential_type == WebhookCredentialType.BEARER_TOKEN:
|
||||
if "token" not in credential_data:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Bearer Token credential requires 'token' field",
|
||||
)
|
||||
|
||||
elif credential_type == WebhookCredentialType.BASIC_AUTH:
|
||||
if "username" not in credential_data or "password" not in credential_data:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Basic Auth credential requires 'username' and 'password' fields",
|
||||
)
|
||||
|
||||
elif credential_type == WebhookCredentialType.CUSTOM_HEADER:
|
||||
if (
|
||||
"header_name" not in credential_data
|
||||
or "header_value" not in credential_data
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Custom Header credential requires 'header_name' and 'header_value' fields",
|
||||
)
|
||||
|
||||
|
||||
def build_credential_response(credential) -> CredentialResponse:
|
||||
"""Build a response from a credential model (excluding sensitive data)."""
|
||||
return CredentialResponse(
|
||||
uuid=credential.credential_uuid,
|
||||
name=credential.name,
|
||||
description=credential.description,
|
||||
credential_type=credential.credential_type,
|
||||
created_at=credential.created_at,
|
||||
updated_at=credential.updated_at,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/")
|
||||
async def list_credentials(
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> List[CredentialResponse]:
|
||||
"""
|
||||
List all webhook credentials for the user's organization.
|
||||
|
||||
Returns:
|
||||
List of credentials (without sensitive data)
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
credentials = await db_client.get_credentials_for_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
return [build_credential_response(cred) for cred in credentials]
|
||||
|
||||
|
||||
@router.post("/")
|
||||
async def create_credential(
|
||||
request: CreateCredentialRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> CredentialResponse:
|
||||
"""
|
||||
Create a new webhook credential.
|
||||
|
||||
Args:
|
||||
request: The credential creation request
|
||||
|
||||
Returns:
|
||||
The created credential (without sensitive data)
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
# Validate credential data structure
|
||||
validate_credential_data(request.credential_type, request.credential_data)
|
||||
|
||||
try:
|
||||
credential = await db_client.create_credential(
|
||||
organization_id=user.selected_organization_id,
|
||||
user_id=user.id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
credential_type=request.credential_type.value,
|
||||
credential_data=request.credential_data,
|
||||
)
|
||||
|
||||
return build_credential_response(credential)
|
||||
|
||||
except Exception as e:
|
||||
# Handle unique constraint violation
|
||||
if "unique_org_credential_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A credential with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{credential_uuid}")
|
||||
async def get_credential(
|
||||
credential_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> CredentialResponse:
|
||||
"""
|
||||
Get a specific webhook credential by UUID.
|
||||
|
||||
Args:
|
||||
credential_uuid: The UUID of the credential
|
||||
|
||||
Returns:
|
||||
The credential (without sensitive data)
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, user.selected_organization_id
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
return build_credential_response(credential)
|
||||
|
||||
|
||||
@router.put("/{credential_uuid}")
|
||||
async def update_credential(
|
||||
credential_uuid: str,
|
||||
request: UpdateCredentialRequest,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> CredentialResponse:
|
||||
"""
|
||||
Update a webhook credential.
|
||||
|
||||
Args:
|
||||
credential_uuid: The UUID of the credential to update
|
||||
request: The update request
|
||||
|
||||
Returns:
|
||||
The updated credential (without sensitive data)
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
# Validate credential data if provided
|
||||
if request.credential_type and request.credential_data:
|
||||
validate_credential_data(request.credential_type, request.credential_data)
|
||||
|
||||
try:
|
||||
credential = await db_client.update_credential(
|
||||
credential_uuid=credential_uuid,
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
credential_type=request.credential_type.value
|
||||
if request.credential_type
|
||||
else None,
|
||||
credential_data=request.credential_data,
|
||||
)
|
||||
|
||||
if not credential:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
return build_credential_response(credential)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
if "unique_org_credential_name" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail=f"A credential with the name '{request.name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{credential_uuid}")
|
||||
async def delete_credential(
|
||||
credential_uuid: str,
|
||||
user: UserModel = Depends(get_user),
|
||||
) -> dict:
|
||||
"""
|
||||
Delete (soft delete) a webhook credential.
|
||||
|
||||
Args:
|
||||
credential_uuid: The UUID of the credential to delete
|
||||
|
||||
Returns:
|
||||
Success message
|
||||
"""
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected for the user"
|
||||
)
|
||||
|
||||
deleted = await db_client.delete_credential(
|
||||
credential_uuid, user.selected_organization_id
|
||||
)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Credential not found")
|
||||
|
||||
return {"status": "deleted", "uuid": credential_uuid}
|
||||
|
|
@ -2,10 +2,12 @@ from fastapi import APIRouter
|
|||
from loguru import logger
|
||||
|
||||
from api.routes.campaign import router as campaign_router
|
||||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
from api.routes.public_agent import router as public_agent_router
|
||||
from api.routes.public_embed import router as public_embed_router
|
||||
from api.routes.reports import router as reports_router
|
||||
from api.routes.rtc_offer import router as rtc_offer_router
|
||||
|
|
@ -29,6 +31,7 @@ router.include_router(superuser_router)
|
|||
router.include_router(workflow_router)
|
||||
router.include_router(user_router)
|
||||
router.include_router(campaign_router)
|
||||
router.include_router(credentials_router)
|
||||
router.include_router(integration_router)
|
||||
router.include_router(organization_router)
|
||||
router.include_router(s3_router)
|
||||
|
|
@ -38,6 +41,7 @@ router.include_router(organization_usage_router)
|
|||
router.include_router(reports_router)
|
||||
router.include_router(webrtc_signaling_router)
|
||||
router.include_router(public_embed_router)
|
||||
router.include_router(public_agent_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
|
||||
|
||||
|
|
|
|||
187
api/routes/public_agent.py
Normal file
187
api/routes/public_agent.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Public API endpoints for agent triggers.
|
||||
|
||||
These endpoints are accessible with API key authentication and allow
|
||||
external systems to programmatically trigger phone calls.
|
||||
"""
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.db import db_client
|
||||
from api.enums import TriggerState
|
||||
from api.services.quota_service import check_dograh_quota_by_user_id
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
|
||||
router = APIRouter(prefix="/public/agent")
|
||||
|
||||
|
||||
class TriggerCallRequest(BaseModel):
|
||||
"""Request model for triggering a call via API"""
|
||||
|
||||
phone_number: str
|
||||
initial_context: Optional[dict] = None
|
||||
|
||||
|
||||
class TriggerCallResponse(BaseModel):
|
||||
"""Response model for successful call initiation"""
|
||||
|
||||
status: str
|
||||
workflow_run_id: int
|
||||
workflow_run_name: str
|
||||
|
||||
|
||||
def trigger_exists_in_workflow(workflow_definition: dict, trigger_path: str) -> bool:
|
||||
"""Check if trigger node exists in workflow definition.
|
||||
|
||||
Args:
|
||||
workflow_definition: The workflow definition JSON
|
||||
trigger_path: The trigger UUID to look for
|
||||
|
||||
Returns:
|
||||
True if trigger node exists, False otherwise
|
||||
"""
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("type") == "trigger":
|
||||
if node.get("data", {}).get("trigger_path") == trigger_path:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/{uuid}", response_model=TriggerCallResponse)
|
||||
async def initiate_call(
|
||||
uuid: str,
|
||||
request: TriggerCallRequest,
|
||||
x_api_key: str = Header(..., alias="X-API-Key"),
|
||||
):
|
||||
"""Initiate a phone call via API trigger.
|
||||
|
||||
This endpoint allows external systems (CRMs, automation tools, etc.) to
|
||||
programmatically trigger outbound phone calls with custom context variables.
|
||||
|
||||
Args:
|
||||
uuid: The unique trigger UUID
|
||||
request: The call request with phone number and optional context
|
||||
x_api_key: API key for authentication (passed in X-API-Key header)
|
||||
|
||||
Returns:
|
||||
TriggerCallResponse with workflow run details
|
||||
|
||||
Raises:
|
||||
HTTPException: Various error conditions (401, 403, 404, 400)
|
||||
"""
|
||||
# 1. Validate API key
|
||||
api_key = await db_client.validate_api_key(x_api_key)
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# 2. Lookup agent trigger by UUID
|
||||
trigger = await db_client.get_agent_trigger_by_path(uuid)
|
||||
if not trigger:
|
||||
raise HTTPException(status_code=404, detail="Agent trigger not found")
|
||||
|
||||
# 3. Validate organization match (API key org must match trigger org)
|
||||
if api_key.organization_id != trigger.organization_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# 4. Validate trigger is active
|
||||
if trigger.state != TriggerState.ACTIVE.value:
|
||||
raise HTTPException(status_code=404, detail="Agent trigger is not active")
|
||||
|
||||
# 4.5 Check Dograh quota before initiating the call
|
||||
quota_result = await check_dograh_quota_by_user_id(api_key.created_by)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# 5. Get workflow and validate trigger exists in definition
|
||||
workflow = await db_client.get_workflow_by_id(trigger.workflow_id)
|
||||
if not workflow:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
# Get workflow definition (with fallback to legacy field)
|
||||
workflow_definition = workflow.workflow_definition_with_fallback
|
||||
|
||||
# Validate trigger node still exists in the workflow definition
|
||||
if not trigger_exists_in_workflow(workflow_definition, uuid):
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Trigger not found or has been removed from workflow",
|
||||
)
|
||||
|
||||
# 6. Get telephony provider for the organization
|
||||
provider = await get_telephony_provider(trigger.organization_id)
|
||||
|
||||
# Validate provider is configured
|
||||
if not provider.validate_config():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Telephony provider not configured for this organization",
|
||||
)
|
||||
|
||||
# 7. Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
# 8. Create workflow run
|
||||
workflow_run_name = f"WR-API-{random.randint(1000, 9999)}"
|
||||
workflow_run = await db_client.create_workflow_run(
|
||||
name=workflow_run_name,
|
||||
workflow_id=trigger.workflow_id,
|
||||
mode=workflow_run_mode,
|
||||
initial_context={
|
||||
"phone_number": request.phone_number,
|
||||
"agent_uuid": uuid,
|
||||
**(request.initial_context or {}),
|
||||
},
|
||||
user_id=api_key.created_by,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created workflow run {workflow_run.id} for API trigger {uuid} "
|
||||
f"to phone number {request.phone_number}"
|
||||
)
|
||||
|
||||
# 9. Construct webhook URL for telephony provider callback
|
||||
backend_endpoint = await TunnelURLProvider.get_tunnel_url()
|
||||
webhook_endpoint = provider.WEBHOOK_ENDPOINT
|
||||
|
||||
webhook_url = (
|
||||
f"https://{backend_endpoint}/api/v1/telephony/{webhook_endpoint}"
|
||||
f"?workflow_id={trigger.workflow_id}"
|
||||
f"&user_id={api_key.created_by}"
|
||||
f"&workflow_run_id={workflow_run.id}"
|
||||
f"&organization_id={trigger.organization_id}"
|
||||
)
|
||||
|
||||
# 10. Initiate call via telephony provider
|
||||
result = await provider.initiate_call(
|
||||
to_number=request.phone_number,
|
||||
webhook_url=webhook_url,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
|
||||
# 11. Store provider metadata in workflow run context
|
||||
gathered_context = {
|
||||
"provider": provider.PROVIDER_NAME,
|
||||
"triggered_by": "api",
|
||||
"trigger_uuid": uuid,
|
||||
**(result.provider_metadata or {}),
|
||||
}
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run.id, gathered_context=gathered_context
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Call initiated successfully for workflow run {workflow_run.id} "
|
||||
f"via trigger {uuid}"
|
||||
)
|
||||
|
||||
return TriggerCallResponse(
|
||||
status="initiated",
|
||||
workflow_run_id=workflow_run.id,
|
||||
workflow_run_name=workflow_run_name,
|
||||
)
|
||||
|
|
@ -18,6 +18,7 @@ from api.enums import WorkflowRunState
|
|||
from api.services.auth.depends import get_user
|
||||
from api.services.campaign.call_dispatcher import campaign_call_dispatcher
|
||||
from api.services.campaign.campaign_event_publisher import get_campaign_event_publisher
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from api.services.telephony.factory import get_telephony_provider
|
||||
from api.utils.tunnel import TunnelURLProvider
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
|
@ -100,6 +101,11 @@ async def initiate_call(
|
|||
detail="telephony_not_configured",
|
||||
)
|
||||
|
||||
# Check Dograh quota before initiating the call
|
||||
quota_result = await check_dograh_quota(user)
|
||||
if not quota_result.has_quota:
|
||||
raise HTTPException(status_code=402, detail=quota_result.error_message)
|
||||
|
||||
# Determine the workflow run mode based on provider type
|
||||
workflow_run_mode = provider.PROVIDER_NAME
|
||||
|
||||
|
|
@ -234,7 +240,9 @@ async def websocket_endpoint(
|
|||
logger.warning(
|
||||
f"Workflow run {workflow_run_id} not in initialized state: {workflow_run.state}"
|
||||
)
|
||||
await websocket.close(code=4409, reason="Workflow run not available for connection")
|
||||
await websocket.close(
|
||||
code=4409, reason="Workflow run not available for connection"
|
||||
)
|
||||
return
|
||||
|
||||
# Extract provider type from workflow run context
|
||||
|
|
@ -267,10 +275,9 @@ async def websocket_endpoint(
|
|||
|
||||
# Set workflow run state to 'running' before starting the pipeline
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
state=WorkflowRunState.RUNNING.value
|
||||
run_id=workflow_run_id, state=WorkflowRunState.RUNNING.value
|
||||
)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"[run {workflow_run_id}] Set workflow run state to 'running' for {provider_type} provider"
|
||||
)
|
||||
|
|
@ -382,9 +389,9 @@ async def _process_status_update(
|
|||
|
||||
# Mark workflow run as completed
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
run_id=workflow_run_id,
|
||||
is_completed=True,
|
||||
state=WorkflowRunState.COMPLETED.value
|
||||
state=WorkflowRunState.COMPLETED.value,
|
||||
)
|
||||
|
||||
elif status.status in ["failed", "busy", "no-answer", "canceled"]:
|
||||
|
|
|
|||
|
|
@ -22,9 +22,8 @@ from loguru import logger
|
|||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user_ws
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.pipecat.run_pipeline import run_pipeline_smallwebrtc
|
||||
from api.services.quota_service import check_dograh_quota
|
||||
from pipecat.transports.smallwebrtc.connection import SmallWebRTCConnection
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
|
|
@ -67,75 +66,6 @@ class SignalingManager:
|
|||
self._connections: Dict[str, WebSocket] = {}
|
||||
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
|
||||
|
||||
async def _check_dograh_quota(self, user: UserModel) -> tuple[bool, str]:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
Args:
|
||||
user_id: The user ID to check quota for
|
||||
|
||||
Returns:
|
||||
Tuple of (has_quota, error_message)
|
||||
- has_quota: True if user has sufficient quota or not using Dograh
|
||||
- error_message: Error message if quota check fails, empty string otherwise
|
||||
"""
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return True, ""
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return False, (
|
||||
"You have exhausted your trial credits."
|
||||
"Please email founders@dograh.com for additional credits."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
return False, "Could not verify Dograh credits. Please try again."
|
||||
|
||||
return True, ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return True, ""
|
||||
|
||||
async def handle_websocket(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
|
|
@ -210,15 +140,15 @@ class SignalingManager:
|
|||
set_current_run_id(workflow_run_id)
|
||||
|
||||
# Check Dograh quota before initiating the call
|
||||
has_quota, error_message = await self._check_dograh_quota(user)
|
||||
if not has_quota:
|
||||
quota_result = await check_dograh_quota(user)
|
||||
if not quota_result.has_quota:
|
||||
# Send error response for quota issues
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"payload": {
|
||||
"error_type": "quota_exceeded",
|
||||
"message": error_message,
|
||||
"message": quota_result.error_message,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
|
|
@ -18,6 +19,62 @@ from api.services.workflow.dto import ReactFlowDTO
|
|||
from api.services.workflow.errors import ItemKind, WorkflowError
|
||||
from api.services.workflow.workflow import WorkflowGraph
|
||||
|
||||
|
||||
def extract_trigger_paths(workflow_definition: dict) -> List[str]:
|
||||
"""Extract trigger UUIDs from workflow definition.
|
||||
|
||||
Args:
|
||||
workflow_definition: The workflow definition JSON
|
||||
|
||||
Returns:
|
||||
List of trigger UUIDs found in the workflow
|
||||
"""
|
||||
if not workflow_definition:
|
||||
return []
|
||||
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
trigger_paths = []
|
||||
|
||||
for node in nodes:
|
||||
if node.get("type") == "trigger":
|
||||
trigger_path = node.get("data", {}).get("trigger_path")
|
||||
if trigger_path:
|
||||
trigger_paths.append(trigger_path)
|
||||
|
||||
return trigger_paths
|
||||
|
||||
|
||||
def regenerate_trigger_uuids(workflow_definition: dict) -> dict:
|
||||
"""Regenerate UUIDs for all trigger nodes in a workflow definition.
|
||||
|
||||
This should be called when creating a new workflow from a template or
|
||||
duplicating a workflow to avoid trigger UUID conflicts.
|
||||
|
||||
Args:
|
||||
workflow_definition: The workflow definition JSON
|
||||
|
||||
Returns:
|
||||
Updated workflow definition with new trigger UUIDs
|
||||
"""
|
||||
if not workflow_definition:
|
||||
return workflow_definition
|
||||
|
||||
# Deep copy to avoid modifying the original
|
||||
import copy
|
||||
|
||||
updated_definition = copy.deepcopy(workflow_definition)
|
||||
|
||||
nodes = updated_definition.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("type") == "trigger":
|
||||
# Generate a new UUID for this trigger
|
||||
if "data" not in node:
|
||||
node["data"] = {}
|
||||
node["data"]["trigger_path"] = str(uuid.uuid4())
|
||||
|
||||
return updated_definition
|
||||
|
||||
|
||||
router = APIRouter(prefix="/workflow")
|
||||
|
||||
|
||||
|
|
@ -181,6 +238,17 @@ async def create_workflow(
|
|||
user.id,
|
||||
user.selected_organization_id,
|
||||
)
|
||||
|
||||
# Sync agent triggers if workflow definition contains any
|
||||
if request.workflow_definition:
|
||||
trigger_paths = extract_trigger_paths(request.workflow_definition)
|
||||
if trigger_paths:
|
||||
await db_client.sync_triggers_for_workflow(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
trigger_paths=trigger_paths,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
|
|
@ -238,13 +306,27 @@ async def create_workflow_from_template(
|
|||
)
|
||||
|
||||
# Create the workflow in our database
|
||||
# Regenerate trigger UUIDs to avoid conflicts with existing triggers
|
||||
workflow_def = regenerate_trigger_uuids(
|
||||
workflow_data.get("workflow_definition", {})
|
||||
)
|
||||
workflow = await db_client.create_workflow(
|
||||
name=workflow_data.get("name", f"{request.use_case} - {request.call_type}"),
|
||||
workflow_definition=workflow_data.get("workflow_definition", {}),
|
||||
workflow_definition=workflow_def,
|
||||
user_id=user.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
# Sync agent triggers if workflow definition contains any
|
||||
if workflow_def:
|
||||
trigger_paths = extract_trigger_paths(workflow_def)
|
||||
if trigger_paths:
|
||||
await db_client.sync_triggers_for_workflow(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
trigger_paths=trigger_paths,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
|
|
@ -434,6 +516,16 @@ async def update_workflow(
|
|||
workflow_configurations=request.workflow_configurations,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
# Sync agent triggers if workflow definition was updated
|
||||
if request.workflow_definition:
|
||||
trigger_paths = extract_trigger_paths(request.workflow_definition)
|
||||
await db_client.sync_triggers_for_workflow(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
trigger_paths=trigger_paths,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
|
|
@ -645,13 +737,25 @@ async def duplicate_workflow_template(
|
|||
)
|
||||
|
||||
# Create a new workflow from the template
|
||||
# Regenerate trigger UUIDs to avoid conflicts with existing triggers
|
||||
workflow_def = regenerate_trigger_uuids(template.template_json)
|
||||
workflow = await db_client.create_workflow(
|
||||
request.workflow_name,
|
||||
template.template_json,
|
||||
workflow_def,
|
||||
user.id,
|
||||
user.selected_organization_id,
|
||||
)
|
||||
|
||||
# Sync agent triggers if template contains any
|
||||
if workflow_def:
|
||||
trigger_paths = extract_trigger_paths(workflow_def)
|
||||
if trigger_paths:
|
||||
await db_client.sync_triggers_for_workflow(
|
||||
workflow_id=workflow.id,
|
||||
organization_id=user.selected_organization_id,
|
||||
trigger_paths=trigger_paths,
|
||||
)
|
||||
|
||||
return {
|
||||
"id": workflow.id,
|
||||
"name": workflow.name,
|
||||
|
|
|
|||
|
|
@ -19,7 +19,6 @@ from pipecat.audio.mixers.soundfile_mixer import SoundfileMixer
|
|||
from pipecat.audio.turn.smart_turn.base_smart_turn import SmartTurnParams
|
||||
from pipecat.audio.turn.smart_turn.local_smart_turn_v3 import LocalSmartTurnAnalyzerV3
|
||||
from pipecat.audio.vad.silero import SileroVADAnalyzer, VADParams
|
||||
from pipecat.serializers.plivo import PlivoFrameSerializer
|
||||
from pipecat.serializers.twilio import TwilioFrameSerializer
|
||||
from pipecat.serializers.vobiz import VobizFrameSerializer
|
||||
from pipecat.serializers.vonage import VonageFrameSerializer
|
||||
|
|
|
|||
122
api/services/quota_service.py
Normal file
122
api/services/quota_service.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Quota checking service for Dograh credits.
|
||||
|
||||
This module provides reusable quota checking functionality that can be used
|
||||
across different endpoints (WebRTC signaling, telephony, public API triggers).
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
@dataclass
|
||||
class QuotaCheckResult:
|
||||
"""Result of a quota check."""
|
||||
|
||||
has_quota: bool
|
||||
error_message: str = ""
|
||||
|
||||
|
||||
async def check_dograh_quota(user: UserModel) -> QuotaCheckResult:
|
||||
"""Check if user has sufficient Dograh quota for making a call.
|
||||
|
||||
This function checks if the user is using any Dograh services (LLM, STT, TTS)
|
||||
and validates that they have sufficient credits remaining.
|
||||
|
||||
Args:
|
||||
user: The user to check quota for
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with has_quota=True if user has sufficient quota or
|
||||
is not using Dograh services, or has_quota=False with error_message
|
||||
if quota is insufficient.
|
||||
"""
|
||||
try:
|
||||
# Get user configurations
|
||||
user_config = await db_client.get_user_configurations(user.id)
|
||||
|
||||
# Check if user is using any Dograh service
|
||||
using_dograh = False
|
||||
dograh_api_keys = set()
|
||||
|
||||
if user_config.llm and user_config.llm.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.llm.api_key)
|
||||
|
||||
if user_config.stt and user_config.stt.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.stt.api_key)
|
||||
|
||||
if user_config.tts and user_config.tts.provider == ServiceProviders.DOGRAH:
|
||||
using_dograh = True
|
||||
dograh_api_keys.add(user_config.tts.api_key)
|
||||
|
||||
# If not using Dograh, quota check passes
|
||||
if not using_dograh:
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
# Check quota for ALL Dograh keys
|
||||
for api_key in dograh_api_keys:
|
||||
try:
|
||||
usage = await mps_service_key_client.check_service_key_usage(
|
||||
api_key, created_by=user.provider_id
|
||||
)
|
||||
remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
# Require at least $0.10 for a short call
|
||||
if remaining < 0.10:
|
||||
logger.warning(
|
||||
f"Insufficient Dograh credits for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message=(
|
||||
"You have exhausted your trial credits. "
|
||||
"Please email founders@dograh.com for additional Dograh credits "
|
||||
"or change providers in Models configurations."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Dograh quota check passed for key ...{api_key[-8:]}: "
|
||||
f"${remaining:.2f} remaining"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check quota for Dograh key: {str(e)}")
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="Could not verify Dograh credits. Please try again.",
|
||||
)
|
||||
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error during quota check: {str(e)}")
|
||||
# On unexpected error, allow the call to proceed
|
||||
return QuotaCheckResult(has_quota=True)
|
||||
|
||||
|
||||
async def check_dograh_quota_by_user_id(user_id: int) -> QuotaCheckResult:
|
||||
"""Check Dograh quota by user ID.
|
||||
|
||||
Convenience function that fetches the user and then checks quota.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user to check quota for
|
||||
|
||||
Returns:
|
||||
QuotaCheckResult with quota status
|
||||
"""
|
||||
user = await db_client.get_user_by_id(user_id)
|
||||
if not user:
|
||||
return QuotaCheckResult(
|
||||
has_quota=False,
|
||||
error_message="User not found",
|
||||
)
|
||||
return await check_dograh_quota(user)
|
||||
|
|
@ -299,11 +299,11 @@ class VobizProvider(TelephonyProvider):
|
|||
message handling to VobizFrameSerializer.
|
||||
"""
|
||||
from api.services.pipecat.run_pipeline import run_pipeline_vobiz
|
||||
|
||||
|
||||
first_msg = await websocket.receive_text()
|
||||
start_msg = json.loads(first_msg)
|
||||
logger.debug(f"Received the first message: {start_msg}")
|
||||
|
||||
|
||||
# Validate that this is a start event
|
||||
if start_msg.get("event") != "start":
|
||||
logger.error(f"Expected 'start' event, got: {start_msg.get('event')}")
|
||||
|
|
@ -317,7 +317,7 @@ class VobizProvider(TelephonyProvider):
|
|||
start_data = start_msg.get("start", {})
|
||||
stream_id = start_data.get("streamId")
|
||||
call_id = start_data.get("callId")
|
||||
|
||||
|
||||
if not stream_id or not call_id:
|
||||
logger.error(f"Missing streamId or callId in start event: {start_data}")
|
||||
await websocket.close(code=4400, reason="Missing streamId or callId")
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ class NodeType(str, Enum):
|
|||
endNode = "endCall"
|
||||
agentNode = "agentNode"
|
||||
globalNode = "globalNode"
|
||||
trigger = "trigger"
|
||||
webhook = "webhook"
|
||||
|
||||
|
||||
class Position(BaseModel):
|
||||
|
|
@ -28,9 +30,20 @@ class ExtractionVariableDTO(BaseModel):
|
|||
prompt: Optional[str] = None
|
||||
|
||||
|
||||
class CustomHeaderDTO(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class RetryConfigDTO(BaseModel):
|
||||
enabled: bool = False
|
||||
max_retries: int = 3
|
||||
retry_delay_seconds: int = 5
|
||||
|
||||
|
||||
class NodeDataDTO(BaseModel):
|
||||
name: str = Field(..., min_length=1)
|
||||
prompt: str = Field(..., min_length=1)
|
||||
prompt: Optional[str] = Field(default=None)
|
||||
is_static: bool = False
|
||||
is_start: bool = False
|
||||
is_end: bool = False
|
||||
|
|
@ -44,6 +57,15 @@ class NodeDataDTO(BaseModel):
|
|||
detect_voicemail: bool = True
|
||||
delayed_start: bool = False
|
||||
delayed_start_duration: Optional[float] = None
|
||||
trigger_path: Optional[str] = None
|
||||
# Webhook node specific fields
|
||||
enabled: bool = True
|
||||
http_method: Optional[str] = None
|
||||
endpoint_url: Optional[str] = None
|
||||
credential_uuid: Optional[str] = None
|
||||
custom_headers: Optional[list[CustomHeaderDTO]] = None
|
||||
payload_template: Optional[dict] = None
|
||||
retry_config: Optional[RetryConfigDTO] = None
|
||||
|
||||
|
||||
class RFNodeDTO(BaseModel):
|
||||
|
|
@ -52,6 +74,14 @@ class RFNodeDTO(BaseModel):
|
|||
position: Position
|
||||
data: NodeDataDTO
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_prompt_required(self):
|
||||
"""Require prompt for all node types except trigger and webhook."""
|
||||
if self.type not in (NodeType.trigger, NodeType.webhook):
|
||||
if not self.data.prompt or len(self.data.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt is required for non-trigger nodes")
|
||||
return self
|
||||
|
||||
|
||||
class EdgeDataDTO(BaseModel):
|
||||
label: str = Field(..., min_length=1)
|
||||
|
|
|
|||
|
|
@ -1,227 +1,227 @@
|
|||
import os
|
||||
"""Execute webhook integrations after workflow run completion."""
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import IntegrationModel
|
||||
from api.enums import OrganizationConfigurationKey, WorkflowRunMode
|
||||
from api.db.models import ExternalCredentialModel, WorkflowRunModel
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.utils.context import set_current_run_id
|
||||
|
||||
|
||||
async def run_integrations_post_workflow_run(ctx, workflow_run_id: int):
|
||||
async def run_integrations_post_workflow_run(_ctx, workflow_run_id: int):
|
||||
"""
|
||||
Run integrations after a workflow run completes.
|
||||
Run webhook integrations after a workflow run completes.
|
||||
|
||||
This function:
|
||||
1. Gets the workflow run and its gathered_context
|
||||
2. Determines the organization_id through the workflow -> user -> organization chain
|
||||
3. Fetches all active integrations for that organization
|
||||
4. For Slack integrations, sends the gathered_context to the webhook URL
|
||||
|
||||
Args:
|
||||
workflow_run_id: The ID of the completed workflow run
|
||||
1. Gets the workflow run and its contexts
|
||||
2. Extracts webhook nodes from workflow definition
|
||||
3. Executes each enabled webhook node
|
||||
"""
|
||||
# Set the workflow_run_id in context variable for consistent logging format
|
||||
set_current_run_id(workflow_run_id)
|
||||
logger.info("Running integrations for workflow run")
|
||||
logger.info("Running webhook integrations for workflow run")
|
||||
|
||||
try:
|
||||
# Step 1: Get workflow run details with gathered_context using DB client
|
||||
# Step 1: Get workflow run with full context
|
||||
workflow_run, organization_id = await db_client.get_workflow_run_with_context(
|
||||
workflow_run_id
|
||||
)
|
||||
|
||||
if not workflow_run:
|
||||
logger.error("Workflow run not found")
|
||||
if not workflow_run or not workflow_run.workflow:
|
||||
logger.error("Workflow run or workflow not found")
|
||||
return
|
||||
|
||||
if not workflow_run.workflow:
|
||||
logger.error("Workflow not found for workflow run")
|
||||
return
|
||||
|
||||
if not workflow_run.workflow.user:
|
||||
logger.error("User not found for workflow run")
|
||||
return
|
||||
|
||||
gathered_context = workflow_run.gathered_context
|
||||
initial_context = workflow_run.initial_context
|
||||
|
||||
if not gathered_context:
|
||||
logger.info("No gathered context for workflow run, skipping integrations")
|
||||
return
|
||||
|
||||
# Check if workflow run mode is stasis and sync with vendor
|
||||
if workflow_run.mode == WorkflowRunMode.STASIS.value:
|
||||
await _sync_vendor_data(initial_context, gathered_context)
|
||||
|
||||
# Step 2: Check if organization_id is available
|
||||
if not organization_id:
|
||||
logger.warning(
|
||||
f"No organization found for workflow run, skipping integrations"
|
||||
)
|
||||
logger.warning("No organization found, skipping webhooks")
|
||||
return
|
||||
|
||||
logger.debug(f"Found organization_id {organization_id} for workflow run")
|
||||
# Step 2: Get workflow definition
|
||||
workflow_definition = workflow_run.workflow.workflow_definition_with_fallback
|
||||
if not workflow_definition:
|
||||
logger.debug("No workflow definition, skipping webhooks")
|
||||
return
|
||||
|
||||
# Step 3: Get all active integrations for the organization using DB client
|
||||
integrations = await db_client.get_active_integrations_by_organization(
|
||||
organization_id
|
||||
)
|
||||
# Step 3: Extract webhook nodes
|
||||
nodes = workflow_definition.get("nodes", [])
|
||||
webhook_nodes = [n for n in nodes if n.get("type") == "webhook"]
|
||||
|
||||
logger.info(
|
||||
f"Found {len(integrations)} active integrations for organization {organization_id}"
|
||||
)
|
||||
if not webhook_nodes:
|
||||
logger.debug("No webhook nodes in workflow")
|
||||
return
|
||||
|
||||
# Step 4: Process each integration
|
||||
for integration in integrations:
|
||||
await _process_integration(integration, gathered_context)
|
||||
logger.info(f"Found {len(webhook_nodes)} webhook nodes to execute")
|
||||
|
||||
# Step 4: Build render context
|
||||
render_context = _build_render_context(workflow_run)
|
||||
|
||||
# Step 5: Execute each webhook node
|
||||
for node in webhook_nodes:
|
||||
webhook_data = node.get("data", {})
|
||||
try:
|
||||
await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context=render_context,
|
||||
organization_id=organization_id,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log error but continue with other webhooks
|
||||
logger.error(
|
||||
f"Failed to execute webhook '{webhook_data.get('name', 'unknown')}': {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running integrations for workflow run: {str(e)}")
|
||||
logger.error(f"Error running webhook integrations: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def _sync_vendor_data(initial_context: dict, gathered_context: dict):
|
||||
def _build_render_context(workflow_run: WorkflowRunModel) -> Dict[str, Any]:
|
||||
"""Build the context dict for template rendering."""
|
||||
return {
|
||||
# Top-level fields
|
||||
"workflow_run_id": workflow_run.id,
|
||||
"workflow_run_name": workflow_run.name,
|
||||
"workflow_id": workflow_run.workflow_id,
|
||||
"workflow_name": workflow_run.workflow.name if workflow_run.workflow else None,
|
||||
# Nested contexts
|
||||
"initial_context": workflow_run.initial_context or {},
|
||||
"gathered_context": workflow_run.gathered_context or {},
|
||||
"cost_info": workflow_run.usage_info or {},
|
||||
"recording_url": getattr(workflow_run, "recording_url", None),
|
||||
"transcript_url": getattr(workflow_run, "transcript_url", None),
|
||||
}
|
||||
|
||||
|
||||
async def _execute_webhook_node(
|
||||
webhook_data: Dict[str, Any],
|
||||
render_context: Dict[str, Any],
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""
|
||||
Sync data with external vendor for stasis mode workflow runs.
|
||||
Execute a single webhook node.
|
||||
|
||||
Args:
|
||||
initial_context: The initial context containing lead_id
|
||||
gathered_context: The gathered context containing mapped_call_disposition
|
||||
webhook_data: The webhook node's data dict from workflow definition
|
||||
render_context: Context for template rendering
|
||||
organization_id: For credential lookup
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
if not os.getenv("ARI_DATA_SYNCING_URI"):
|
||||
logger.info("ARI_DATA_SYNCING_URI not configured, skipping vendor sync")
|
||||
return
|
||||
webhook_name = webhook_data.get("name", "Unnamed Webhook")
|
||||
|
||||
try:
|
||||
lead_id = initial_context.get("lead_id")
|
||||
status = gathered_context.get("mapped_call_disposition")
|
||||
# 1. Check if enabled
|
||||
if not webhook_data.get("enabled", True):
|
||||
logger.debug(f"Webhook '{webhook_name}' is disabled, skipping")
|
||||
return True
|
||||
|
||||
if lead_id and status:
|
||||
ari_data_uri = os.getenv("ARI_DATA_SYNCING_URI")
|
||||
# Add URL params to the base URL
|
||||
sync_url = f"{ari_data_uri}&lead_id={lead_id}&status={status}"
|
||||
# 2. Validate endpoint URL
|
||||
url = webhook_data.get("endpoint_url")
|
||||
if not url:
|
||||
logger.error(f"Webhook '{webhook_name}' has no endpoint URL")
|
||||
return False
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(sync_url, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
logger.info(
|
||||
f"Successfully synced data for lead_id: {lead_id} with status: {status}"
|
||||
)
|
||||
# 3. Build headers
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# 4. Add auth header if credential configured
|
||||
credential_uuid = webhook_data.get("credential_uuid")
|
||||
if credential_uuid:
|
||||
credential = await db_client.get_credential_by_uuid(
|
||||
credential_uuid, organization_id
|
||||
)
|
||||
if credential:
|
||||
auth_header = _build_auth_header(credential)
|
||||
headers.update(auth_header)
|
||||
logger.debug(f"Applied credential '{credential.name}' to webhook")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Missing lead_id or status for syncing - lead_id: {lead_id}, status: {status}"
|
||||
f"Credential {credential_uuid} not found for webhook '{webhook_name}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to sync data to ARI_DATA_SYNCING_URI: {e}")
|
||||
|
||||
# 5. Add custom headers
|
||||
custom_headers = webhook_data.get("custom_headers", [])
|
||||
for h in custom_headers:
|
||||
if h.get("key") and h.get("value"):
|
||||
headers[h["key"]] = h["value"]
|
||||
|
||||
async def _process_integration(
|
||||
integration: IntegrationModel,
|
||||
gathered_context: dict,
|
||||
):
|
||||
"""
|
||||
Process a single integration.
|
||||
# 6. Render payload template
|
||||
payload_template = webhook_data.get("payload_template", {})
|
||||
payload = render_template(payload_template, render_context)
|
||||
|
||||
Args:
|
||||
integration: The integration model
|
||||
gathered_context: The gathered context from the workflow run
|
||||
workflow_run_name: Name of the workflow run
|
||||
run_id: The workflow run ID for logging context
|
||||
"""
|
||||
logger.info(
|
||||
f"Processing integration {integration.id} (provider: {integration.provider})"
|
||||
)
|
||||
# 7. Make HTTP request
|
||||
method = webhook_data.get("http_method", "POST").upper()
|
||||
|
||||
logger.info(f"Executing webhook '{webhook_name}': {method}")
|
||||
|
||||
try:
|
||||
if integration.provider.lower() == "slack":
|
||||
await _process_slack_integration(integration, gathered_context)
|
||||
else:
|
||||
logger.info(
|
||||
f"Integration provider '{integration.provider}' not supported yet"
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
if method in ("POST", "PUT", "PATCH"):
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
json=payload,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
else: # GET, DELETE
|
||||
response = await client.request(
|
||||
method=method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
logger.info(f"Webhook '{webhook_name}' succeeded: {response.status_code}")
|
||||
return True
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(
|
||||
f"Webhook '{webhook_name}' failed: {e.response.status_code} - {e.response.text[:200]}"
|
||||
)
|
||||
return False
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Webhook '{webhook_name}' request error: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing integration {integration.id}: {str(e)}")
|
||||
logger.error(f"Webhook '{webhook_name}' unexpected error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def _process_slack_integration(
|
||||
integration: IntegrationModel, gathered_context: dict
|
||||
):
|
||||
def _build_auth_header(credential: ExternalCredentialModel) -> Dict[str, str]:
|
||||
"""
|
||||
Process a Slack integration by sending gathered_context to the webhook.
|
||||
Build authentication header based on credential type.
|
||||
|
||||
Args:
|
||||
integration: The Slack integration model
|
||||
gathered_context: The gathered context from the workflow run
|
||||
workflow_run_name: Name of the workflow run
|
||||
run_id: The workflow run ID for logging context
|
||||
credential: The credential model
|
||||
|
||||
Returns:
|
||||
Dict with header name and value
|
||||
"""
|
||||
logger.info(f"Processing Slack integration {integration.id}")
|
||||
cred_type = credential.credential_type
|
||||
cred_data = credential.credential_data or {}
|
||||
|
||||
# TODO: Generalise this
|
||||
if gathered_context.get("mapped_call_disposition") != "XFER":
|
||||
logger.debug(
|
||||
f"Not sending message on slack since not XFER: {gathered_context.get('mapped_call_disposition')}"
|
||||
)
|
||||
return
|
||||
if cred_type == "bearer_token":
|
||||
token = cred_data.get("token", "")
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
try:
|
||||
# Extract webhook URL from connection_details
|
||||
connection_details = integration.connection_details
|
||||
elif cred_type == "api_key":
|
||||
header_name = cred_data.get("header_name", "X-API-Key")
|
||||
api_key = cred_data.get("api_key", "")
|
||||
return {header_name: api_key}
|
||||
|
||||
if not connection_details:
|
||||
logger.error(
|
||||
f"No connection details found for Slack integration {integration.id}"
|
||||
)
|
||||
return
|
||||
elif cred_type == "basic_auth":
|
||||
username = cred_data.get("username", "")
|
||||
password = cred_data.get("password", "")
|
||||
encoded = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||
return {"Authorization": f"Basic {encoded}"}
|
||||
|
||||
# Navigate to incoming_webhook.url in the connection_details
|
||||
webhook_url = connection_details.get("connection_config", {}).get(
|
||||
"incoming_webhook.url"
|
||||
)
|
||||
if not webhook_url:
|
||||
logger.error(
|
||||
f"No incoming_webhook found in connection details for integration {integration.id}"
|
||||
)
|
||||
return
|
||||
elif cred_type == "custom_header":
|
||||
header_name = cred_data.get("header_name", "X-Custom")
|
||||
header_value = cred_data.get("header_value", "")
|
||||
return {header_name: header_value}
|
||||
|
||||
logger.info(f"Found Slack webhook URL for integration {integration.id}")
|
||||
|
||||
# Get message template configuration
|
||||
# Get organization_id from the integration model
|
||||
organization_id = integration.organisation_id
|
||||
message_templates = await db_client.get_configuration_value(
|
||||
organization_id,
|
||||
OrganizationConfigurationKey.DISPOSITION_MESSAGE_TEMPLATE.value,
|
||||
default={},
|
||||
)
|
||||
|
||||
# Check if there's a custom template for Slack
|
||||
slack_template = message_templates.get("slack", {})
|
||||
rendered_text = render_template(slack_template, gathered_context)
|
||||
|
||||
slack_message = {"text": rendered_text}
|
||||
|
||||
# Send to Slack webhook
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
webhook_url,
|
||||
json=slack_message,
|
||||
headers={"Content-Type": "application/json"},
|
||||
) as response:
|
||||
if response.status == 200:
|
||||
logger.info(
|
||||
f"Successfully sent message to Slack for integration {integration.id}"
|
||||
)
|
||||
else:
|
||||
error_text = await response.text()
|
||||
logger.error(
|
||||
f"Failed to send Slack message for integration {integration.id}: {response.status} - {error_text}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing Slack integration {integration.id}: {str(e)}")
|
||||
return {}
|
||||
|
|
|
|||
|
|
@ -1,136 +1,330 @@
|
|||
"""Tests for webhook execution in run_integrations.py."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from api.tasks.run_integrations import _process_slack_integration
|
||||
from api.tasks.run_integrations import (
|
||||
_build_auth_header,
|
||||
_build_render_context,
|
||||
_execute_webhook_node,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_logger():
|
||||
"""Mock the logger for all tests."""
|
||||
with patch("api.tasks.run_integrations.logger") as mock_logger:
|
||||
# Mock the bind method to return the logger itself
|
||||
mock_logger.bind.return_value = mock_logger
|
||||
yield mock_logger
|
||||
with patch("api.tasks.run_integrations.logger") as mock_log:
|
||||
mock_log.bind.return_value = mock_log
|
||||
yield mock_log
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_integration_with_template():
|
||||
"""Test that Slack integration uses render_template correctly."""
|
||||
# Mock integration
|
||||
mock_integration = MagicMock()
|
||||
mock_integration.id = 1
|
||||
mock_integration.organisation_id = 123
|
||||
mock_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
class TestBuildAuthHeader:
|
||||
"""Tests for _build_auth_header function."""
|
||||
|
||||
# Mock gathered context
|
||||
gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration to proceed
|
||||
"call_duration": "300",
|
||||
"agent_name": "Alex",
|
||||
}
|
||||
def test_bearer_token(self):
|
||||
"""Test bearer token auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "bearer_token"
|
||||
credential.credential_data = {"token": "my-secret-token"}
|
||||
|
||||
# Mock db_client
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock message template configuration
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Agent: {{agent_name}}\\nDisposition: {{call_disposition}}\\nDuration: {{call_duration}}s"
|
||||
}
|
||||
}
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"Authorization": "Bearer my-secret-token"}
|
||||
|
||||
def test_api_key(self):
|
||||
"""Test API key auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"header_name": "X-API-Key", "api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_api_key_default_header(self):
|
||||
"""Test API key with default header name."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "api_key"
|
||||
credential.credential_data = {"api_key": "key123"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-API-Key": "key123"}
|
||||
|
||||
def test_basic_auth(self):
|
||||
"""Test basic auth header."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "basic_auth"
|
||||
credential.credential_data = {"username": "user", "password": "pass"}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
# base64 of "user:pass" is "dXNlcjpwYXNz"
|
||||
assert result == {"Authorization": "Basic dXNlcjpwYXNz"}
|
||||
|
||||
def test_custom_header(self):
|
||||
"""Test custom header auth."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "custom_header"
|
||||
credential.credential_data = {
|
||||
"header_name": "X-Custom-Auth",
|
||||
"header_value": "custom-value",
|
||||
}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {"X-Custom-Auth": "custom-value"}
|
||||
|
||||
def test_unknown_type(self):
|
||||
"""Test unknown credential type returns empty dict."""
|
||||
credential = MagicMock()
|
||||
credential.credential_type = "unknown"
|
||||
credential.credential_data = {}
|
||||
|
||||
result = _build_auth_header(credential)
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestBuildRenderContext:
|
||||
"""Tests for _build_render_context function."""
|
||||
|
||||
def test_basic_context(self):
|
||||
"""Test building render context from workflow run."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 123
|
||||
workflow_run.name = "WR-TEST-001"
|
||||
workflow_run.workflow_id = 456
|
||||
workflow_run.workflow.name = "Test Workflow"
|
||||
workflow_run.initial_context = {"phone_number": "+1234567890"}
|
||||
workflow_run.gathered_context = {
|
||||
"customer_name": "John",
|
||||
"mapped_call_disposition": "QUALIFIED",
|
||||
}
|
||||
workflow_run.usage_info = {"call_duration_seconds": 120}
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["workflow_run_id"] == 123
|
||||
assert result["workflow_run_name"] == "WR-TEST-001"
|
||||
assert result["workflow_id"] == 456
|
||||
assert result["workflow_name"] == "Test Workflow"
|
||||
assert result["initial_context"]["phone_number"] == "+1234567890"
|
||||
assert result["gathered_context"]["customer_name"] == "John"
|
||||
assert result["cost_info"]["call_duration_seconds"] == 120
|
||||
assert result["disposition_code"] == "QUALIFIED"
|
||||
|
||||
def test_empty_contexts(self):
|
||||
"""Test with empty/None contexts."""
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.id = 1
|
||||
workflow_run.name = "Test"
|
||||
workflow_run.workflow_id = 1
|
||||
workflow_run.workflow.name = "Workflow"
|
||||
workflow_run.initial_context = None
|
||||
workflow_run.gathered_context = None
|
||||
workflow_run.usage_info = None
|
||||
workflow_run.completed_at = None
|
||||
|
||||
result = _build_render_context(workflow_run)
|
||||
|
||||
assert result["initial_context"] == {}
|
||||
assert result["gathered_context"] == {}
|
||||
assert result["cost_info"] == {}
|
||||
assert result["disposition_code"] is None
|
||||
|
||||
|
||||
class TestExecuteWebhookNode:
|
||||
"""Tests for _execute_webhook_node function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disabled_webhook_skipped(self):
|
||||
"""Test that disabled webhooks are skipped."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": False}
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
# Mock aiohttp session
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
assert result is True # Returns True for skipped webhooks
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_url_returns_false(self):
|
||||
"""Test that missing endpoint URL returns False."""
|
||||
webhook_data = {"name": "Test Webhook", "enabled": True, "endpoint_url": None}
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
|
||||
# Call the function
|
||||
await _process_slack_integration(mock_integration, gathered_context)
|
||||
|
||||
# Verify the message was formatted correctly
|
||||
mock_session.post.assert_called_once()
|
||||
call_args = mock_session.post.call_args
|
||||
|
||||
# Check the webhook URL
|
||||
assert call_args[0][0] == "https://hooks.slack.com/test"
|
||||
|
||||
# Check the message content
|
||||
json_data = call_args[1]["json"]
|
||||
|
||||
# Check that the template was rendered correctly
|
||||
expected_text = "Agent: Alex\nDisposition: XFER\nDuration: 300s"
|
||||
assert json_data["text"] == expected_text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_slack_integration_with_missing_template_vars():
|
||||
"""Test template rendering with missing variables."""
|
||||
# Mock integration
|
||||
mock_integration = MagicMock()
|
||||
mock_integration.id = 1
|
||||
mock_integration.organisation_id = 123
|
||||
mock_integration.connection_details = {
|
||||
"connection_config": {"incoming_webhook.url": "https://hooks.slack.com/test"}
|
||||
}
|
||||
|
||||
# Mock gathered context with missing values
|
||||
gathered_context = {
|
||||
"call_disposition": "XFER",
|
||||
"mapped_call_disposition": "XFER", # Required for Slack integration to proceed
|
||||
# call_duration is missing
|
||||
}
|
||||
|
||||
# Mock db_client
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db_client:
|
||||
# Mock message template configuration with fallback
|
||||
mock_db_client.get_configuration_value = AsyncMock(
|
||||
return_value={
|
||||
"slack": {
|
||||
"DISPOSITION_CODE": "Disposition: {{call_disposition}}\\nDuration: {{call_duration | fallback:N/A}}"
|
||||
}
|
||||
}
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
# Mock aiohttp session
|
||||
with patch(
|
||||
"api.tasks.run_integrations.aiohttp.ClientSession"
|
||||
) as mock_session_class:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = 200
|
||||
assert result is False
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__aenter__.return_value = mock_session
|
||||
mock_session.__aexit__.return_value = AsyncMock()
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_post_request(self):
|
||||
"""Test successful POST webhook execution."""
|
||||
webhook_data = {
|
||||
"name": "CRM Sync",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {
|
||||
"call_id": "{{workflow_run_id}}",
|
||||
"phone": "{{initial_context.phone_number}}",
|
||||
},
|
||||
}
|
||||
|
||||
mock_post = MagicMock()
|
||||
mock_post.__aenter__.return_value = mock_response
|
||||
mock_post.__aexit__.return_value = AsyncMock()
|
||||
render_context = {
|
||||
"workflow_run_id": 123,
|
||||
"initial_context": {"phone_number": "+1234567890"},
|
||||
}
|
||||
|
||||
mock_session.post.return_value = mock_post
|
||||
mock_session_class.return_value = mock_session
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
# Call the function
|
||||
await _process_slack_integration(mock_integration, gathered_context)
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
# Check that the template was rendered with fallback
|
||||
json_data = mock_session.post.call_args[1]["json"]
|
||||
expected_text = "Disposition: XFER\nDuration: N/A"
|
||||
assert json_data["text"] == expected_text
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context=render_context,
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify the request was made correctly
|
||||
mock_client_instance.request.assert_called_once()
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["method"] == "POST"
|
||||
assert call_kwargs["url"] == "https://api.example.com/webhook"
|
||||
assert call_kwargs["json"] == {
|
||||
"call_id": "123",
|
||||
"phone": "+1234567890",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_credential(self):
|
||||
"""Test webhook execution with credential auth."""
|
||||
webhook_data = {
|
||||
"name": "Authenticated Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"credential_uuid": "cred-123",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.name = "API Key"
|
||||
mock_credential.credential_type = "bearer_token"
|
||||
mock_credential.credential_data = {"token": "secret-token"}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=mock_credential)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify auth header was included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["Authorization"] == "Bearer secret-token"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_with_custom_headers(self):
|
||||
"""Test webhook execution with custom headers."""
|
||||
webhook_data = {
|
||||
"name": "Custom Headers Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"custom_headers": [
|
||||
{"key": "X-Source", "value": "dograh"},
|
||||
{"key": "X-Workflow", "value": "test"},
|
||||
],
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
# Verify custom headers were included
|
||||
call_kwargs = mock_client_instance.request.call_args[1]
|
||||
assert call_kwargs["headers"]["X-Source"] == "dograh"
|
||||
assert call_kwargs["headers"]["X-Workflow"] == "test"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_http_error(self):
|
||||
"""Test webhook execution with HTTP error."""
|
||||
import httpx
|
||||
|
||||
webhook_data = {
|
||||
"name": "Failing Webhook",
|
||||
"enabled": True,
|
||||
"http_method": "POST",
|
||||
"endpoint_url": "https://api.example.com/webhook",
|
||||
"payload_template": {},
|
||||
}
|
||||
|
||||
with patch("api.tasks.run_integrations.db_client") as mock_db:
|
||||
mock_db.get_credential_by_uuid = AsyncMock(return_value=None)
|
||||
|
||||
with patch("api.tasks.run_integrations.httpx.AsyncClient") as mock_client:
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.text = "Internal Server Error"
|
||||
mock_response.raise_for_status = MagicMock(
|
||||
side_effect=httpx.HTTPStatusError(
|
||||
"Server Error",
|
||||
request=MagicMock(),
|
||||
response=mock_response,
|
||||
)
|
||||
)
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.request = AsyncMock(return_value=mock_response)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
result = await _execute_webhook_node(
|
||||
webhook_data=webhook_data,
|
||||
render_context={},
|
||||
organization_id=1,
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
|
|
|||
|
|
@ -1,46 +1,126 @@
|
|||
"""Common template rendering utility."""
|
||||
"""Template rendering utility with support for nested JSON paths."""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
|
||||
def render_template(template_str: str, template_var_mapping: Dict[str, Any]) -> str: # noqa: C901 – complex but self-contained
|
||||
"""Replace template placeholders in *template_str* with values from *template_var_mapping*.
|
||||
|
||||
Supported syntax:
|
||||
* ``{{ variable_name }}``
|
||||
* ``{{ variable_name | fallback }}``
|
||||
* ``{{ variable_name | fallback:default_value }}``
|
||||
|
||||
If the variable is undefined and a *fallback* filter is specified the value
|
||||
of *default_value* (or the *variable_name* itself if no default is given)
|
||||
is used instead.
|
||||
def get_nested_value(obj: Any, path: str) -> Any:
|
||||
"""
|
||||
Get a nested value from a dictionary using dot notation.
|
||||
|
||||
Args:
|
||||
obj: The object to traverse (dict or any)
|
||||
path: Dot-separated path (e.g., "a.b.c")
|
||||
|
||||
Returns:
|
||||
The value at the path, or None if not found
|
||||
|
||||
Examples:
|
||||
get_nested_value({"a": {"b": 1}}, "a.b") -> 1
|
||||
get_nested_value({"a": {"b": {"c": 2}}}, "a.b.c") -> 2
|
||||
get_nested_value({"a": 1}, "a.b") -> None
|
||||
"""
|
||||
if not path:
|
||||
return obj
|
||||
|
||||
keys = path.split(".")
|
||||
current = obj
|
||||
|
||||
for key in keys:
|
||||
if isinstance(current, dict):
|
||||
current = current.get(key)
|
||||
else:
|
||||
return None
|
||||
|
||||
if current is None:
|
||||
return None
|
||||
|
||||
return current
|
||||
|
||||
|
||||
def render_template(
|
||||
template: Union[str, dict, list, None],
|
||||
context: Dict[str, Any],
|
||||
) -> Union[str, dict, list, None]: # noqa: C901 – complex but self-contained
|
||||
"""
|
||||
Render a template with variable substitution supporting nested paths.
|
||||
|
||||
Supports:
|
||||
- String templates: "Hello {{name}}"
|
||||
- JSON templates: {"key": "{{value}}"}
|
||||
- Nested paths: "{{initial_context.phone_number}}"
|
||||
- Deep nesting: "{{gathered_context.customer.address.city}}"
|
||||
- Fallback: "{{name | fallback:Unknown}}"
|
||||
|
||||
Args:
|
||||
template: String, dict, list, or None with {{variable}} placeholders
|
||||
context: Dict containing all available variables
|
||||
|
||||
Returns:
|
||||
Rendered template with variables replaced
|
||||
"""
|
||||
if template is None:
|
||||
return None
|
||||
|
||||
# Handle dict templates recursively
|
||||
if isinstance(template, dict):
|
||||
return {
|
||||
_render_string(str(k), context)
|
||||
if isinstance(k, str)
|
||||
else k: render_template(v, context)
|
||||
for k, v in template.items()
|
||||
}
|
||||
|
||||
# Handle list templates recursively
|
||||
if isinstance(template, list):
|
||||
return [render_template(item, context) for item in template]
|
||||
|
||||
# Handle non-string types (int, float, bool, etc.)
|
||||
if not isinstance(template, str):
|
||||
return template
|
||||
|
||||
return _render_string(template, context)
|
||||
|
||||
|
||||
def _render_string(template_str: str, context: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Render a string template with variable substitution.
|
||||
|
||||
Args:
|
||||
template_str: String with {{variable}} placeholders
|
||||
context: Dict containing all available variables
|
||||
|
||||
Returns:
|
||||
Rendered string with variables replaced
|
||||
"""
|
||||
if not template_str:
|
||||
return template_str
|
||||
|
||||
# Regex matches e.g. ``{{ name }}``, ``{{ name | fallback }}``, ``{{ name | fallback:John }}``
|
||||
# Pattern: {{ path }} or {{ path | filter }} or {{ path | filter:default }}
|
||||
pattern = r"\{\{\s*([^|\s}]+)(?:\s*\|\s*([^:}]+)(?::([^}]+))?)?\s*\}\}"
|
||||
|
||||
def _replace(match: re.Match[str]) -> str: # type: ignore[type-arg]
|
||||
variable_name = match.group(1).strip()
|
||||
variable_path = match.group(1).strip()
|
||||
filter_name = match.group(2).strip() if match.group(2) else None
|
||||
filter_value = match.group(3).strip() if match.group(3) else None
|
||||
|
||||
# Pull value from context
|
||||
value = template_var_mapping.get(variable_name)
|
||||
# Get value using nested path lookup
|
||||
value = get_nested_value(context, variable_path)
|
||||
|
||||
# Apply filters
|
||||
if filter_name == "fallback":
|
||||
if value is None or value == "":
|
||||
# Use explicit default value or a title-cased variable name.
|
||||
value = (
|
||||
filter_value if filter_value is not None else variable_name.title()
|
||||
filter_value if filter_value is not None else variable_path.title()
|
||||
)
|
||||
|
||||
# Convert *None* to an empty string so that re.sub replacement works.
|
||||
return str(value) if value is not None else ""
|
||||
# Convert to string for substitution
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value)
|
||||
return str(value)
|
||||
|
||||
# Replace template variables
|
||||
result = re.sub(pattern, _replace, template_str)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue