mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
Merge remote-tracking branch 'origin/main' into feat/user-onboarding
# Conflicts: # docs/api-reference/openapi.json # sdk/python/src/dograh_sdk/_generated_models.py # ui/src/client/index.ts # ui/src/components/AIModelConfigurationV2Editor.tsx
This commit is contained in:
commit
5559ed686f
44 changed files with 2155 additions and 321 deletions
|
|
@ -5,7 +5,7 @@ organization_configurations. Existing rows (the legacy v1 AI model
|
|||
configuration blob) are backfilled with key MODEL_CONFIGURATION.
|
||||
|
||||
Revision ID: 91cc6ba3e1c7
|
||||
Revises: 384be6596b36
|
||||
Revises: efe356f488f9
|
||||
Create Date: 2026-06-12 21:04:25.561529
|
||||
|
||||
"""
|
||||
|
|
@ -17,7 +17,7 @@ from alembic import op
|
|||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "91cc6ba3e1c7"
|
||||
down_revision: Union[str, None] = "384be6596b36"
|
||||
down_revision: Union[str, None] = "efe356f488f9"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
"""add extra column in workflow runs
|
||||
|
||||
Revision ID: efe356f488f9
|
||||
Revises: 384be6596b36
|
||||
Create Date: 2026-06-16 12:24:30.081058
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "efe356f488f9"
|
||||
down_revision: Union[str, None] = "384be6596b36"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"workflow_runs",
|
||||
sa.Column(
|
||||
"extra",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("workflow_runs", "extra")
|
||||
|
|
@ -10,6 +10,7 @@ from api.db.filters import apply_workflow_run_filters, get_workflow_run_order_cl
|
|||
from api.db.models import CampaignModel, QueuedRunModel, WorkflowRunModel
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
from api.utils.recording_artifacts import get_recording_storage_key
|
||||
|
||||
|
||||
class CampaignClient(BaseDBClient):
|
||||
|
|
@ -45,9 +46,11 @@ class CampaignClient(BaseDBClient):
|
|||
source_id=source_id,
|
||||
created_by=user_id,
|
||||
organization_id=organization_id,
|
||||
retry_config=retry_config
|
||||
if retry_config
|
||||
else CampaignModel.retry_config.default.arg,
|
||||
retry_config=(
|
||||
retry_config
|
||||
if retry_config
|
||||
else CampaignModel.retry_config.default.arg
|
||||
),
|
||||
orchestrator_metadata=orchestrator_metadata,
|
||||
telephony_configuration_id=telephony_configuration_id,
|
||||
)
|
||||
|
|
@ -216,6 +219,12 @@ class CampaignClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"user_recording_url": get_recording_storage_key(
|
||||
run.extra, "user"
|
||||
),
|
||||
"bot_recording_url": get_recording_storage_key(
|
||||
run.extra, "bot"
|
||||
),
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
|
|
@ -270,9 +279,11 @@ class CampaignClient(BaseDBClient):
|
|||
source_id=parent_campaign.source_id,
|
||||
created_by=parent_campaign.created_by,
|
||||
organization_id=parent_campaign.organization_id,
|
||||
retry_config=retry_config
|
||||
if retry_config
|
||||
else CampaignModel.retry_config.default.arg,
|
||||
retry_config=(
|
||||
retry_config
|
||||
if retry_config
|
||||
else CampaignModel.retry_config.default.arg
|
||||
),
|
||||
orchestrator_metadata=child_meta,
|
||||
rate_limit_per_second=parent_campaign.rate_limit_per_second,
|
||||
total_rows=len(queued_runs_data),
|
||||
|
|
@ -338,8 +349,7 @@ class CampaignClient(BaseDBClient):
|
|||
# Retries create new queued_runs with suffixed source_uuids linked via
|
||||
# parent_queued_run_id, so group by the ROOT queued_run using a
|
||||
# recursive walk and pick the latest workflow_run across the tree.
|
||||
sql = text(
|
||||
f"""
|
||||
sql = text(f"""
|
||||
WITH RECURSIVE run_tree AS (
|
||||
SELECT id AS root_id, id AS run_id
|
||||
FROM queued_runs
|
||||
|
|
@ -366,8 +376,7 @@ class CampaignClient(BaseDBClient):
|
|||
JOIN latest_run_per_root lr ON lr.root_id = q0.id
|
||||
WHERE q0.campaign_id = :cid
|
||||
AND ({tag_filter})
|
||||
"""
|
||||
)
|
||||
""")
|
||||
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(sql, {"cid": campaign_id})
|
||||
|
|
|
|||
|
|
@ -544,6 +544,9 @@ class WorkflowRunModel(Base):
|
|||
is_completed = Column(Boolean, default=False)
|
||||
recording_url = Column(String, nullable=True)
|
||||
transcript_url = Column(String, nullable=True)
|
||||
extra = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
# Store storage backend as string enum (s3, minio)
|
||||
storage_backend = Column(
|
||||
Enum("s3", "minio", name="storage_backend"),
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from api.db.models import (
|
|||
)
|
||||
from api.enums import OrganizationConfigurationKey, UserConfigurationKey
|
||||
from api.schemas.ai_model_configuration import EffectiveAIModelConfiguration
|
||||
from api.utils.recording_artifacts import get_recording_storage_key
|
||||
|
||||
|
||||
class OrganizationUsageClient(BaseDBClient):
|
||||
|
|
@ -226,6 +227,9 @@ class OrganizationUsageClient(BaseDBClient):
|
|||
"call_duration_seconds": int(round(call_duration)),
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"user_recording_url": get_recording_storage_key(run.extra, "user"),
|
||||
"bot_recording_url": get_recording_storage_key(run.extra, "bot"),
|
||||
"extra": run.extra,
|
||||
"public_access_token": run.public_access_token,
|
||||
"phone_number": phone_number,
|
||||
"caller_number": caller_number,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from api.db.models import (
|
|||
from api.enums import CallType, StorageBackend
|
||||
from api.schemas.workflow import WorkflowRunResponseSchema
|
||||
from api.services.workflow.run_usage_response import format_public_cost_info
|
||||
from api.utils.recording_artifacts import get_recording_storage_key
|
||||
|
||||
|
||||
class WorkflowRunClient(BaseDBClient):
|
||||
|
|
@ -188,13 +189,19 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"workflow_name": run.workflow.name if run.workflow else None,
|
||||
"user_id": run.workflow.user_id if run.workflow else None,
|
||||
"organization_id": organization.id if organization else None,
|
||||
"organization_name": organization.provider_id
|
||||
if organization
|
||||
else None,
|
||||
"organization_name": (
|
||||
organization.provider_id if organization else None
|
||||
),
|
||||
"mode": run.mode,
|
||||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"user_recording_url": get_recording_storage_key(
|
||||
run.extra, "user"
|
||||
),
|
||||
"bot_recording_url": get_recording_storage_key(
|
||||
run.extra, "bot"
|
||||
),
|
||||
"usage_info": run.usage_info,
|
||||
"cost_info": run.cost_info,
|
||||
"initial_context": run.initial_context,
|
||||
|
|
@ -313,6 +320,12 @@ class WorkflowRunClient(BaseDBClient):
|
|||
"is_completed": run.is_completed,
|
||||
"recording_url": run.recording_url,
|
||||
"transcript_url": run.transcript_url,
|
||||
"user_recording_url": get_recording_storage_key(
|
||||
run.extra, "user"
|
||||
),
|
||||
"bot_recording_url": get_recording_storage_key(
|
||||
run.extra, "bot"
|
||||
),
|
||||
"cost_info": format_public_cost_info(
|
||||
run.cost_info, run.usage_info
|
||||
),
|
||||
|
|
@ -340,6 +353,7 @@ class WorkflowRunClient(BaseDBClient):
|
|||
logs: dict | None = None,
|
||||
state: str | None = None,
|
||||
annotations: dict | None = None,
|
||||
extra: dict | None = None,
|
||||
) -> WorkflowRunModel:
|
||||
async with self.async_session() as session:
|
||||
# Use SELECT FOR UPDATE to lock the row during the update
|
||||
|
|
@ -374,6 +388,8 @@ class WorkflowRunClient(BaseDBClient):
|
|||
run.logs = {**run.logs, **logs}
|
||||
if annotations:
|
||||
run.annotations = {**run.annotations, **annotations}
|
||||
if extra:
|
||||
run.extra = {**run.extra, **extra}
|
||||
if is_completed:
|
||||
run.is_completed = is_completed
|
||||
if state:
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ from api.services.configuration.masking import is_mask_of, mask_key, mask_user_c
|
|||
from api.services.configuration.registry import (
|
||||
DOGRAH_STT_LANGUAGES,
|
||||
REGISTRY,
|
||||
DograhTTSService,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
|
|
@ -210,6 +211,13 @@ async def get_telephony_config_warnings(user: UserModel = Depends(get_user)):
|
|||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _dograh_allows_custom_voice() -> bool:
|
||||
extra = DograhTTSService.model_fields["voice"].json_schema_extra
|
||||
if isinstance(extra, dict):
|
||||
return bool(extra.get("allow_custom_input", False))
|
||||
return False
|
||||
|
||||
|
||||
def _byok_provider_schemas(service_type: ServiceType) -> dict[str, dict]:
|
||||
return {
|
||||
provider: model_cls.model_json_schema()
|
||||
|
|
@ -251,6 +259,7 @@ async def get_model_configuration_v2_defaults(
|
|||
return {
|
||||
"dograh": {
|
||||
"voices": [DOGRAH_DEFAULT_VOICE],
|
||||
"allow_custom_input": _dograh_allows_custom_voice(),
|
||||
"speeds": list(DOGRAH_SPEED_OPTIONS),
|
||||
"languages": DOGRAH_STT_LANGUAGES,
|
||||
"defaults": {
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from api.services.auth.depends import get_user, get_user_with_selected_organizat
|
|||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.reports import generate_usage_runs_report_csv
|
||||
from api.utils.artifacts import artifact_url
|
||||
from api.utils.recording_artifacts import has_recording_track
|
||||
|
||||
router = APIRouter(prefix="/organizations")
|
||||
|
||||
|
|
@ -99,8 +100,12 @@ class WorkflowRunUsageResponse(BaseModel):
|
|||
call_duration_seconds: int
|
||||
recording_url: Optional[str] = None
|
||||
transcript_url: Optional[str] = None
|
||||
user_recording_url: Optional[str] = None
|
||||
bot_recording_url: Optional[str] = None
|
||||
recording_public_url: Optional[str] = None
|
||||
transcript_public_url: Optional[str] = None
|
||||
user_recording_public_url: Optional[str] = None
|
||||
bot_recording_public_url: Optional[str] = None
|
||||
public_access_token: Optional[str] = None
|
||||
phone_number: Optional[str] = Field(
|
||||
default=None,
|
||||
|
|
@ -308,14 +313,18 @@ async def get_billing_credits(
|
|||
aggregation_key=entry.get("aggregation_key"),
|
||||
usage_event_id=_optional_int(entry.get("usage_event_id")),
|
||||
workflow_run_id=_optional_int(entry.get("workflow_run_id")),
|
||||
workflow_id=workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None,
|
||||
billable_quantity=float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None,
|
||||
workflow_id=(
|
||||
workflow_ids_by_run_id.get(
|
||||
_optional_int(entry.get("workflow_run_id"))
|
||||
)
|
||||
if entry.get("workflow_run_id") is not None
|
||||
else None
|
||||
),
|
||||
billable_quantity=(
|
||||
float(entry["billable_quantity"])
|
||||
if entry.get("billable_quantity") is not None
|
||||
else None
|
||||
),
|
||||
quantity_unit=entry.get("quantity_unit"),
|
||||
metadata=entry.get("metadata") or {},
|
||||
created_at=str(entry["created_at"]),
|
||||
|
|
@ -478,6 +487,17 @@ async def get_usage_history(
|
|||
public_access_token, "transcript"
|
||||
)
|
||||
run["recording_public_url"] = artifact_url(public_access_token, "recording")
|
||||
run["user_recording_public_url"] = (
|
||||
artifact_url(public_access_token, "user_recording")
|
||||
if has_recording_track(run.get("extra"), "user")
|
||||
else None
|
||||
)
|
||||
run["bot_recording_public_url"] = (
|
||||
artifact_url(public_access_token, "bot_recording")
|
||||
if has_recording_track(run.get("extra"), "bot")
|
||||
else None
|
||||
)
|
||||
run.pop("extra", None)
|
||||
|
||||
return {
|
||||
"runs": runs,
|
||||
|
|
|
|||
|
|
@ -6,14 +6,16 @@ post-call processing for runs that execute integrations, QA, or campaign
|
|||
reporting.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_storage_for_backend
|
||||
from api.utils.recording_artifacts import (
|
||||
get_recording_storage_backend,
|
||||
get_recording_storage_key,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/public/download")
|
||||
|
||||
|
|
@ -21,7 +23,7 @@ router = APIRouter(prefix="/public/download")
|
|||
@router.get("/workflow/{token}/{artifact_type}")
|
||||
async def download_workflow_artifact(
|
||||
token: str,
|
||||
artifact_type: Literal["recording", "transcript"],
|
||||
artifact_type: str,
|
||||
inline: bool = Query(
|
||||
default=False, description="Display inline in browser instead of download"
|
||||
),
|
||||
|
|
@ -36,13 +38,15 @@ async def download_workflow_artifact(
|
|||
|
||||
Args:
|
||||
token: The public access token (UUID format)
|
||||
artifact_type: Type of artifact - "recording" or "transcript"
|
||||
artifact_type: Type of artifact - "recording", "transcript",
|
||||
"user_recording", or "bot_recording"
|
||||
inline: If true, sets Content-Disposition to inline for browser preview
|
||||
|
||||
Returns:
|
||||
RedirectResponse to the signed URL (302 redirect)
|
||||
|
||||
Raises:
|
||||
HTTPException 400: If artifact type is unsupported
|
||||
HTTPException 404: If token is invalid or artifact not found
|
||||
"""
|
||||
# 1. Lookup workflow run by token
|
||||
|
|
@ -52,10 +56,26 @@ async def download_workflow_artifact(
|
|||
raise HTTPException(status_code=404, detail="Invalid or expired token")
|
||||
|
||||
# 2. Get file path based on artifact type
|
||||
artifact_storage_backend = None
|
||||
if artifact_type == "recording":
|
||||
file_path = workflow_run.recording_url
|
||||
else: # transcript
|
||||
elif artifact_type == "transcript":
|
||||
file_path = workflow_run.transcript_url
|
||||
elif artifact_type == "user_recording":
|
||||
file_path = get_recording_storage_key(workflow_run.extra, "user")
|
||||
artifact_storage_backend = get_recording_storage_backend(
|
||||
workflow_run.extra, "user"
|
||||
)
|
||||
elif artifact_type == "bot_recording":
|
||||
file_path = get_recording_storage_key(workflow_run.extra, "bot")
|
||||
artifact_storage_backend = get_recording_storage_backend(
|
||||
workflow_run.extra, "bot"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unsupported artifact type: type={artifact_type}, workflow_run_id={workflow_run.id}"
|
||||
)
|
||||
raise HTTPException(status_code=400, detail="Unsupported artifact type")
|
||||
|
||||
if not file_path:
|
||||
logger.warning(
|
||||
|
|
@ -68,7 +88,9 @@ async def download_workflow_artifact(
|
|||
|
||||
# 3. Get storage backend for this workflow run
|
||||
try:
|
||||
storage = get_storage_for_backend(workflow_run.storage_backend)
|
||||
storage = get_storage_for_backend(
|
||||
artifact_storage_backend or workflow_run.storage_backend
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid storage backend: {workflow_run.storage_backend}")
|
||||
raise HTTPException(status_code=500, detail="Storage configuration error")
|
||||
|
|
|
|||
|
|
@ -40,14 +40,22 @@ class PresignedUploadUrlResponse(BaseModel):
|
|||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
|
||||
|
||||
ORG_SCOPED_STORAGE_PREFIXES = ("campaigns", "knowledge_base")
|
||||
|
||||
|
||||
def _extract_org_id_from_key(key: str) -> Optional[int]:
|
||||
"""Try to extract an organization ID from a storage key.
|
||||
|
||||
Matches keys of the form ``{prefix}/{org_id}/...`` where *org_id* is a
|
||||
positive integer. Returns ``None`` when the pattern does not match.
|
||||
Matches known org-scoped keys of the form ``{prefix}/{org_id}/...`` where
|
||||
*org_id* is a positive integer. Returns ``None`` when the pattern does not
|
||||
match.
|
||||
"""
|
||||
parts = key.split("/")
|
||||
if len(parts) >= 3 and parts[1].isdigit():
|
||||
if (
|
||||
len(parts) >= 3
|
||||
and parts[0] in ORG_SCOPED_STORAGE_PREFIXES
|
||||
and parts[1].isdigit()
|
||||
):
|
||||
return int(parts[1])
|
||||
return None
|
||||
|
||||
|
|
@ -58,15 +66,20 @@ def _extract_legacy_workflow_run_id(key: str) -> Optional[int]:
|
|||
Supports:
|
||||
- ``transcripts/{run_id}.txt``
|
||||
- ``recordings/{run_id}.wav``
|
||||
- ``recordings/{run_id}/user.wav``
|
||||
- ``recordings/{run_id}/bot.wav``
|
||||
|
||||
Returns ``None`` when the key does not match a legacy pattern.
|
||||
"""
|
||||
if key.startswith("transcripts/") and key.endswith(".txt"):
|
||||
run_id_str = key[len("transcripts/") : -4]
|
||||
elif key.startswith("recordings/") and key.endswith(".wav"):
|
||||
run_id_str = key[len("recordings/") : -4]
|
||||
else:
|
||||
return None
|
||||
recording_match = re.fullmatch(
|
||||
r"recordings/(\d+)(?:\.wav|/(?:user|bot)\.wav)", key
|
||||
)
|
||||
if not recording_match:
|
||||
return None
|
||||
run_id_str = recording_match.group(1)
|
||||
|
||||
return int(run_id_str) if run_id_str.isdigit() else None
|
||||
|
||||
|
|
@ -89,8 +102,13 @@ async def _validate_and_extract_workflow_run_id(
|
|||
"""
|
||||
if key.startswith("transcripts/") and key.endswith(".txt"):
|
||||
run_id_str = key[len("transcripts/") : -4] # strip prefix & suffix
|
||||
elif key.startswith("recordings/") and key.endswith(".wav"):
|
||||
run_id_str = key[len("recordings/") : -4]
|
||||
elif key.startswith("recordings/"):
|
||||
run_id = _extract_legacy_workflow_run_id(key)
|
||||
if run_id is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid workflow_run_id in key"
|
||||
)
|
||||
return run_id
|
||||
elif allow_special_paths and key.startswith("voicemail_detections/"):
|
||||
return None # Skip validation for these paths
|
||||
else:
|
||||
|
|
@ -159,9 +177,9 @@ async def get_signed_url(
|
|||
"""Return a short-lived signed URL for a file stored on S3 / MinIO.
|
||||
|
||||
Access Control:
|
||||
* Keys that embed an organization ID (``{prefix}/{org_id}/...``) are
|
||||
authorized by matching the org_id against the requesting user's
|
||||
organization.
|
||||
* Known org-scoped keys (for example ``campaigns/{org_id}/...`` and
|
||||
``knowledge_base/{org_id}/...``) are authorized by matching the org_id
|
||||
against the requesting user's organization.
|
||||
* Legacy keys (``recordings/{run_id}.wav``, ``transcripts/{run_id}.txt``)
|
||||
are authorized via the workflow run they belong to.
|
||||
* Superusers can request any key.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import ipaddress
|
|||
import os
|
||||
from datetime import UTC, datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
from aiortc import RTCIceServer
|
||||
from aiortc.sdp import candidate_from_sdp
|
||||
|
|
@ -246,6 +246,74 @@ class SignalingManager:
|
|||
def __init__(self):
|
||||
self._connections: Dict[str, WebSocket] = {}
|
||||
self._peer_connections: Dict[str, SmallWebRTCConnection] = {}
|
||||
self._connection_peer_ids: Dict[str, Set[str]] = {}
|
||||
self._peer_connection_owners: Dict[str, str] = {}
|
||||
|
||||
def _track_peer_connection(
|
||||
self, connection_id: str, pc_id: str, pc: SmallWebRTCConnection
|
||||
) -> None:
|
||||
self._peer_connections[pc_id] = pc
|
||||
self._peer_connection_owners[pc_id] = connection_id
|
||||
self._connection_peer_ids.setdefault(connection_id, set()).add(pc_id)
|
||||
|
||||
def _forget_peer_connection(self, pc_id: str) -> Optional[str]:
|
||||
connection_id = self._peer_connection_owners.pop(pc_id, None)
|
||||
self._peer_connections.pop(pc_id, None)
|
||||
|
||||
if connection_id:
|
||||
peer_ids = self._connection_peer_ids.get(connection_id)
|
||||
if peer_ids is not None:
|
||||
peer_ids.discard(pc_id)
|
||||
if not peer_ids:
|
||||
self._connection_peer_ids.pop(connection_id, None)
|
||||
|
||||
return connection_id
|
||||
|
||||
async def _send_json_if_connected(
|
||||
self, websocket: WebSocket, message: dict
|
||||
) -> bool:
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
return False
|
||||
|
||||
try:
|
||||
await websocket.send_json(message)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to send signaling WebSocket message: {e}")
|
||||
return False
|
||||
|
||||
async def _close_websocket_if_connected(
|
||||
self, websocket: WebSocket, code: int = 1000, reason: str = ""
|
||||
) -> None:
|
||||
if websocket.application_state != WebSocketState.CONNECTED:
|
||||
return
|
||||
|
||||
try:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to close signaling WebSocket: {e}")
|
||||
|
||||
async def _notify_call_ended_and_close_websocket(
|
||||
self,
|
||||
websocket: WebSocket,
|
||||
workflow_run_id: int,
|
||||
pc_id: str,
|
||||
reason: str,
|
||||
) -> None:
|
||||
await self._send_json_if_connected(
|
||||
websocket,
|
||||
{
|
||||
"type": "call-ended",
|
||||
"payload": {
|
||||
"workflow_run_id": workflow_run_id,
|
||||
"pc_id": pc_id,
|
||||
"reason": reason,
|
||||
},
|
||||
},
|
||||
)
|
||||
await self._close_websocket_if_connected(
|
||||
websocket, code=1000, reason="call ended"
|
||||
)
|
||||
|
||||
async def handle_websocket(
|
||||
self,
|
||||
|
|
@ -257,35 +325,51 @@ class SignalingManager:
|
|||
"""Handle WebSocket connection for signaling."""
|
||||
await websocket.accept()
|
||||
connection_id = f"{workflow_id}:{workflow_run_id}:{user.id}"
|
||||
self._connections[connection_id] = websocket
|
||||
connection_key = f"{connection_id}:{id(websocket)}"
|
||||
self._connections[connection_key] = websocket
|
||||
|
||||
try:
|
||||
while True:
|
||||
message = await websocket.receive_json()
|
||||
await self._handle_message(
|
||||
websocket, message, workflow_id, workflow_run_id, user
|
||||
websocket,
|
||||
message,
|
||||
workflow_id,
|
||||
workflow_run_id,
|
||||
user,
|
||||
connection_key,
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
logger.info(f"WebSocket disconnected for {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {connection_id}: {e}")
|
||||
if websocket.application_state == WebSocketState.DISCONNECTED:
|
||||
logger.info(f"WebSocket disconnected for {connection_id}")
|
||||
else:
|
||||
logger.error(f"WebSocket error for {connection_id}: {e}")
|
||||
finally:
|
||||
# Cleanup
|
||||
self._connections.pop(connection_id, None)
|
||||
self._connections.pop(connection_key, None)
|
||||
peer_ids = list(self._connection_peer_ids.pop(connection_key, set()))
|
||||
|
||||
# Unregister WebSocket sender for real-time feedback
|
||||
unregister_ws_sender(workflow_run_id)
|
||||
|
||||
# Clean up all peer connections for this workflow run
|
||||
# Clean up peer connections owned by this WebSocket.
|
||||
# Note: In a WebSocket-based signaling approach (vs HTTP PATCH),
|
||||
# we maintain our own connection map instead of relying on
|
||||
# SmallWebRTCRequestHandler's _pcs_map. This is suitable for
|
||||
# multi-worker FastAPI deployments where state cannot be shared.
|
||||
for pc_id in list(self._peer_connections.keys()):
|
||||
for pc_id in peer_ids:
|
||||
self._peer_connection_owners.pop(pc_id, None)
|
||||
pc = self._peer_connections.pop(pc_id, None)
|
||||
if pc:
|
||||
await pc.disconnect()
|
||||
logger.debug(f"Disconnected peer connection: {pc_id}")
|
||||
try:
|
||||
await pc.disconnect()
|
||||
logger.debug(f"Disconnected peer connection: {pc_id}")
|
||||
except Exception as e:
|
||||
logger.debug(
|
||||
f"Failed to disconnect peer connection {pc_id}: {e}"
|
||||
)
|
||||
|
||||
async def _handle_message(
|
||||
self,
|
||||
|
|
@ -294,17 +378,20 @@ class SignalingManager:
|
|||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel,
|
||||
connection_key: str,
|
||||
):
|
||||
"""Handle incoming WebSocket messages."""
|
||||
msg_type = message.get("type")
|
||||
payload = message.get("payload", {})
|
||||
|
||||
if msg_type == "offer":
|
||||
await self._handle_offer(ws, payload, workflow_id, workflow_run_id, user)
|
||||
await self._handle_offer(
|
||||
ws, payload, workflow_id, workflow_run_id, user, connection_key
|
||||
)
|
||||
elif msg_type == "ice-candidate":
|
||||
await self._handle_ice_candidate(ws, payload, workflow_run_id)
|
||||
await self._handle_ice_candidate(payload, connection_key)
|
||||
elif msg_type == "renegotiate":
|
||||
await self._handle_renegotiation(ws, payload, workflow_id, workflow_run_id)
|
||||
await self._handle_renegotiation(ws, payload, connection_key)
|
||||
|
||||
async def _handle_offer(
|
||||
self,
|
||||
|
|
@ -313,6 +400,7 @@ class SignalingManager:
|
|||
workflow_id: int,
|
||||
workflow_run_id: int,
|
||||
user: UserModel,
|
||||
connection_key: str,
|
||||
):
|
||||
"""Handle offer message and create answer with ICE trickling."""
|
||||
pc_id = payload.get("pc_id")
|
||||
|
|
@ -320,6 +408,15 @@ class SignalingManager:
|
|||
type_ = payload.get("type")
|
||||
call_context_vars = payload.get("call_context_vars", {})
|
||||
|
||||
if not pc_id or not sdp or not type_:
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"payload": {"message": "Missing offer fields"},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Set run context for logging and tracing. org_id must be set before
|
||||
# pc.initialize() so that aiortc's internal tasks inherit it.
|
||||
set_current_run_id(workflow_run_id)
|
||||
|
|
@ -347,7 +444,16 @@ class SignalingManager:
|
|||
)
|
||||
return
|
||||
|
||||
if pc_id and pc_id in self._peer_connections:
|
||||
if pc_id in self._peer_connections:
|
||||
if self._peer_connection_owners.get(pc_id) != connection_key:
|
||||
await ws.send_json(
|
||||
{
|
||||
"type": "error",
|
||||
"payload": {"message": "Peer connection already owned"},
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
# Reuse existing connection
|
||||
logger.info(f"Reusing existing connection for pc_id: {pc_id}")
|
||||
pc = self._peer_connections[pc_id]
|
||||
|
|
@ -379,7 +485,7 @@ class SignalingManager:
|
|||
await pc.initialize(sdp=sdp, type=type_)
|
||||
|
||||
# Store peer connection using client's pc_id
|
||||
self._peer_connections[pc_id] = pc
|
||||
self._track_peer_connection(connection_key, pc_id, pc)
|
||||
|
||||
# Register WebSocket sender for real-time feedback
|
||||
async def ws_sender(message: dict):
|
||||
|
|
@ -392,7 +498,16 @@ class SignalingManager:
|
|||
@pc.event_handler("closed")
|
||||
async def handle_disconnected(webrtc_connection: SmallWebRTCConnection):
|
||||
logger.info(f"PeerConnection closed: {webrtc_connection.pc_id}")
|
||||
self._peer_connections.pop(webrtc_connection.pc_id, None)
|
||||
owner_connection_id = self._forget_peer_connection(
|
||||
webrtc_connection.pc_id
|
||||
)
|
||||
if owner_connection_id == connection_key:
|
||||
await self._notify_call_ended_and_close_websocket(
|
||||
ws,
|
||||
workflow_run_id,
|
||||
webrtc_connection.pc_id,
|
||||
reason="peer_connection_closed",
|
||||
)
|
||||
|
||||
# Start pipeline in background
|
||||
asyncio.create_task(
|
||||
|
|
@ -421,9 +536,7 @@ class SignalingManager:
|
|||
}
|
||||
)
|
||||
|
||||
async def _handle_ice_candidate(
|
||||
self, ws: WebSocket, payload: dict, workflow_run_id: int
|
||||
):
|
||||
async def _handle_ice_candidate(self, payload: dict, connection_key: str):
|
||||
"""Handle incoming ICE candidate from client.
|
||||
|
||||
Uses SmallWebRTC's native ICE trickling support via add_ice_candidate().
|
||||
|
|
@ -442,6 +555,9 @@ class SignalingManager:
|
|||
if not pc:
|
||||
logger.warning(f"No peer connection found for pc_id: {pc_id}")
|
||||
return
|
||||
if self._peer_connection_owners.get(pc_id) != connection_key:
|
||||
logger.warning(f"Ignoring ICE candidate for unowned pc_id: {pc_id}")
|
||||
return
|
||||
|
||||
if candidate_data:
|
||||
candidate_str = candidate_data.get("candidate", "")
|
||||
|
|
@ -466,7 +582,7 @@ class SignalingManager:
|
|||
logger.debug(f"End of ICE candidates for pc_id: {pc_id}")
|
||||
|
||||
async def _handle_renegotiation(
|
||||
self, ws: WebSocket, payload: dict, workflow_id: int, workflow_run_id: int
|
||||
self, ws: WebSocket, payload: dict, connection_key: str
|
||||
):
|
||||
"""Handle renegotiation request."""
|
||||
pc_id = payload.get("pc_id")
|
||||
|
|
@ -479,6 +595,11 @@ class SignalingManager:
|
|||
{"type": "error", "payload": {"message": "Peer connection not found"}}
|
||||
)
|
||||
return
|
||||
if self._peer_connection_owners.get(pc_id) != connection_key:
|
||||
await ws.send_json(
|
||||
{"type": "error", "payload": {"message": "Peer connection not found"}}
|
||||
)
|
||||
return
|
||||
|
||||
pc = self._peer_connections[pc_id]
|
||||
await pc.renegotiate(sdp=sdp, type=type_, restart_pc=restart_pc)
|
||||
|
|
|
|||
|
|
@ -60,6 +60,10 @@ from api.services.workflow.trigger_paths import (
|
|||
)
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
from api.utils.artifacts import artifact_url
|
||||
from api.utils.recording_artifacts import (
|
||||
get_recording_storage_key,
|
||||
has_recording_track,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/workflow")
|
||||
|
||||
|
|
@ -1255,7 +1259,16 @@ async def get_workflow_run(
|
|||
raise HTTPException(status_code=404, detail="Workflow run not found")
|
||||
|
||||
public_access_token = run.public_access_token
|
||||
if (run.transcript_url or run.recording_url) and not public_access_token:
|
||||
user_recording_url = get_recording_storage_key(run.extra, "user")
|
||||
bot_recording_url = get_recording_storage_key(run.extra, "bot")
|
||||
has_user_recording = has_recording_track(run.extra, "user")
|
||||
has_bot_recording = has_recording_track(run.extra, "bot")
|
||||
if (
|
||||
run.transcript_url
|
||||
or run.recording_url
|
||||
or has_user_recording
|
||||
or has_bot_recording
|
||||
) and not public_access_token:
|
||||
public_access_token = await db_client.ensure_public_access_token(run.id)
|
||||
|
||||
return {
|
||||
|
|
@ -1266,8 +1279,20 @@ async def get_workflow_run(
|
|||
"is_completed": run.is_completed,
|
||||
"transcript_url": run.transcript_url,
|
||||
"recording_url": run.recording_url,
|
||||
"user_recording_url": user_recording_url,
|
||||
"bot_recording_url": bot_recording_url,
|
||||
"transcript_public_url": artifact_url(public_access_token, "transcript"),
|
||||
"recording_public_url": artifact_url(public_access_token, "recording"),
|
||||
"user_recording_public_url": (
|
||||
artifact_url(public_access_token, "user_recording")
|
||||
if has_user_recording
|
||||
else None
|
||||
),
|
||||
"bot_recording_public_url": (
|
||||
artifact_url(public_access_token, "bot_recording")
|
||||
if has_bot_recording
|
||||
else None
|
||||
),
|
||||
"public_access_token": public_access_token,
|
||||
"cost_info": format_public_cost_info(run.cost_info, run.usage_info),
|
||||
"usage_info": format_public_usage_info(run.usage_info),
|
||||
|
|
|
|||
|
|
@ -15,8 +15,12 @@ class WorkflowRunResponseSchema(BaseModel):
|
|||
is_completed: bool
|
||||
transcript_url: str | None
|
||||
recording_url: str | None
|
||||
user_recording_url: str | None = None
|
||||
bot_recording_url: str | None = None
|
||||
transcript_public_url: str | None = None
|
||||
recording_public_url: str | None = None
|
||||
user_recording_public_url: str | None = None
|
||||
bot_recording_public_url: str | None = None
|
||||
public_access_token: str | None = None
|
||||
cost_info: Dict[str, Any] | None
|
||||
usage_info: Dict[str, Any] | None = None
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SPEACHES.value: self._check_speaches_api_key,
|
||||
ServiceProviders.HUGGINGFACE.value: self._check_huggingface_api_key,
|
||||
ServiceProviders.GOOGLE_VERTEX.value: self._check_google_vertex_llm_api_key,
|
||||
ServiceProviders.OPENAI_REALTIME.value: self._check_openai_api_key,
|
||||
ServiceProviders.GROK_REALTIME.value: self._check_grok_realtime_api_key,
|
||||
|
|
@ -60,6 +61,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.GLADIA.value: self._check_gladia_api_key,
|
||||
ServiceProviders.RIME.value: self._check_rime_api_key,
|
||||
ServiceProviders.MINIMAX.value: self._check_minimax_api_key,
|
||||
ServiceProviders.SMALLEST.value: self._check_smallest_api_key,
|
||||
}
|
||||
|
||||
async def validate(
|
||||
|
|
@ -360,6 +362,14 @@ class UserConfigurationValidator:
|
|||
raise ValueError("base_url is required for Speaches services")
|
||||
return True
|
||||
|
||||
def _check_huggingface_api_key(self, model: str, api_key: str) -> bool:
|
||||
if not api_key.startswith("hf_"):
|
||||
raise ValueError(
|
||||
"Invalid Hugging Face API token format. Use a token that starts with "
|
||||
"'hf_' and has Inference Providers permission."
|
||||
)
|
||||
return True
|
||||
|
||||
def _check_google_vertex_realtime_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "project_id", None):
|
||||
raise ValueError("project_id is required for Google Vertex Realtime")
|
||||
|
|
@ -389,6 +399,7 @@ class UserConfigurationValidator:
|
|||
return True
|
||||
|
||||
def _check_minimax_api_key(self, model: str, api_key: str) -> bool:
|
||||
# MiniMax doesn't publish a cheap key-validation endpoint; trust the key
|
||||
# at save time and surface auth errors at first call (same as Rime/Sarvam).
|
||||
return True
|
||||
|
||||
def _check_smallest_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -68,6 +68,7 @@ class ServiceProviders(str, Enum):
|
|||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SPEACHES = "speaches"
|
||||
HUGGINGFACE = "huggingface"
|
||||
ASSEMBLYAI = "assemblyai"
|
||||
GLADIA = "gladia"
|
||||
RIME = "rime"
|
||||
|
|
@ -79,6 +80,7 @@ class ServiceProviders(str, Enum):
|
|||
GOOGLE_REALTIME = "google_realtime"
|
||||
GOOGLE_VERTEX_REALTIME = "google_vertex_realtime"
|
||||
AZURE_REALTIME = "azure_realtime"
|
||||
SMALLEST = "smallest"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
|
|
@ -94,6 +96,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SPEACHES,
|
||||
ServiceProviders.HUGGINGFACE,
|
||||
ServiceProviders.ASSEMBLYAI,
|
||||
ServiceProviders.GLADIA,
|
||||
ServiceProviders.RIME,
|
||||
|
|
@ -106,6 +109,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.GOOGLE_VERTEX_REALTIME,
|
||||
ServiceProviders.AZURE_REALTIME,
|
||||
ServiceProviders.SARVAM,
|
||||
ServiceProviders.SMALLEST,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
||||
|
|
@ -255,6 +259,11 @@ SPEACHES_PROVIDER_MODEL_CONFIG = provider_model_config(
|
|||
),
|
||||
provider_docs_url="https://github.com/speaches-ai/speaches",
|
||||
)
|
||||
HUGGINGFACE_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Hugging Face",
|
||||
description="Hosted Hugging Face Inference Providers API for usage-based inference.",
|
||||
provider_docs_url="https://huggingface.co/docs/inference-providers/en/index",
|
||||
)
|
||||
AZURE_SPEECH_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Azure Speech Services",
|
||||
description="Azure Cognitive Services Speech — TTS and STT via the Azure Speech SDK.",
|
||||
|
|
@ -471,6 +480,35 @@ class SpeachesLLMConfiguration(BaseLLMConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_LLM_MODELS = [
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"deepseek-ai/DeepSeek-R1:fastest",
|
||||
"Qwen/Qwen3-Coder-480B-A35B-Instruct:fastest",
|
||||
]
|
||||
|
||||
|
||||
@register_llm
|
||||
class HuggingFaceLLMConfiguration(BaseLLMConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/gpt-oss-120b:cerebras",
|
||||
description="Hugging Face chat-completion model identifier, optionally with provider suffix.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_LLM_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/v1",
|
||||
description="Hugging Face OpenAI-compatible chat-completions router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
|
||||
|
||||
MINIMAX_MODELS = [
|
||||
"MiniMax-M2.7",
|
||||
"MiniMax-M2.7-highspeed",
|
||||
|
|
@ -741,6 +779,7 @@ LLMConfig = Annotated[
|
|||
DograhLLMService,
|
||||
AWSBedrockLLMConfiguration,
|
||||
SpeachesLLMConfiguration,
|
||||
HuggingFaceLLMConfiguration,
|
||||
MiniMaxLLMConfiguration,
|
||||
SarvamLLMConfiguration,
|
||||
],
|
||||
|
|
@ -907,6 +946,7 @@ class DograhTTSService(BaseTTSConfiguration):
|
|||
voice: str = Field(
|
||||
default="default",
|
||||
description="Voice preset.",
|
||||
json_schema_extra={"allow_custom_input": True},
|
||||
)
|
||||
speed: float = Field(default=1.0, ge=0.5, le=2.0, description="Speed of the voice.")
|
||||
|
||||
|
|
@ -961,6 +1001,12 @@ class SarvamTTSConfiguration(BaseTTSConfiguration):
|
|||
description="BCP-47 Indian-language code (e.g. hi-IN, en-IN).",
|
||||
json_schema_extra={"examples": SARVAM_LANGUAGES},
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.5,
|
||||
le=2.0,
|
||||
description="Speech speed multiplier.",
|
||||
)
|
||||
|
||||
|
||||
CAMB_TTS_MODELS = ["mars-flash", "mars-pro", "mars-instruct"]
|
||||
|
|
@ -1120,6 +1166,80 @@ class AzureSpeechTTSConfiguration(BaseTTSConfiguration):
|
|||
)
|
||||
|
||||
|
||||
SMALLEST_PROVIDER_MODEL_CONFIG = provider_model_config(
|
||||
"Smallest AI",
|
||||
description="Smallest AI ultralow-latency TTS (Waves) and STT (Pulse) APIs.",
|
||||
provider_docs_url="https://smallest.ai/docs",
|
||||
)
|
||||
|
||||
SMALLEST_TTS_MODELS = ["lightning_v3.1", "lightning_v3.1_pro"]
|
||||
SMALLEST_TTS_VOICES = [
|
||||
"sophia",
|
||||
"avery",
|
||||
"liam",
|
||||
"lucas",
|
||||
"olivia",
|
||||
"ryan",
|
||||
"freya",
|
||||
"william",
|
||||
"devansh",
|
||||
"arjun",
|
||||
"niharika",
|
||||
"maya",
|
||||
"dhruv",
|
||||
"mia",
|
||||
"maithili",
|
||||
]
|
||||
SMALLEST_TTS_LANGUAGES = [
|
||||
"en",
|
||||
"hi",
|
||||
"fr",
|
||||
"de",
|
||||
"es",
|
||||
"it",
|
||||
"nl",
|
||||
"pl",
|
||||
"ru",
|
||||
"ar",
|
||||
"bn",
|
||||
"gu",
|
||||
"he",
|
||||
"kn",
|
||||
"mr",
|
||||
"ta",
|
||||
]
|
||||
|
||||
|
||||
@register_tts
|
||||
class SmallestAITTSConfiguration(BaseTTSConfiguration):
|
||||
model_config = SMALLEST_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SMALLEST] = ServiceProviders.SMALLEST
|
||||
model: str = Field(
|
||||
default="lightning_v3.1",
|
||||
description="Smallest AI TTS model. lightning_v3.1_pro is the premium pool (American, British, Indian accents); lightning_v3.1 is the standard pool with 217 voices across 12 languages.",
|
||||
json_schema_extra={"examples": SMALLEST_TTS_MODELS},
|
||||
)
|
||||
voice: str = Field(
|
||||
default="sophia",
|
||||
description="Smallest AI voice ID.",
|
||||
json_schema_extra={"examples": SMALLEST_TTS_VOICES, "allow_custom_input": True},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
description="ISO 639-1 language code for synthesis.",
|
||||
json_schema_extra={
|
||||
"examples": SMALLEST_TTS_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
speed: float = Field(
|
||||
default=1.0,
|
||||
ge=0.5,
|
||||
le=2.0,
|
||||
description="Speech speed multiplier (0.5 to 2.0).",
|
||||
)
|
||||
|
||||
|
||||
TTSConfig = Annotated[
|
||||
Union[
|
||||
DeepgramTTSConfiguration,
|
||||
|
|
@ -1134,6 +1254,7 @@ TTSConfig = Annotated[
|
|||
SpeachesTTSConfiguration,
|
||||
MiniMaxTTSConfiguration,
|
||||
AzureSpeechTTSConfiguration,
|
||||
SmallestAITTSConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -1334,6 +1455,38 @@ class SpeachesSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
HUGGINGFACE_STT_MODELS = [
|
||||
"openai/whisper-large-v3-turbo",
|
||||
"openai/whisper-large-v3",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class HuggingFaceSTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = HUGGINGFACE_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.HUGGINGFACE] = ServiceProviders.HUGGINGFACE
|
||||
model: str = Field(
|
||||
default="openai/whisper-large-v3-turbo",
|
||||
description="Hugging Face ASR model identifier served through Inference Providers.",
|
||||
json_schema_extra={
|
||||
"examples": HUGGINGFACE_STT_MODELS,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="https://router.huggingface.co/hf-inference",
|
||||
description="Hugging Face Inference Providers router base URL.",
|
||||
)
|
||||
bill_to: str | None = Field(
|
||||
default=None,
|
||||
description="Optional Hugging Face organization or user to bill using X-HF-Bill-To.",
|
||||
)
|
||||
return_timestamps: bool = Field(
|
||||
default=False,
|
||||
description="Request timestamp chunks when supported by the selected provider/model.",
|
||||
)
|
||||
|
||||
|
||||
ASSEMBLYAI_STT_MODELS = ["u3-rt-pro"]
|
||||
ASSEMBLYAI_STT_LANGUAGES = ["en", "es", "de", "fr", "pt", "it"]
|
||||
|
||||
|
|
@ -1396,6 +1549,62 @@ class AzureSpeechSTTConfiguration(BaseSTTConfiguration):
|
|||
)
|
||||
|
||||
|
||||
SMALLEST_STT_MODELS = ["pulse"]
|
||||
SMALLEST_STT_LANGUAGES = [
|
||||
"en",
|
||||
"hi",
|
||||
"fr",
|
||||
"de",
|
||||
"es",
|
||||
"it",
|
||||
"nl",
|
||||
"pl",
|
||||
"ru",
|
||||
"pt",
|
||||
"bn",
|
||||
"gu",
|
||||
"kn",
|
||||
"ml",
|
||||
"mr",
|
||||
"ta",
|
||||
"te",
|
||||
"pa",
|
||||
"or",
|
||||
"bg",
|
||||
"cs",
|
||||
"da",
|
||||
"et",
|
||||
"fi",
|
||||
"hu",
|
||||
"lt",
|
||||
"lv",
|
||||
"mt",
|
||||
"ro",
|
||||
"sk",
|
||||
"sv",
|
||||
"uk",
|
||||
]
|
||||
|
||||
|
||||
@register_stt
|
||||
class SmallestAISTTConfiguration(BaseSTTConfiguration):
|
||||
model_config = SMALLEST_PROVIDER_MODEL_CONFIG
|
||||
provider: Literal[ServiceProviders.SMALLEST] = ServiceProviders.SMALLEST
|
||||
model: str = Field(
|
||||
default="pulse",
|
||||
description="Smallest AI STT model. Supports 38 languages with real-time streaming.",
|
||||
json_schema_extra={"examples": SMALLEST_STT_MODELS},
|
||||
)
|
||||
language: str = Field(
|
||||
default="en",
|
||||
description="ISO 639-1 language code for transcription.",
|
||||
json_schema_extra={
|
||||
"examples": SMALLEST_STT_LANGUAGES,
|
||||
"allow_custom_input": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
STTConfig = Annotated[
|
||||
Union[
|
||||
DeepgramSTTConfiguration,
|
||||
|
|
@ -1406,9 +1615,11 @@ STTConfig = Annotated[
|
|||
SpeechmaticsSTTConfiguration,
|
||||
SarvamSTTConfiguration,
|
||||
SpeachesSTTConfiguration,
|
||||
HuggingFaceSTTConfiguration,
|
||||
AssemblyAISTTConfiguration,
|
||||
GladiaSTTConfiguration,
|
||||
AzureSpeechSTTConfiguration,
|
||||
SmallestAISTTConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ from api.services.integrations import IntegrationRuntimeSession
|
|||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.audio_playback import play_audio_loop
|
||||
from api.services.pipecat.in_memory_buffers import (
|
||||
InMemoryAudioBuffer,
|
||||
InMemoryLogsBuffer,
|
||||
InMemoryRecordingBuffers,
|
||||
)
|
||||
from api.services.pipecat.pipeline_metrics_aggregator import PipelineMetricsAggregator
|
||||
from api.services.pipecat.tracing_config import get_trace_url
|
||||
|
|
@ -40,11 +40,11 @@ async def _capture_call_event(
|
|||
"workflow_run_id": workflow_run_id,
|
||||
"workflow_id": workflow_run.workflow_id if workflow_run else None,
|
||||
"call_type": workflow_run.mode if workflow_run else None,
|
||||
"call_direction": (workflow_run.initial_context or {}).get(
|
||||
"direction", "outbound"
|
||||
)
|
||||
if workflow_run
|
||||
else None,
|
||||
"call_direction": (
|
||||
(workflow_run.initial_context or {}).get("direction", "outbound")
|
||||
if workflow_run
|
||||
else None
|
||||
),
|
||||
}
|
||||
if extra_properties:
|
||||
properties.update(extra_properties)
|
||||
|
|
@ -73,7 +73,7 @@ def register_event_handlers(
|
|||
"""Register all event handlers for transport and task events.
|
||||
|
||||
Returns:
|
||||
in_memory_audio_buffer for use by other handlers.
|
||||
In-memory recording buffers for use by other handlers.
|
||||
"""
|
||||
# Initialize in-memory buffers with proper audio configuration
|
||||
sample_rate = audio_config.pipeline_sample_rate if audio_config else 16000
|
||||
|
|
@ -84,7 +84,7 @@ def register_event_handlers(
|
|||
f"with sample_rate={sample_rate}Hz, channels={num_channels}"
|
||||
)
|
||||
|
||||
in_memory_audio_buffer = InMemoryAudioBuffer(
|
||||
in_memory_audio_buffers = InMemoryRecordingBuffers(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
|
|
@ -363,14 +363,32 @@ def register_event_handlers(
|
|||
|
||||
# Write buffers to temp files and enqueue combined processing task
|
||||
audio_temp_path = None
|
||||
user_audio_temp_path = None
|
||||
bot_audio_temp_path = None
|
||||
transcript_temp_path = None
|
||||
|
||||
try:
|
||||
if not in_memory_audio_buffer.is_empty:
|
||||
audio_temp_path = await in_memory_audio_buffer.write_to_temp_file()
|
||||
if not in_memory_audio_buffers.mixed.is_empty:
|
||||
audio_temp_path = (
|
||||
await in_memory_audio_buffers.mixed.write_to_temp_file()
|
||||
)
|
||||
else:
|
||||
logger.debug("Audio buffer is empty, skipping upload")
|
||||
|
||||
if not in_memory_audio_buffers.user.is_empty:
|
||||
user_audio_temp_path = (
|
||||
await in_memory_audio_buffers.user.write_to_temp_file()
|
||||
)
|
||||
else:
|
||||
logger.debug("User audio buffer is empty, skipping upload")
|
||||
|
||||
if not in_memory_audio_buffers.bot.is_empty:
|
||||
bot_audio_temp_path = (
|
||||
await in_memory_audio_buffers.bot.write_to_temp_file()
|
||||
)
|
||||
else:
|
||||
logger.debug("Bot audio buffer is empty, skipping upload")
|
||||
|
||||
transcript_temp_path = in_memory_logs_buffer.write_transcript_to_temp_file()
|
||||
if not transcript_temp_path:
|
||||
logger.debug("No transcript events in logs buffer, skipping upload")
|
||||
|
|
@ -385,16 +403,18 @@ def register_event_handlers(
|
|||
workflow_run_id,
|
||||
audio_temp_path,
|
||||
transcript_temp_path,
|
||||
user_audio_temp_path,
|
||||
bot_audio_temp_path,
|
||||
)
|
||||
|
||||
# Return the buffer so it can be passed to other handlers
|
||||
return in_memory_audio_buffer
|
||||
return in_memory_audio_buffers
|
||||
|
||||
|
||||
def register_audio_data_handler(
|
||||
audio_buffer: AudioBufferProcessor,
|
||||
workflow_run_id,
|
||||
in_memory_buffer: InMemoryAudioBuffer,
|
||||
in_memory_buffers: InMemoryRecordingBuffers,
|
||||
):
|
||||
"""Register event handler for audio data"""
|
||||
logger.info(f"Registering audio data handler for workflow run {workflow_run_id}")
|
||||
|
|
@ -404,9 +424,19 @@ def register_audio_data_handler(
|
|||
if not audio:
|
||||
return
|
||||
|
||||
# Use in-memory buffer
|
||||
try:
|
||||
await in_memory_buffer.append(audio)
|
||||
await in_memory_buffers.mixed.append(audio)
|
||||
except MemoryError as e:
|
||||
logger.error(f"Memory buffer full: {e}")
|
||||
# Could implement overflow to disk here if needed
|
||||
logger.error(f"Mixed audio buffer full: {e}")
|
||||
|
||||
@audio_buffer.event_handler("on_track_audio_data")
|
||||
async def on_track_audio_data(
|
||||
buffer, user_audio, bot_audio, sample_rate, num_channels
|
||||
):
|
||||
try:
|
||||
if user_audio:
|
||||
await in_memory_buffers.user.append(user_audio)
|
||||
if bot_audio:
|
||||
await in_memory_buffers.bot.append(bot_audio)
|
||||
except MemoryError as e:
|
||||
logger.error(f"Track audio buffer full: {e}")
|
||||
|
|
|
|||
|
|
@ -75,6 +75,27 @@ class InMemoryAudioBuffer:
|
|||
return self._total_size
|
||||
|
||||
|
||||
class InMemoryRecordingBuffers:
|
||||
"""Holds the mixed recording plus aligned user and bot mono tracks."""
|
||||
|
||||
def __init__(self, workflow_run_id: int, sample_rate: int, num_channels: int = 1):
|
||||
self.mixed = InMemoryAudioBuffer(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=num_channels,
|
||||
)
|
||||
self.user = InMemoryAudioBuffer(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
self.bot = InMemoryAudioBuffer(
|
||||
workflow_run_id=workflow_run_id,
|
||||
sample_rate=sample_rate,
|
||||
num_channels=1,
|
||||
)
|
||||
|
||||
|
||||
class InMemoryLogsBuffer:
|
||||
"""Buffer real-time feedback events in memory during a call, then save to workflow run logs."""
|
||||
|
||||
|
|
|
|||
|
|
@ -39,8 +39,17 @@ from pipecat.services.google.vertex.llm import (
|
|||
GoogleVertexLLMSettings,
|
||||
)
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.huggingface.llm import (
|
||||
HuggingFaceLLMService,
|
||||
HuggingFaceLLMSettings,
|
||||
)
|
||||
from pipecat.services.huggingface.stt import (
|
||||
HuggingFaceSTTService,
|
||||
HuggingFaceSTTSettings,
|
||||
)
|
||||
from pipecat.services.minimax.llm import MiniMaxLLMService
|
||||
from pipecat.services.minimax.tts import MiniMaxTTSSettings
|
||||
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import (
|
||||
|
|
@ -53,6 +62,8 @@ from pipecat.services.rime.tts import RimeTTSService, RimeTTSSettings
|
|||
from pipecat.services.sarvam.llm import SarvamLLMService, SarvamLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.smallest.stt import SmallestSTTService, SmallestSTTSettings
|
||||
from pipecat.services.smallest.tts import SmallestTTSService, SmallestTTSSettings
|
||||
from pipecat.services.speaches.llm import SpeachesLLMService, SpeachesLLMSettings
|
||||
from pipecat.services.speaches.stt import SpeachesSTTService, SpeachesSTTSettings
|
||||
from pipecat.services.speaches.tts import SpeachesTTSService, SpeachesTTSSettings
|
||||
|
|
@ -218,6 +229,22 @@ def create_stt_service(
|
|||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = (
|
||||
getattr(user_config.stt, "base_url", None)
|
||||
or "https://router.huggingface.co/hf-inference"
|
||||
)
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
base_url=base_url,
|
||||
bill_to=getattr(user_config.stt, "bill_to", None),
|
||||
settings=HuggingFaceSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
return_timestamps=getattr(user_config.stt, "return_timestamps", False),
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.ASSEMBLYAI.value:
|
||||
language = getattr(user_config.stt, "language", None)
|
||||
settings_kwargs = {"model": user_config.stt.model, "language": language}
|
||||
|
|
@ -284,6 +311,20 @@ def create_stt_service(
|
|||
settings=AzureSTTSettings(language=pipecat_language),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SMALLEST.value:
|
||||
language_code = getattr(user_config.stt, "language", None) or "en"
|
||||
try:
|
||||
pipecat_language = Language(language_code)
|
||||
except ValueError:
|
||||
pipecat_language = Language.EN
|
||||
return SmallestSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=SmallestSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=pipecat_language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid STT provider {user_config.stt.provider}"
|
||||
|
|
@ -320,6 +361,7 @@ def create_tts_service(
|
|||
kwargs["base_url"] = base_url
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
sample_rate=OPENAI_SAMPLE_RATE,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
|
|
@ -493,13 +535,17 @@ def create_tts_service(
|
|||
pipecat_language = language_mapping.get(language, Language.HI)
|
||||
|
||||
voice = getattr(user_config.tts, "voice", None) or "anushka"
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
settings_kwargs = {
|
||||
"model": user_config.tts.model,
|
||||
"voice": voice,
|
||||
"language": pipecat_language,
|
||||
}
|
||||
if speed and speed != 1.0:
|
||||
settings_kwargs["pace"] = speed
|
||||
return SarvamTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
settings=SarvamTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=voice,
|
||||
language=pipecat_language,
|
||||
),
|
||||
settings=SarvamTTSSettings(**settings_kwargs),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
|
|
@ -560,6 +606,28 @@ def create_tts_service(
|
|||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SMALLEST.value:
|
||||
language_code = getattr(user_config.tts, "language", None) or "en"
|
||||
try:
|
||||
pipecat_language = Language(language_code)
|
||||
except ValueError:
|
||||
pipecat_language = Language.EN
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
model = user_config.tts.model.replace("lightning-v", "lightning_v")
|
||||
settings_kwargs = SmallestTTSSettings(
|
||||
model=model,
|
||||
voice=user_config.tts.voice,
|
||||
language=pipecat_language,
|
||||
)
|
||||
if speed and speed != 1.0:
|
||||
settings_kwargs.speed = speed
|
||||
return SmallestTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
settings=settings_kwargs,
|
||||
text_filters=[xml_function_tag_filter],
|
||||
skip_aggregator_types=["recording_router", "recording"],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Invalid TTS provider {user_config.tts.provider}"
|
||||
|
|
@ -581,6 +649,7 @@ def create_llm_service_from_provider(
|
|||
location: str | None = None,
|
||||
credentials: str | None = None,
|
||||
temperature: float | None = None,
|
||||
bill_to: str | None = None,
|
||||
):
|
||||
"""Create an LLM service from explicit provider/model/api_key.
|
||||
|
||||
|
|
@ -663,6 +732,15 @@ def create_llm_service_from_provider(
|
|||
api_key=api_key or "none",
|
||||
settings=SpeachesLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
base_url = base_url or "https://router.huggingface.co/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
return HuggingFaceLLMService(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
bill_to=bill_to,
|
||||
settings=HuggingFaceLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif provider == ServiceProviders.MINIMAX.value:
|
||||
base_url = base_url or "https://api.minimax.io/v1"
|
||||
_validate_runtime_service_url(base_url, "base_url")
|
||||
|
|
@ -875,6 +953,9 @@ def create_llm_service(user_config, correlation_id: str | None = None):
|
|||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SPEACHES.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.HUGGINGFACE.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
kwargs["bill_to"] = user_config.llm.bill_to
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
|
|
|
|||
|
|
@ -718,6 +718,8 @@ class TriggerNodeData(BaseNodeData):
|
|||
"rsvp": "{{gathered_context.rsvp}}",
|
||||
"duration": "{{cost_info.call_duration_seconds}}",
|
||||
"recording_url": "{{recording_url}}",
|
||||
"user_recording_url": "{{user_recording_url}}",
|
||||
"bot_recording_url": "{{bot_recording_url}}",
|
||||
"transcript_url": "{{transcript_url}}",
|
||||
},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ from api.services.workflow.dto import (
|
|||
)
|
||||
from api.services.workflow.qa import run_per_node_qa_analysis
|
||||
from api.utils.credential_auth import build_auth_header
|
||||
from api.utils.recording_artifacts import get_recording_storage_key
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
|
|
@ -339,6 +340,10 @@ def _build_render_context(
|
|||
Returns:
|
||||
Dict containing all fields available for template rendering
|
||||
"""
|
||||
extra = workflow_run.extra or {}
|
||||
user_recording_key = get_recording_storage_key(extra, "user")
|
||||
bot_recording_key = get_recording_storage_key(extra, "bot")
|
||||
|
||||
context = {
|
||||
# Top-level fields
|
||||
"workflow_run_id": workflow_run.id,
|
||||
|
|
@ -353,6 +358,7 @@ def _build_render_context(
|
|||
"cost_info": workflow_run.usage_info or {},
|
||||
# Annotations (includes QA results)
|
||||
"annotations": workflow_run.annotations or {},
|
||||
"extra": extra,
|
||||
}
|
||||
|
||||
# Add public download URLs if token is available
|
||||
|
|
@ -366,9 +372,17 @@ def _build_render_context(
|
|||
context["transcript_url"] = (
|
||||
f"{base_url}/transcript" if workflow_run.transcript_url else None
|
||||
)
|
||||
context["user_recording_url"] = (
|
||||
f"{base_url}/user_recording" if user_recording_key else None
|
||||
)
|
||||
context["bot_recording_url"] = (
|
||||
f"{base_url}/bot_recording" if bot_recording_key else None
|
||||
)
|
||||
else:
|
||||
context["recording_url"] = workflow_run.recording_url
|
||||
context["transcript_url"] = workflow_run.transcript_url
|
||||
context["user_recording_url"] = user_recording_key
|
||||
context["bot_recording_url"] = bot_recording_key
|
||||
|
||||
return context
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,51 @@ from api.services.workflow_run_billing import (
|
|||
from api.tasks.run_integrations import run_integrations_post_workflow_run
|
||||
|
||||
|
||||
def _recording_metadata(storage_key: str, storage_backend: str, track: str) -> dict:
|
||||
return {
|
||||
"storage_key": storage_key,
|
||||
"storage_backend": storage_backend,
|
||||
"format": "wav",
|
||||
"track": track,
|
||||
}
|
||||
|
||||
|
||||
async def _upload_temp_file(
|
||||
workflow_run_id: int,
|
||||
temp_file_path: str,
|
||||
storage_key: str,
|
||||
label: str,
|
||||
) -> bool:
|
||||
try:
|
||||
if not os.path.exists(temp_file_path):
|
||||
logger.warning(f"{label} temp file not found: {temp_file_path}")
|
||||
return False
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
logger.debug(f"{label} file size: {file_size} bytes")
|
||||
|
||||
await storage_fs.aupload_file(temp_file_path, storage_key)
|
||||
logger.info(f"Successfully uploaded {label}: {storage_key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading {label} for workflow {workflow_run_id}: {e}")
|
||||
return False
|
||||
finally:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.remove(temp_file_path)
|
||||
logger.debug(f"Cleaned up temp {label} file: {temp_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp {label} file: {e}")
|
||||
|
||||
|
||||
async def process_workflow_completion(
|
||||
_ctx,
|
||||
workflow_run_id: int,
|
||||
audio_temp_path: Optional[str] = None,
|
||||
transcript_temp_path: Optional[str] = None,
|
||||
user_audio_temp_path: Optional[str] = None,
|
||||
bot_audio_temp_path: Optional[str] = None,
|
||||
):
|
||||
"""Process workflow completion: upload artifacts and run integrations.
|
||||
|
||||
|
|
@ -28,6 +68,8 @@ async def process_workflow_completion(
|
|||
workflow_run_id: The workflow run ID
|
||||
audio_temp_path: Optional path to temp audio file
|
||||
transcript_temp_path: Optional path to temp transcript file
|
||||
user_audio_temp_path: Optional path to temp user-track audio file
|
||||
bot_audio_temp_path: Optional path to temp bot-track audio file
|
||||
"""
|
||||
run_id = str(workflow_run_id)
|
||||
set_current_run_id(run_id)
|
||||
|
|
@ -37,35 +79,55 @@ async def process_workflow_completion(
|
|||
storage_backend = get_current_storage_backend()
|
||||
|
||||
# Step 1: Upload audio if provided
|
||||
recordings_metadata: dict[str, dict] = {}
|
||||
|
||||
if audio_temp_path:
|
||||
try:
|
||||
if os.path.exists(audio_temp_path):
|
||||
file_size = os.path.getsize(audio_temp_path)
|
||||
logger.debug(f"Audio file size: {file_size} bytes")
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading mixed audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
if await _upload_temp_file(
|
||||
workflow_run_id, audio_temp_path, recording_url, "mixed audio"
|
||||
):
|
||||
recordings_metadata["mixed"] = _recording_metadata(
|
||||
recording_url, storage_backend.value, "mixed"
|
||||
)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
|
||||
recording_url = f"recordings/{workflow_run_id}.wav"
|
||||
logger.info(
|
||||
f"Uploading audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
if user_audio_temp_path:
|
||||
user_recording_url = f"recordings/{workflow_run_id}/user.wav"
|
||||
logger.info(
|
||||
f"Uploading user audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
if await _upload_temp_file(
|
||||
workflow_run_id, user_audio_temp_path, user_recording_url, "user audio"
|
||||
):
|
||||
recordings_metadata["user"] = _recording_metadata(
|
||||
user_recording_url, storage_backend.value, "user"
|
||||
)
|
||||
|
||||
await storage_fs.aupload_file(audio_temp_path, recording_url)
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
recording_url=recording_url,
|
||||
storage_backend=storage_backend.value,
|
||||
)
|
||||
logger.info(f"Successfully uploaded audio: {recording_url}")
|
||||
else:
|
||||
logger.warning(f"Audio temp file not found: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error uploading audio for workflow {workflow_run_id}: {e}")
|
||||
finally:
|
||||
if audio_temp_path and os.path.exists(audio_temp_path):
|
||||
try:
|
||||
os.remove(audio_temp_path)
|
||||
logger.debug(f"Cleaned up temp audio file: {audio_temp_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to clean up temp audio file: {e}")
|
||||
if bot_audio_temp_path:
|
||||
bot_recording_url = f"recordings/{workflow_run_id}/bot.wav"
|
||||
logger.info(
|
||||
f"Uploading bot audio to {storage_backend.name} - workflow_run_id: {workflow_run_id}"
|
||||
)
|
||||
if await _upload_temp_file(
|
||||
workflow_run_id, bot_audio_temp_path, bot_recording_url, "bot audio"
|
||||
):
|
||||
recordings_metadata["bot"] = _recording_metadata(
|
||||
bot_recording_url, storage_backend.value, "bot"
|
||||
)
|
||||
|
||||
if recordings_metadata:
|
||||
await db_client.update_workflow_run(
|
||||
run_id=workflow_run_id,
|
||||
storage_backend=storage_backend.value,
|
||||
extra={"recordings": recordings_metadata},
|
||||
)
|
||||
|
||||
# Step 2: Upload transcript if provided
|
||||
if transcript_temp_path:
|
||||
|
|
|
|||
131
api/tests/test_huggingface_stt_service_factory.py
Normal file
131
api/tests/test_huggingface_stt_service_factory.py
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
HuggingFaceLLMConfiguration,
|
||||
HuggingFaceSTTConfiguration,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_stt_configuration_defaults_and_registry():
|
||||
config = HuggingFaceSTTConfiguration(api_key="hf_test")
|
||||
|
||||
assert config.provider == ServiceProviders.HUGGINGFACE
|
||||
assert config.model == "openai/whisper-large-v3-turbo"
|
||||
assert config.base_url == "https://router.huggingface.co/hf-inference"
|
||||
assert config.return_timestamps is False
|
||||
assert (
|
||||
REGISTRY[ServiceType.STT][ServiceProviders.HUGGINGFACE]
|
||||
is HuggingFaceSTTConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_llm_configuration_defaults_and_registry():
|
||||
config = HuggingFaceLLMConfiguration(api_key="hf_test")
|
||||
|
||||
assert config.provider == ServiceProviders.HUGGINGFACE
|
||||
assert config.model == "openai/gpt-oss-120b:cerebras"
|
||||
assert config.base_url == "https://router.huggingface.co/v1"
|
||||
assert (
|
||||
REGISTRY[ServiceType.LLM][ServiceProviders.HUGGINGFACE]
|
||||
is HuggingFaceLLMConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_create_huggingface_llm_service_uses_openai_compatible_router():
|
||||
user_config = SimpleNamespace(
|
||||
llm=SimpleNamespace(
|
||||
provider=ServiceProviders.HUGGINGFACE.value,
|
||||
api_key="hf_test",
|
||||
model="deepseek-ai/DeepSeek-R1:fastest",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
bill_to="demo-org",
|
||||
)
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.HuggingFaceLLMService"
|
||||
) as mock_service:
|
||||
create_llm_service(user_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "hf_test"
|
||||
assert kwargs["base_url"] == "https://router.huggingface.co/v1"
|
||||
assert kwargs["bill_to"] == "demo-org"
|
||||
assert kwargs["settings"].model == "deepseek-ai/DeepSeek-R1:fastest"
|
||||
assert kwargs["settings"].temperature == 0.1
|
||||
|
||||
|
||||
def test_create_huggingface_stt_service_uses_hosted_defaults():
|
||||
user_config = SimpleNamespace(
|
||||
stt=SimpleNamespace(
|
||||
provider=ServiceProviders.HUGGINGFACE.value,
|
||||
api_key="hf_test",
|
||||
model="openai/whisper-large-v3-turbo",
|
||||
base_url="https://router.huggingface.co/hf-inference",
|
||||
bill_to="demo-org",
|
||||
return_timestamps=True,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.HuggingFaceSTTService"
|
||||
) as mock_service:
|
||||
create_stt_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "hf_test"
|
||||
assert kwargs["base_url"] == "https://router.huggingface.co/hf-inference"
|
||||
assert kwargs["bill_to"] == "demo-org"
|
||||
assert kwargs["sample_rate"] == 16000
|
||||
assert kwargs["settings"].model == "openai/whisper-large-v3-turbo"
|
||||
assert kwargs["settings"].return_timestamps is True
|
||||
|
||||
|
||||
def test_validator_accepts_huggingface_stt_token_format():
|
||||
validator = UserConfigurationValidator()
|
||||
|
||||
assert (
|
||||
validator._validate_service(
|
||||
HuggingFaceSTTConfiguration(api_key="hf_test"),
|
||||
"stt",
|
||||
)
|
||||
== []
|
||||
)
|
||||
assert (
|
||||
validator._validate_service(
|
||||
HuggingFaceLLMConfiguration(api_key="hf_test"),
|
||||
"llm",
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_validator_rejects_non_huggingface_token_format():
|
||||
validator = UserConfigurationValidator()
|
||||
|
||||
errors = validator._validate_service(
|
||||
HuggingFaceSTTConfiguration(api_key="not-hf-token"),
|
||||
"stt",
|
||||
)
|
||||
|
||||
assert errors == [
|
||||
{
|
||||
"model": "stt",
|
||||
"message": (
|
||||
"Invalid Hugging Face API token format. Use a token that starts with "
|
||||
"'hf_' and has Inference Providers permission."
|
||||
),
|
||||
}
|
||||
]
|
||||
31
api/tests/test_openai_tts_service_factory.py
Normal file
31
api/tests/test_openai_tts_service_factory.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from pipecat.services.openai._constants import OPENAI_SAMPLE_RATE
|
||||
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from api.services.pipecat.service_factory import create_tts_service
|
||||
|
||||
|
||||
def test_create_openai_tts_service_uses_openai_pcm_sample_rate():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.OPENAI.value,
|
||||
api_key="test-key",
|
||||
model="gpt-4o-mini-tts",
|
||||
voice="alloy",
|
||||
base_url=None,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(
|
||||
transport_out_sample_rate=16000,
|
||||
transport_in_sample_rate=16000,
|
||||
)
|
||||
|
||||
with patch("api.services.pipecat.service_factory.OpenAITTSService") as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["sample_rate"] == OPENAI_SAMPLE_RATE
|
||||
assert kwargs["settings"].model == "gpt-4o-mini-tts"
|
||||
30
api/tests/test_s3_signed_url.py
Normal file
30
api/tests/test_s3_signed_url.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from api.routes.s3_signed_url import (
|
||||
_extract_legacy_workflow_run_id,
|
||||
_extract_org_id_from_key,
|
||||
)
|
||||
|
||||
|
||||
def test_split_recording_keys_are_workflow_run_artifacts_not_org_keys():
|
||||
assert _extract_legacy_workflow_run_id("recordings/1855/user.wav") == 1855
|
||||
assert _extract_legacy_workflow_run_id("recordings/1855/bot.wav") == 1855
|
||||
|
||||
assert _extract_org_id_from_key("recordings/1855/user.wav") is None
|
||||
assert _extract_org_id_from_key("recordings/1855/bot.wav") is None
|
||||
|
||||
|
||||
def test_legacy_recording_keys_do_not_fall_through_to_org_scoped_auth():
|
||||
assert _extract_legacy_workflow_run_id("recordings/1855.wav") == 1855
|
||||
assert _extract_legacy_workflow_run_id("recordings/1855/other.wav") is None
|
||||
|
||||
assert _extract_org_id_from_key("recordings/1855.wav") is None
|
||||
assert _extract_org_id_from_key("recordings/1855/other.wav") is None
|
||||
|
||||
|
||||
def test_known_org_scoped_keys_extract_org_id():
|
||||
assert _extract_org_id_from_key("campaigns/42/source.csv") == 42
|
||||
assert _extract_org_id_from_key("knowledge_base/42/document/file.pdf") == 42
|
||||
assert _extract_legacy_workflow_run_id("campaigns/42/source.csv") is None
|
||||
|
||||
|
||||
def test_unknown_numeric_prefix_is_not_treated_as_org_scoped():
|
||||
assert _extract_org_id_from_key("unknown/42/file.wav") is None
|
||||
|
|
@ -7,6 +7,7 @@ from pipecat.transcriptions.language import Language
|
|||
|
||||
from api.services.configuration.registry import (
|
||||
SarvamLLMConfiguration,
|
||||
SarvamTTSConfiguration,
|
||||
ServiceProviders,
|
||||
)
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
|
@ -14,6 +15,7 @@ from api.services.pipecat.service_factory import (
|
|||
create_llm_service,
|
||||
create_llm_service_from_provider,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -112,3 +114,41 @@ class TestSarvamSTTServiceFactory:
|
|||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].language == expected_language
|
||||
|
||||
|
||||
class TestSarvamTTSServiceFactory:
|
||||
def test_sarvam_tts_configuration_defaults(self):
|
||||
config = SarvamTTSConfiguration(api_key="test-key")
|
||||
|
||||
assert config.provider == ServiceProviders.SARVAM
|
||||
assert config.model == "bulbul:v2"
|
||||
assert config.voice == "anushka"
|
||||
assert config.language == "hi-IN"
|
||||
assert config.speed == 1.0
|
||||
|
||||
def test_create_sarvam_tts_service_maps_speed_to_pace(self):
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.SARVAM.value,
|
||||
api_key="test-key",
|
||||
model="bulbul:v2",
|
||||
voice="anushka",
|
||||
language="hi-IN",
|
||||
speed=1.25,
|
||||
)
|
||||
)
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000, transport_out_sample_rate=16000
|
||||
)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SarvamTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["api_key"] == "test-key"
|
||||
assert kwargs["settings"].model == "bulbul:v2"
|
||||
assert kwargs["settings"].voice == "anushka"
|
||||
assert kwargs["settings"].language == Language.HI
|
||||
assert kwargs["settings"].pace == 1.25
|
||||
|
|
|
|||
80
api/tests/test_smallest_service_factory.py
Normal file
80
api/tests/test_smallest_service_factory.py
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
REGISTRY,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
SmallestAISTTConfiguration,
|
||||
SmallestAITTSConfiguration,
|
||||
)
|
||||
from api.services.pipecat.service_factory import create_tts_service
|
||||
|
||||
|
||||
def test_smallest_tts_configuration_defaults_and_registry():
|
||||
config = SmallestAITTSConfiguration(api_key="test-key")
|
||||
|
||||
assert config.provider == ServiceProviders.SMALLEST
|
||||
assert config.model == "lightning_v3.1"
|
||||
assert config.voice == "sophia"
|
||||
assert config.language == "en"
|
||||
assert config.speed == 1.0
|
||||
assert (
|
||||
REGISTRY[ServiceType.TTS][ServiceProviders.SMALLEST]
|
||||
is SmallestAITTSConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_smallest_stt_configuration_defaults_and_registry():
|
||||
config = SmallestAISTTConfiguration(api_key="test-key")
|
||||
|
||||
assert config.provider == ServiceProviders.SMALLEST
|
||||
assert config.model == "pulse"
|
||||
assert config.language == "en"
|
||||
assert (
|
||||
REGISTRY[ServiceType.STT][ServiceProviders.SMALLEST]
|
||||
is SmallestAISTTConfiguration
|
||||
)
|
||||
|
||||
|
||||
def test_validator_accepts_smallest_services():
|
||||
validator = UserConfigurationValidator()
|
||||
|
||||
assert (
|
||||
validator._validate_service(
|
||||
SmallestAITTSConfiguration(api_key="test-key"),
|
||||
"tts",
|
||||
)
|
||||
== []
|
||||
)
|
||||
assert (
|
||||
validator._validate_service(
|
||||
SmallestAISTTConfiguration(api_key="test-key"),
|
||||
"stt",
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_create_smallest_tts_service_normalizes_hyphenated_model_values():
|
||||
user_config = SimpleNamespace(
|
||||
tts=SimpleNamespace(
|
||||
provider=ServiceProviders.SMALLEST.value,
|
||||
api_key="test-key",
|
||||
model="lightning-v3.1",
|
||||
voice="sophia",
|
||||
language="en",
|
||||
speed=1.0,
|
||||
)
|
||||
)
|
||||
audio_config = SimpleNamespace(transport_in_sample_rate=16000)
|
||||
|
||||
with patch(
|
||||
"api.services.pipecat.service_factory.SmallestTTSService"
|
||||
) as mock_service:
|
||||
create_tts_service(user_config, audio_config)
|
||||
|
||||
assert mock_service.call_count == 1
|
||||
kwargs = mock_service.call_args.kwargs
|
||||
assert kwargs["settings"].model == "lightning_v3.1"
|
||||
35
api/utils/recording_artifacts.py
Normal file
35
api/utils/recording_artifacts.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Literal
|
||||
|
||||
RecordingTrack = Literal["mixed", "user", "bot"]
|
||||
|
||||
|
||||
def get_recording_storage_key(extra: dict | None, track: RecordingTrack) -> str | None:
|
||||
recordings = (extra or {}).get("recordings", {})
|
||||
if not isinstance(recordings, dict):
|
||||
return None
|
||||
|
||||
artifact = recordings.get(track)
|
||||
if isinstance(artifact, str):
|
||||
return artifact
|
||||
if isinstance(artifact, dict):
|
||||
storage_key = artifact.get("storage_key")
|
||||
return storage_key if isinstance(storage_key, str) else None
|
||||
return None
|
||||
|
||||
|
||||
def get_recording_storage_backend(
|
||||
extra: dict | None, track: RecordingTrack
|
||||
) -> str | None:
|
||||
recordings = (extra or {}).get("recordings", {})
|
||||
if not isinstance(recordings, dict):
|
||||
return None
|
||||
|
||||
artifact = recordings.get(track)
|
||||
if isinstance(artifact, dict):
|
||||
storage_backend = artifact.get("storage_backend")
|
||||
return storage_backend if isinstance(storage_backend, str) else None
|
||||
return None
|
||||
|
||||
|
||||
def has_recording_track(extra: dict | None, track: RecordingTrack) -> bool:
|
||||
return bool(get_recording_storage_key(extra, track))
|
||||
Loading…
Add table
Add a link
Reference in a new issue