feat: add hybrid text + recording functionality in agents (#191)

* feat: add recording feature in agents

* chore: pin pipecat version

* feat: show usage in UI

* chore: update pipecat
This commit is contained in:
Abhishek 2026-03-16 15:04:08 +05:30 committed by GitHub
parent f075bcb623
commit 494c60d774
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
43 changed files with 2865 additions and 397 deletions

3
.gitignore vendored
View file

@ -14,4 +14,5 @@ prd/
venv/
.venv/
.playwright-mcp
coturn/
coturn/
dograh_pcm_cache/

View file

@ -0,0 +1,102 @@
"""add recording table
Revision ID: e54ddb048535
Revises: 6fd8fac02883
Create Date: 2026-03-12 19:21:53.729888
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "e54ddb048535"
down_revision: Union[str, None] = "6fd8fac02883"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
sa.Enum("s3", "minio", name="recording_storage_backend").create(op.get_bind())
op.create_table(
"workflow_recordings",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("recording_id", sa.String(length=16), nullable=False),
sa.Column("workflow_id", sa.Integer(), nullable=False),
sa.Column("organization_id", sa.Integer(), nullable=False),
sa.Column("tts_provider", sa.String(), nullable=False),
sa.Column("tts_model", sa.String(), nullable=False),
sa.Column("tts_voice_id", sa.String(), nullable=False),
sa.Column("transcript", sa.Text(), nullable=False),
sa.Column("storage_key", sa.String(), nullable=False),
sa.Column(
"storage_backend",
postgresql.ENUM(
"s3", "minio", name="recording_storage_backend", create_type=False
),
server_default=sa.text("'s3'::recording_storage_backend"),
nullable=False,
),
sa.Column(
"recording_metadata",
sa.JSON(),
server_default=sa.text("'{}'::json"),
nullable=False,
),
sa.Column("created_by", sa.Integer(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["created_by"],
["users.id"],
),
sa.ForeignKeyConstraint(
["organization_id"], ["organizations.id"], ondelete="CASCADE"
),
sa.ForeignKeyConstraint(["workflow_id"], ["workflows.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index(
"ix_workflow_recordings_org_id",
"workflow_recordings",
["organization_id"],
unique=False,
)
op.create_index(
op.f("ix_workflow_recordings_recording_id"),
"workflow_recordings",
["recording_id"],
unique=True,
)
op.create_index(
"ix_workflow_recordings_tts_scope",
"workflow_recordings",
["workflow_id", "tts_provider", "tts_model", "tts_voice_id"],
unique=False,
)
op.create_index(
"ix_workflow_recordings_workflow_id",
"workflow_recordings",
["workflow_id"],
unique=False,
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(
"ix_workflow_recordings_workflow_id", table_name="workflow_recordings"
)
op.drop_index("ix_workflow_recordings_tts_scope", table_name="workflow_recordings")
op.drop_index(
op.f("ix_workflow_recordings_recording_id"), table_name="workflow_recordings"
)
op.drop_index("ix_workflow_recordings_org_id", table_name="workflow_recordings")
op.drop_table("workflow_recordings")
sa.Enum("s3", "minio", name="recording_storage_backend").drop(op.get_bind())
# ### end Alembic commands ###

View file

@ -526,7 +526,7 @@ class CampaignClient(BaseDBClient):
QueuedRunModel.state == "queued",
QueuedRunModel.scheduled_for.is_(None),
)
.order_by(QueuedRunModel.created_at)
.order_by(func.random())
.limit(remaining_slots)
.with_for_update(skip_locked=True)
)

View file

@ -13,6 +13,7 @@ from api.db.tool_client import ToolClient
from api.db.user_client import UserClient
from api.db.webhook_credential_client import WebhookCredentialClient
from api.db.workflow_client import WorkflowClient
from api.db.workflow_recording_client import WorkflowRecordingClient
from api.db.workflow_run_client import WorkflowRunClient
from api.db.workflow_template_client import WorkflowTemplateClient
@ -35,6 +36,7 @@ class DBClient(
WebhookCredentialClient,
ToolClient,
KnowledgeBaseClient,
WorkflowRecordingClient,
):
"""
Unified database client that combines all specialized database operations.

View file

@ -996,6 +996,77 @@ class KnowledgeBaseDocumentModel(Base):
)
class WorkflowRecordingModel(Base):
"""Model for storing audio recordings scoped to a workflow and TTS configuration.
Recordings are used in hybrid prompts where parts of the output are pre-recorded
audio rather than dynamically generated TTS.
"""
__tablename__ = "workflow_recordings"
id = Column(Integer, primary_key=True, index=True)
# Short globally unique ID (e.g. "xbhfha3k") used in prompts
recording_id = Column(String(16), unique=True, nullable=False, index=True)
# Scoping
workflow_id = Column(
Integer, ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False
)
organization_id = Column(
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
)
# TTS configuration scope
tts_provider = Column(String, nullable=False)
tts_model = Column(String, nullable=False)
tts_voice_id = Column(String, nullable=False)
# Content
transcript = Column(Text, nullable=False)
# Storage
storage_key = Column(String, nullable=False)
storage_backend = Column(
Enum("s3", "minio", name="recording_storage_backend"),
nullable=False,
default="s3",
server_default=text("'s3'::recording_storage_backend"),
)
# Extra metadata (file_size_bytes, duration_seconds, original_filename, mime_type, etc.)
recording_metadata = Column(
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
)
# Audit
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
# Soft delete
is_active = Column(Boolean, default=True, nullable=False)
# Relationships
workflow = relationship("WorkflowModel")
organization = relationship("OrganizationModel")
created_by_user = relationship("UserModel")
# Indexes
__table_args__ = (
Index("ix_workflow_recordings_workflow_id", "workflow_id"),
Index("ix_workflow_recordings_org_id", "organization_id"),
Index("ix_workflow_recordings_recording_id", "recording_id"),
Index(
"ix_workflow_recordings_tts_scope",
"workflow_id",
"tts_provider",
"tts_model",
"tts_voice_id",
),
)
class KnowledgeBaseChunkModel(Base):
"""Model for storing document chunks with vector embeddings.

View file

@ -0,0 +1,218 @@
"""Database client for managing workflow recordings."""
import secrets
import string
from typing import List, Optional
from loguru import logger
from sqlalchemy import func, select
from api.db.base_client import BaseDBClient
from api.db.models import WorkflowRecordingModel
def generate_short_id(length: int = 8) -> str:
"""Generate a random lowercase alphanumeric short ID."""
alphabet = string.ascii_lowercase + string.digits
return "".join(secrets.choice(alphabet) for _ in range(length))
class WorkflowRecordingClient(BaseDBClient):
"""Client for managing workflow audio recordings."""
async def create_recording(
self,
recording_id: str,
workflow_id: int,
organization_id: int,
tts_provider: str,
tts_model: str,
tts_voice_id: str,
transcript: str,
storage_key: str,
storage_backend: str,
created_by: int,
metadata: Optional[dict] = None,
) -> WorkflowRecordingModel:
"""Create a new workflow recording record.
Args:
recording_id: Short unique recording identifier
workflow_id: ID of the workflow
organization_id: ID of the organization
tts_provider: TTS provider name
tts_model: TTS model name
tts_voice_id: TTS voice identifier
transcript: User-provided transcript
storage_key: S3/MinIO storage key
storage_backend: Storage backend (s3 or minio)
created_by: ID of the user
metadata: Optional extra metadata
Returns:
The created WorkflowRecordingModel
"""
async with self.async_session() as session:
recording = WorkflowRecordingModel(
recording_id=recording_id,
workflow_id=workflow_id,
organization_id=organization_id,
tts_provider=tts_provider,
tts_model=tts_model,
tts_voice_id=tts_voice_id,
transcript=transcript,
storage_key=storage_key,
storage_backend=storage_backend,
created_by=created_by,
metadata=metadata or {},
)
session.add(recording)
await session.commit()
await session.refresh(recording)
logger.info(
f"Created recording {recording_id} for workflow {workflow_id}, "
f"org {organization_id}"
)
return recording
async def get_recordings_for_workflow(
self,
workflow_id: int,
organization_id: int,
tts_provider: Optional[str] = None,
tts_model: Optional[str] = None,
tts_voice_id: Optional[str] = None,
) -> List[WorkflowRecordingModel]:
"""Get recordings for a workflow, optionally filtered by TTS config.
Args:
workflow_id: ID of the workflow
organization_id: ID of the organization
tts_provider: Optional TTS provider filter
tts_model: Optional TTS model filter
tts_voice_id: Optional TTS voice ID filter
Returns:
List of WorkflowRecordingModel instances
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.workflow_id == workflow_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
if tts_provider:
query = query.where(WorkflowRecordingModel.tts_provider == tts_provider)
if tts_model:
query = query.where(WorkflowRecordingModel.tts_model == tts_model)
if tts_voice_id:
query = query.where(WorkflowRecordingModel.tts_voice_id == tts_voice_id)
query = query.order_by(WorkflowRecordingModel.created_at.desc())
result = await session.execute(query)
return list(result.scalars().all())
async def get_recording_by_recording_id(
self,
recording_id: str,
organization_id: int,
) -> Optional[WorkflowRecordingModel]:
"""Get a recording by its short ID.
Args:
recording_id: The short unique recording ID
organization_id: ID of the organization
Returns:
WorkflowRecordingModel if found, None otherwise
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.recording_id == recording_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def has_active_recordings(
self,
workflow_id: int,
organization_id: int,
) -> bool:
"""Check if a workflow has any active recordings.
Args:
workflow_id: ID of the workflow
organization_id: ID of the organization
Returns:
True if at least one active recording exists, False otherwise
"""
async with self.async_session() as session:
query = (
select(func.count())
.select_from(WorkflowRecordingModel)
.where(
WorkflowRecordingModel.workflow_id == workflow_id,
WorkflowRecordingModel.organization_id == organization_id,
WorkflowRecordingModel.is_active == True,
)
)
result = await session.execute(query)
return result.scalar_one() > 0
async def check_recording_id_exists(self, recording_id: str) -> bool:
"""Check if a recording ID already exists globally.
Args:
recording_id: The short recording ID to check
Returns:
True if exists, False otherwise
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel.id).where(
WorkflowRecordingModel.recording_id == recording_id,
)
result = await session.execute(query)
return result.scalar_one_or_none() is not None
async def delete_recording(
self,
recording_id: str,
organization_id: int,
) -> bool:
"""Soft delete a recording.
Args:
recording_id: The short recording ID
organization_id: ID of the organization
Returns:
True if deleted, False if not found
"""
async with self.async_session() as session:
query = select(WorkflowRecordingModel).where(
WorkflowRecordingModel.recording_id == recording_id,
WorkflowRecordingModel.organization_id == organization_id,
)
result = await session.execute(query)
recording = result.scalar_one_or_none()
if not recording:
return False
recording.is_active = False
await session.commit()
logger.info(
f"Deleted recording {recording_id} for organization {organization_id}"
)
return True

View file

@ -24,6 +24,7 @@ from api.routes.user import router as user_router
from api.routes.webrtc_signaling import router as webrtc_signaling_router
from api.routes.workflow import router as workflow_router
from api.routes.workflow_embed import router as workflow_embed_router
from api.routes.workflow_recording import router as workflow_recording_router
router = APIRouter(
tags=["main"],
@ -51,6 +52,7 @@ router.include_router(public_agent_router)
router.include_router(public_download_router)
router.include_router(workflow_embed_router)
router.include_router(knowledge_base_router)
router.include_router(workflow_recording_router)
router.include_router(auth_router)

View file

@ -3,11 +3,14 @@ from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from pydantic import BaseModel
from api.constants import DEPLOYMENT_MODE
from api.db import db_client
from api.db.models import UserModel
from api.services.auth.depends import get_user
from api.services.mps_service_key_client import mps_service_key_client
router = APIRouter(prefix="/organizations")
@ -28,6 +31,12 @@ class CurrentUsageResponse(BaseModel):
price_per_second_usd: Optional[float] = None
class MPSCreditsResponse(BaseModel):
total_credits_used: float
remaining_credits: float
total_quota: float
class WorkflowRunUsageResponse(BaseModel):
id: int
workflow_id: int
@ -85,6 +94,42 @@ async def get_current_period_usage(user: UserModel = Depends(get_user)):
raise HTTPException(status_code=500, detail=str(e))
@router.get("/usage/mps-credits", response_model=MPSCreditsResponse)
async def get_mps_credits(user: UserModel = Depends(get_user)):
"""Get aggregated usage and quota from MPS.
OSS users: queries by provider_id (created_by).
Hosted users: queries by organization_id.
"""
try:
if DEPLOYMENT_MODE == "oss":
usage = await mps_service_key_client.get_usage_by_created_by(
str(user.provider_id)
)
else:
if not user.selected_organization_id:
raise HTTPException(
status_code=400, detail="No organization selected"
)
usage = await mps_service_key_client.get_usage_by_organization(
user.selected_organization_id
)
total_used = usage.get("total_credits_used", 0.0)
total_remaining = usage.get("remaining_credits", 0.0)
return MPSCreditsResponse(
total_credits_used=total_used,
remaining_credits=total_remaining,
total_quota=total_used + total_remaining,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to fetch MPS credits: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/usage/runs", response_model=UsageHistoryResponse)
async def get_usage_history(
start_date: Optional[str] = Query(None, description="ISO format date string"),

View file

@ -0,0 +1,218 @@
"""API routes for workflow recording operations."""
from typing import Annotated, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from loguru import logger
from api.db import db_client
from api.db.workflow_recording_client import generate_short_id
from api.enums import StorageBackend
from api.schemas.workflow_recording import (
RecordingCreateRequestSchema,
RecordingListResponseSchema,
RecordingResponseSchema,
RecordingUploadRequestSchema,
RecordingUploadResponseSchema,
)
from api.services.auth.depends import get_user
from api.services.storage import storage_fs
router = APIRouter(prefix="/workflow-recordings", tags=["workflow-recordings"])
async def _generate_unique_recording_id() -> str:
"""Generate a globally unique short recording ID."""
for _ in range(10):
rid = generate_short_id(8)
exists = await db_client.check_recording_id_exists(rid)
if not exists:
return rid
raise HTTPException(
status_code=500, detail="Failed to generate unique recording ID"
)
def _build_response(rec) -> RecordingResponseSchema:
return RecordingResponseSchema(
id=rec.id,
recording_id=rec.recording_id,
workflow_id=rec.workflow_id,
organization_id=rec.organization_id,
tts_provider=rec.tts_provider,
tts_model=rec.tts_model,
tts_voice_id=rec.tts_voice_id,
transcript=rec.transcript,
storage_key=rec.storage_key,
storage_backend=rec.storage_backend,
metadata=rec.recording_metadata or {},
created_by=rec.created_by,
created_at=rec.created_at,
is_active=rec.is_active,
)
@router.post(
"/upload-url",
response_model=RecordingUploadResponseSchema,
summary="Get presigned URL for recording upload",
)
async def get_upload_url(
request: RecordingUploadRequestSchema,
user=Depends(get_user),
):
"""Generate a presigned PUT URL for uploading an audio recording."""
try:
recording_id = await _generate_unique_recording_id()
storage_key = (
f"recordings/{user.selected_organization_id}"
f"/{request.workflow_id}/{recording_id}"
f"/{request.filename}"
)
upload_url = await storage_fs.aget_presigned_put_url(
file_path=storage_key,
expiration=1800, # 30 minutes
content_type=request.mime_type,
max_size=5_242_880, # 5MB max
)
if not upload_url:
raise HTTPException(
status_code=500, detail="Failed to generate presigned upload URL"
)
logger.info(
f"Generated recording upload URL: {recording_id}, "
f"workflow {request.workflow_id}, org {user.selected_organization_id}"
)
return RecordingUploadResponseSchema(
upload_url=upload_url,
recording_id=recording_id,
storage_key=storage_key,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error generating recording upload URL: {exc}")
raise HTTPException(
status_code=500, detail="Failed to generate upload URL"
) from exc
@router.post(
"/",
response_model=RecordingResponseSchema,
summary="Create recording record after upload",
)
async def create_recording(
request: RecordingCreateRequestSchema,
user=Depends(get_user),
):
"""Create a recording record after the audio has been uploaded to storage."""
try:
backend = StorageBackend.get_current_backend()
recording = await db_client.create_recording(
recording_id=request.recording_id,
workflow_id=request.workflow_id,
organization_id=user.selected_organization_id,
tts_provider=request.tts_provider,
tts_model=request.tts_model,
tts_voice_id=request.tts_voice_id,
transcript=request.transcript,
storage_key=request.storage_key,
storage_backend=backend.value,
created_by=user.id,
metadata=request.metadata,
)
logger.info(
f"Created recording {request.recording_id} for workflow {request.workflow_id}"
)
return _build_response(recording)
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error creating recording: {exc}")
raise HTTPException(
status_code=500, detail="Failed to create recording"
) from exc
@router.get(
"/",
response_model=RecordingListResponseSchema,
summary="List recordings for a workflow",
)
async def list_recordings(
workflow_id: Annotated[int, Query(description="Workflow ID")],
tts_provider: Annotated[
Optional[str], Query(description="Filter by TTS provider")
] = None,
tts_model: Annotated[
Optional[str], Query(description="Filter by TTS model")
] = None,
tts_voice_id: Annotated[
Optional[str], Query(description="Filter by TTS voice ID")
] = None,
user=Depends(get_user),
):
"""List recordings for a workflow, optionally filtered by TTS configuration."""
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id=workflow_id,
organization_id=user.selected_organization_id,
tts_provider=tts_provider,
tts_model=tts_model,
tts_voice_id=tts_voice_id,
)
return RecordingListResponseSchema(
recordings=[_build_response(r) for r in recordings],
total=len(recordings),
)
except Exception as exc:
logger.error(f"Error listing recordings: {exc}")
raise HTTPException(
status_code=500, detail="Failed to list recordings"
) from exc
@router.delete(
"/{recording_id}",
summary="Delete a recording",
)
async def delete_recording(
recording_id: str,
user=Depends(get_user),
):
"""Soft delete a recording."""
try:
success = await db_client.delete_recording(
recording_id=recording_id,
organization_id=user.selected_organization_id,
)
if not success:
raise HTTPException(status_code=404, detail="Recording not found")
logger.info(
f"Deleted recording {recording_id}, org {user.selected_organization_id}"
)
return {"success": True, "message": "Recording deleted successfully"}
except HTTPException:
raise
except Exception as exc:
logger.error(f"Error deleting recording: {exc}")
raise HTTPException(
status_code=500, detail="Failed to delete recording"
) from exc

View file

@ -0,0 +1,73 @@
"""Pydantic schemas for workflow recording operations."""
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class RecordingUploadRequestSchema(BaseModel):
"""Request schema for getting a presigned upload URL."""
workflow_id: int = Field(..., description="Workflow ID this recording belongs to")
filename: str = Field(..., description="Original filename of the audio file")
mime_type: str = Field(
default="audio/wav", description="MIME type of the audio file"
)
file_size: int = Field(
...,
gt=0,
le=5_242_880,
description="File size in bytes (max 5MB)",
)
class RecordingUploadResponseSchema(BaseModel):
"""Response schema with presigned upload URL."""
upload_url: str = Field(..., description="Presigned URL for uploading the audio")
recording_id: str = Field(..., description="Short unique recording ID")
storage_key: str = Field(..., description="Storage key where file will be uploaded")
class RecordingCreateRequestSchema(BaseModel):
"""Request schema for creating a recording record after upload."""
recording_id: str = Field(..., description="Short recording ID from upload step")
workflow_id: int = Field(..., description="Workflow ID")
tts_provider: str = Field(..., description="TTS provider (e.g. elevenlabs)")
tts_model: str = Field(..., description="TTS model name")
tts_voice_id: str = Field(..., description="TTS voice identifier")
transcript: str = Field(
..., description="User-provided transcript of the recording"
)
storage_key: str = Field(..., description="Storage key from upload step")
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Optional metadata (file_size, duration, etc.)"
)
class RecordingResponseSchema(BaseModel):
"""Response schema for a single recording."""
id: int
recording_id: str
workflow_id: int
organization_id: int
tts_provider: str
tts_model: str
tts_voice_id: str
transcript: str
storage_key: str
storage_backend: str
metadata: Dict[str, Any]
created_by: int
created_at: datetime
is_active: bool
class RecordingListResponseSchema(BaseModel):
"""Response schema for list of recordings."""
recordings: List[RecordingResponseSchema]
total: int

View file

@ -1,4 +1,4 @@
from typing import Dict, Optional, TypedDict
from typing import Optional, TypedDict
import openai
from deepgram import DeepgramClient
@ -12,6 +12,7 @@ from api.schemas.user_configuration import (
UserConfiguration,
)
from api.services.configuration.registry import ServiceConfig, ServiceProviders
from api.services.mps_service_key_client import mps_service_key_client
class APIKeyStatus(TypedDict):
@ -25,7 +26,6 @@ class APIKeyStatusResponse(TypedDict):
class UserConfigurationValidator:
def __init__(self):
self._provider_api_key_validity_status: Dict[str, bool] = {}
self._validator_map = {
ServiceProviders.OPENAI.value: self._check_openai_api_key,
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
@ -73,8 +73,13 @@ class UserConfigurationValidator:
provider = service_config.provider
api_key = service_config.api_key
if not self._check_api_key(provider, api_key):
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
try:
if not self._check_api_key(provider, api_key):
return [
{"model": service_name, "message": f"Invalid {provider} API key"}
]
except ValueError as e:
return [{"model": service_name, "message": str(e)}]
return []
@ -87,40 +92,28 @@ class UserConfigurationValidator:
return validator(provider, api_key)
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = openai.OpenAI(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
return True
except openai.AuthenticationError:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
try:
deepgram = DeepgramClient(api_key=api_key)
deepgram.manage.v1.projects.list()
self._provider_api_key_validity_status[model] = True
return True
except Exception:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
if model in self._provider_api_key_validity_status:
return self._provider_api_key_validity_status[model]
client = Groq(api_key=api_key)
try:
client.models.list()
self._provider_api_key_validity_status[model] = True
return True
except Exception:
self._provider_api_key_validity_status[model] = False
return self._provider_api_key_validity_status[model]
return False
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
return True
@ -135,7 +128,12 @@ class UserConfigurationValidator:
return True
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
return True
if api_key.startswith("dgr"):
raise ValueError(
"You provided a Dograh API key (dgr...) instead of a service key. "
"Please use a service key (mps...)."
)
return mps_service_key_client.validate_service_key(api_key)
def _check_sarvam_api_key(self, model: str, api_key: str) -> bool:
return True

View file

@ -285,6 +285,90 @@ class MPSServiceKeyClient:
response=response,
)
async def get_usage_by_created_by(self, created_by: str) -> dict:
"""
Get aggregated usage for all service keys created by a user (OSS mode).
Args:
created_by: The user's provider ID
Returns:
Dictionary containing total_credits_used and remaining_credits
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/usage/created-by",
json={"created_by": created_by},
headers=self._get_headers(created_by=created_by),
)
if response.status_code == 200:
data = response.json()
return {
"total_credits_used": data.get("total_credits_used", 0.0),
"remaining_credits": data.get("remaining_credits", 0.0),
}
else:
logger.error(
f"Failed to get usage by created_by: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get usage by created_by: {response.text}",
request=response.request,
response=response,
)
async def get_usage_by_organization(self, organization_id: int) -> dict:
"""
Get aggregated usage for all service keys belonging to an organization (hosted mode).
Args:
organization_id: The organization's ID
Returns:
Dictionary containing total_credits_used and remaining_credits
"""
async with httpx.AsyncClient(timeout=self.timeout) as client:
response = await client.post(
f"{self.base_url}/api/v1/service-keys/usage/organization",
json={"organization_id": organization_id},
headers=self._get_headers(organization_id=organization_id),
)
if response.status_code == 200:
data = response.json()
return {
"total_credits_used": data.get("total_credits_used", 0.0),
"remaining_credits": data.get("remaining_credits", 0.0),
}
else:
logger.error(
f"Failed to get usage by organization: {response.status_code} - {response.text}"
)
raise httpx.HTTPStatusError(
f"Failed to get usage by organization: {response.text}",
request=response.request,
response=response,
)
def validate_service_key(self, service_key: str) -> bool:
"""
Synchronously validate a Dograh service key by checking usage via MPS.
Returns True if the key is valid, False otherwise.
"""
try:
with httpx.Client(timeout=self.timeout) as client:
response = client.post(
f"{self.base_url}/api/v1/service-keys/usage",
json={"service_key": service_key},
headers=self._get_headers(),
)
return response.status_code == 200
except Exception:
logger.warning("Failed to validate Dograh service key via MPS")
return False
async def get_voices(
self,
provider: str,

View file

@ -39,6 +39,7 @@ def build_pipeline(
pipeline_engine_callback_processor,
pipeline_metrics_aggregator,
voicemail_detector=None,
recording_router=None,
):
"""Build the main pipeline with all components.
@ -47,6 +48,9 @@ def build_pipeline(
voicemail_detector: Optional native pipecat VoicemailDetector. When provided,
inserts voicemail detection after STT. Note: We don't use the TTS gate
to avoid blocking TTS frames during classification.
recording_router: Optional RecordingRouterProcessor. When provided,
inserts between callback processor and TTS to route between
pre-recorded audio playback and dynamic TTS.
"""
# Build processors list with optional voicemail detection
processors = [
@ -66,11 +70,15 @@ def build_pipeline(
processors.append(voicemail_detector.detector())
# Continue with the rest of the pipeline
post_llm = [pipeline_engine_callback_processor]
if recording_router:
post_llm.append(recording_router)
processors.extend(
[
user_context_aggregator,
llm, # LLM
pipeline_engine_callback_processor,
*post_llm,
tts, # TTS
transport.output(), # Transport bot output
audio_buffer, # AudioBufferProcessor - records both input and output audio

View file

@ -37,10 +37,10 @@ from pipecat.frames.frames import (
FunctionCallResultFrame,
InterimTranscriptionFrame,
InterruptionFrame,
LLMTextFrame,
MetricsFrame,
StopFrame,
TranscriptionFrame,
TTSTextFrame,
)
from pipecat.metrics.metrics import TTFBMetricsData
from pipecat.observers.base_observer import BaseObserver, FramePushed
@ -207,7 +207,7 @@ class RealtimeFeedbackObserver(BaseObserver):
)
# Handle bot TTS text - respect pts timing, WebSocket only
# Complete turn text is persisted via register_turn_handlers
elif isinstance(frame, TTSTextFrame):
elif isinstance(frame, LLMTextFrame):
message = {
"type": RealtimeFeedbackType.BOT_TEXT.value,
"payload": {

View file

@ -0,0 +1,338 @@
"""Filesystem-backed cache and audio fetcher for workflow recordings.
Downloads recording files from object storage on first access, converts them
to raw 16-bit mono PCM at the pipeline sample rate via ffmpeg, trims
leading/trailing silence, and caches the processed bytes on disk so
subsequent plays (even from other workers) are instantaneous.
"""
import asyncio
import os
import shutil
import tempfile
from typing import Awaitable, Callable, Optional
import numpy as np
from loguru import logger
from api.constants import APP_ROOT_DIR
from pipecat.audio.utils import SPEAKING_THRESHOLD
# ---------------------------------------------------------------------------
# Filesystem cache directory
# ---------------------------------------------------------------------------
_CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
os.makedirs(_CACHE_DIR, exist_ok=True)
def _cache_path(recording_id: str, sample_rate: int) -> str:
"""Return the on-disk path for a cached PCM file."""
return os.path.join(_CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
# ---------------------------------------------------------------------------
# Public factory
# ---------------------------------------------------------------------------
def create_recording_audio_fetcher(
organization_id: int,
pipeline_sample_rate: int,
) -> Callable[[str], Awaitable[Optional[bytes]]]:
"""Create an async callback that returns raw PCM bytes for a recording_id.
The returned callable:
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
2. On miss, looks up the recording in the DB, downloads the audio file
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
trims leading/trailing silence, caches the result on disk, and returns it.
Args:
organization_id: Organization owning the recordings.
pipeline_sample_rate: Target PCM sample rate for the pipeline.
Returns:
``async (recording_id: str) -> Optional[bytes]``
"""
from api.db import db_client
from api.services.storage import get_storage_for_backend
# Resolve storage instances once per backend at creation time, not per fetch.
_storage_cache: dict[str, object] = {}
def _get_storage(backend: str):
if backend not in _storage_cache:
_storage_cache[backend] = get_storage_for_backend(backend)
return _storage_cache[backend]
async def fetch(recording_id: str) -> Optional[bytes]:
cached = _cache_path(recording_id, pipeline_sample_rate)
# 1. Serve from filesystem cache
if os.path.exists(cached):
logger.debug(f"Recording {recording_id} served from disk cache")
return _read_file(cached)
# 2. DB lookup
recording = await db_client.get_recording_by_recording_id(
recording_id, organization_id
)
if not recording:
logger.warning(f"Recording {recording_id} not found in database")
return None
# 3. Download, convert, trim, and cache
pcm_data = await _download_and_convert(
recording, pipeline_sample_rate, _get_storage
)
return pcm_data
return fetch
# ---------------------------------------------------------------------------
# Cache warming
# ---------------------------------------------------------------------------
async def warm_recording_cache(
workflow_id: int,
organization_id: int,
pipeline_sample_rate: int,
) -> None:
"""Pre-fetch all active recordings for a workflow into the disk cache.
Launched as a background ``asyncio.Task`` at pipeline startup so that
recordings are ready before the first playback request. Errors are logged
but never propagated a cache miss falls back to the on-demand fetch path.
"""
from api.db import db_client
from api.services.storage import get_storage_for_backend
try:
recordings = await db_client.get_recordings_for_workflow(
workflow_id, organization_id
)
if not recordings:
return
# Skip if every recording is already cached on disk
uncached = [
r
for r in recordings
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
]
if not uncached:
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
return
logger.info(
f"Warming recording cache: {len(uncached)}/{len(recordings)} "
f"recording(s) for workflow {workflow_id}"
)
# Resolve storage instances once per backend, not per recording
storage_by_backend: dict[str, object] = {}
def _get_storage(backend: str):
if backend not in storage_by_backend:
storage_by_backend[backend] = get_storage_for_backend(backend)
return storage_by_backend[backend]
for recording in uncached:
try:
pcm_data = await _download_and_convert(
recording, pipeline_sample_rate, _get_storage
)
if pcm_data:
logger.debug(
f"Cache warm: loaded {recording.recording_id} "
f"({len(pcm_data)} bytes)"
)
except Exception:
logger.exception(
f"Cache warm: error processing {recording.recording_id}"
)
logger.info(f"Recording cache warm complete for workflow {workflow_id}")
except Exception:
logger.exception("Recording cache warm failed")
# ---------------------------------------------------------------------------
# Shared download → convert → trim → cache-to-disk helper
# ---------------------------------------------------------------------------
async def _download_and_convert(
recording, sample_rate: int, get_storage_fn
) -> Optional[bytes]:
"""Download a recording from storage, convert to PCM, trim, and cache to disk.
Returns the processed PCM bytes, or None on failure.
"""
ext = _ext_from_key(recording.storage_key)
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
os.close(fd)
try:
storage = get_storage_fn(recording.storage_backend)
success = await storage.adownload_file(recording.storage_key, tmp_path)
if not success:
logger.error(f"Failed to download recording {recording.recording_id}")
return None
pcm_data = await _audio_file_to_pcm(tmp_path, sample_rate)
if pcm_data is None:
return None
pcm_data = _trim_silence(pcm_data, sample_rate)
# Write to disk cache atomically (write to tmp then rename)
cached = _cache_path(recording.recording_id, sample_rate)
fd, tmp_cache = tempfile.mkstemp(dir=_CACHE_DIR, suffix=".pcm.tmp")
os.close(fd)
_write_file(tmp_cache, pcm_data)
os.replace(tmp_cache, cached)
return pcm_data
except Exception:
logger.exception(f"Error fetching recording {recording.recording_id}")
return None
finally:
if os.path.exists(tmp_path):
try:
os.unlink(tmp_path)
except OSError:
pass
# ---------------------------------------------------------------------------
# File I/O helpers (run via asyncio.to_thread)
# ---------------------------------------------------------------------------
def _read_file(path: str) -> bytes:
with open(path, "rb") as f:
return f.read()
def _write_file(path: str, data: bytes) -> None:
with open(path, "wb") as f:
f.write(data)
# ---------------------------------------------------------------------------
# Audio conversion
# ---------------------------------------------------------------------------
async def _audio_file_to_pcm(
file_path: str, target_sample_rate: int
) -> Optional[bytes]:
"""Convert an audio file to raw 16-bit mono PCM bytes via ffmpeg."""
ffmpeg = shutil.which("ffmpeg")
if not ffmpeg:
logger.error("ffmpeg not found on PATH — cannot decode recording")
return None
cmd = [
ffmpeg,
"-i",
file_path,
"-f",
"s16le", # raw 16-bit signed little-endian PCM
"-acodec",
"pcm_s16le",
"-ac",
"1", # mono
"-ar",
str(target_sample_rate),
"-loglevel",
"error",
"pipe:1", # output to stdout
]
try:
proc = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await proc.communicate()
if proc.returncode != 0:
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
return None
if not stdout:
logger.error("ffmpeg produced no output")
return None
return stdout
except Exception:
logger.exception("ffmpeg subprocess error")
return None
# ---------------------------------------------------------------------------
# Silence trimming
# ---------------------------------------------------------------------------
def _trim_silence(pcm_data: bytes, sample_rate: int) -> bytes:
"""Trim leading and trailing silence from raw 16-bit mono PCM bytes.
Uses 10ms frames and the same amplitude threshold as pipecat's
``is_silence`` to detect speech boundaries.
"""
data = np.frombuffer(pcm_data, dtype=np.int16)
frame_size = int(sample_rate * 0.01) # 10ms frames
num_frames = len(data) // frame_size
if num_frames == 0:
return pcm_data
# Find first non-silent frame
first_speech = None
for i in range(num_frames):
frame = data[i * frame_size : (i + 1) * frame_size]
if np.abs(frame).max() > SPEAKING_THRESHOLD:
first_speech = i
break
if first_speech is None:
# Entire clip is silence — return as-is to avoid empty audio
return pcm_data
# Find last non-silent frame
last_speech = first_speech
for i in range(num_frames - 1, first_speech - 1, -1):
frame = data[i * frame_size : (i + 1) * frame_size]
if np.abs(frame).max() > SPEAKING_THRESHOLD:
last_speech = i
break
start = first_speech * frame_size
end = (last_speech + 1) * frame_size
trimmed = data[start:end]
trimmed_duration = len(trimmed) / sample_rate
original_duration = len(data) / sample_rate
if original_duration - trimmed_duration > 0.05:
logger.debug(
f"Trimmed silence: {original_duration:.2f}s → {trimmed_duration:.2f}s"
)
return trimmed.tobytes()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _ext_from_key(storage_key: str) -> str:
"""Extract file extension from a storage key, defaulting to .wav."""
_, ext = os.path.splitext(storage_key)
return ext if ext else ".wav"

View file

@ -0,0 +1,254 @@
"""Recording router processor for routing LLM output between TTS and pre-recorded audio.
Sits between the LLM (after pipeline_engine_callbacks_processor) and TTS in the
pipeline. Detects response mode markers ( for TTS, for recording) and routes
accordingly:
- (TTS): Strips the marker, passes remaining text downstream to TTS.
- (Recording): Suppresses TTS, fetches cached audio, pushes
OutputAudioRawFrame downstream.
Pattern modelled after ``pipecat.turns.user_turn_completion_mixin`` buffer
streaming LLM text tokens until the mode marker is detected, then act.
"""
import uuid
from typing import Awaitable, Callable, Optional
from loguru import logger
from api.services.workflow.pipecat_engine_context_composer import (
RECORDING_MARKER,
TTS_MARKER,
)
from pipecat.frames.frames import (
Frame,
InterruptionFrame,
LLMFullResponseEndFrame,
LLMTextFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
class RecordingRouterProcessor(FrameProcessor):
"""Routes LLM responses between TTS and pre-recorded audio playback.
When the LLM prefixes its response with:
- ```` text flows to TTS as normal speech.
- ```` text is suppressed (skip_tts), and the referenced recording is
fetched (with local disk cache) and streamed as ``OutputAudioRawFrame``.
If no marker is detected by the end of the response, text is passed through
to TTS as a graceful degradation.
Args:
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
fetch_recording_audio: Async callback that takes a recording_id and
returns raw 16-bit mono PCM bytes, or None on failure.
"""
def __init__(
self,
*,
audio_sample_rate: int,
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
**kwargs,
):
super().__init__(**kwargs)
self._audio_sample_rate = audio_sample_rate
self._fetch_recording_audio = fetch_recording_audio
# Per-response state
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
self._recording_id_buffer = ""
# ------------------------------------------------------------------
# Frame dispatch
# ------------------------------------------------------------------
async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)
if isinstance(frame, InterruptionFrame):
self._reset()
await self.push_frame(frame, direction)
elif isinstance(frame, LLMTextFrame):
await self._handle_llm_text(frame, direction)
elif isinstance(frame, LLMFullResponseEndFrame):
await self._handle_response_end(frame, direction)
else:
await self.push_frame(frame, direction)
# ------------------------------------------------------------------
# LLMTextFrame handling
# ------------------------------------------------------------------
async def _handle_llm_text(self, frame: LLMTextFrame, direction: FrameDirection):
# Pass through frames already marked skip_tts (e.g. turn completion ✓)
if frame.skip_tts:
await self.push_frame(frame, direction)
return
# --- TTS mode established: pass text through normally ---
if self._mode == "tts":
await self.push_frame(frame, direction)
return
# --- Recording mode: buffer recording_id, suppress TTS ---
if self._mode == "recording":
self._recording_id_buffer += frame.text
frame.skip_tts = True
await self.push_frame(frame, direction)
return
# --- Detection mode: buffer until marker found ---
self._frame_buffer.append((frame, direction))
buffered_text = self._buffered_text()
# Check for recording marker (●)
if RECORDING_MARKER in buffered_text:
self._mode = "recording"
marker_end = buffered_text.index(RECORDING_MARKER) + len(RECORDING_MARKER)
# Push buffered frames with skip_tts, extract recording_id from post-marker text
cumulative = 0
for buf_frame, buf_dir in self._frame_buffer:
buf_frame.skip_tts = True
frame_start = cumulative
cumulative += len(buf_frame.text)
await self.push_frame(buf_frame, buf_dir)
# Capture any recording_id text after the marker
if cumulative > marker_end:
offset = max(marker_end - frame_start, 0)
remaining = buf_frame.text[offset:]
if not self._recording_id_buffer and remaining.startswith(" "):
remaining = remaining[1:]
self._recording_id_buffer += remaining
self._frame_buffer = []
return
# Check for TTS marker (▸)
if TTS_MARKER in buffered_text:
self._mode = "tts"
marker_end = buffered_text.index(TTS_MARKER) + len(TTS_MARKER)
# Push buffered frames — skip_tts for marker portion, normal for the rest
cumulative = 0
for buf_frame, buf_dir in self._frame_buffer:
frame_start = cumulative
cumulative += len(buf_frame.text)
if cumulative <= marker_end:
# Entirely within marker portion — suppress TTS
buf_frame.skip_tts = True
await self.push_frame(buf_frame, buf_dir)
elif frame_start >= marker_end:
# Entirely after marker — normal TTS speech
if frame_start == marker_end and buf_frame.text.startswith(" "):
buf_frame.text = buf_frame.text[1:]
if buf_frame.text:
await self.push_frame(buf_frame, buf_dir)
else:
# Frame spans the marker boundary — split
offset = marker_end - frame_start
original_text = buf_frame.text
buf_frame.text = original_text[:offset]
buf_frame.skip_tts = True
await self.push_frame(buf_frame, buf_dir)
tts_text = original_text[offset:]
if tts_text.startswith(" "):
tts_text = tts_text[1:]
if tts_text:
await self.push_frame(LLMTextFrame(tts_text), buf_dir)
self._frame_buffer = []
return
# Neither marker found yet — keep buffering (should arrive very soon)
# ------------------------------------------------------------------
# End-of-response handling
# ------------------------------------------------------------------
async def _handle_response_end(
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
):
if self._mode == "recording":
recording_id = self._recording_id_buffer.strip()
if recording_id:
await self._play_recording(recording_id)
else:
logger.warning(
"RecordingRouterProcessor: recording mode but empty recording_id"
)
elif self._mode is None and self._frame_buffer:
# Graceful degradation: no marker detected, pass text to TTS as-is
logger.warning(
"RecordingRouterProcessor: no response mode marker found, "
"passing text to TTS as-is"
)
for buf_frame, buf_dir in self._frame_buffer:
await self.push_frame(buf_frame, buf_dir)
self._reset()
await self.push_frame(frame, direction)
# ------------------------------------------------------------------
# Audio playback
# ------------------------------------------------------------------
async def _play_recording(self, recording_id: str):
"""Fetch recording audio and push TTSStarted → TTSAudioRaw → TTSStopped.
The transport handles chunking automatically. The Started/Stopped
frames ensure downstream processors (transport, audio buffer, observers)
treat this as a proper TTS utterance.
"""
logger.info(f"Playing pre-recorded audio: {recording_id}")
audio_data = await self._fetch_recording_audio(recording_id)
if not audio_data:
logger.warning(
f"Failed to fetch recording {recording_id}, no audio will play"
)
return
context_id = str(uuid.uuid4())
await self.push_frame(TTSStartedFrame(context_id=context_id))
await self.push_frame(
TTSAudioRawFrame(
audio=audio_data,
sample_rate=self._audio_sample_rate,
num_channels=1,
context_id=context_id,
)
)
await self.push_frame(TTSStoppedFrame(context_id=context_id))
duration_secs = len(audio_data) / (self._audio_sample_rate * 2)
logger.debug(
f"Finished pushing recording {recording_id} "
f"({len(audio_data)} bytes, {duration_secs:.1f}s)"
)
# ------------------------------------------------------------------
# State management
# ------------------------------------------------------------------
def _buffered_text(self) -> str:
"""Return concatenated text from the frame buffer."""
return "".join(f.text for f, _ in self._frame_buffer)
def _reset(self):
"""Reset per-response state."""
self._frame_buffer = []
self._mode = None
self._recording_id_buffer = ""

View file

@ -27,6 +27,11 @@ from api.services.pipecat.realtime_feedback_observer import (
RealtimeFeedbackObserver,
register_turn_log_handlers,
)
from api.services.pipecat.recording_audio_cache import (
create_recording_audio_fetcher,
warm_recording_cache,
)
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
from api.services.pipecat.service_factory import (
create_llm_service,
create_stt_service,
@ -558,6 +563,12 @@ async def _run_pipeline(
embeddings_model = user_config.embeddings.model
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
# Check if the workflow has any active recordings so the engine can
# include recording response mode instructions in all node prompts.
has_recordings = await db_client.has_active_recordings(
workflow_id, workflow.organization_id
)
engine = PipecatEngine(
llm=llm,
workflow=workflow_graph,
@ -567,6 +578,7 @@ async def _run_pipeline(
embeddings_api_key=embeddings_api_key,
embeddings_model=embeddings_model,
embeddings_base_url=embeddings_base_url,
has_recordings=has_recordings,
)
# Create pipeline components
@ -680,6 +692,27 @@ async def _run_pipeline(
abort_immediately=True,
)
# Create recording router if workflow has active recordings
recording_router = None
if has_recordings:
fetch_audio = create_recording_audio_fetcher(
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
recording_router = RecordingRouterProcessor(
audio_sample_rate=audio_config.pipeline_sample_rate,
fetch_recording_audio=fetch_audio,
)
# Warm the recording cache in the background so audio is ready
# before the first playback request.
asyncio.create_task(
warm_recording_cache(
workflow_id=workflow_id,
organization_id=workflow.organization_id,
pipeline_sample_rate=audio_config.pipeline_sample_rate,
)
)
# Build the pipeline with the STT mute filter and context controller
pipeline = build_pipeline(
transport,
@ -692,6 +725,7 @@ async def _run_pipeline(
pipeline_engine_callback_processor,
pipeline_metrics_aggregator,
voicemail_detector=voicemail_detector,
recording_router=recording_router,
)
# Create pipeline task with audio configuration

View file

@ -5,25 +5,32 @@ from loguru import logger
from api.constants import MPS_API_URL
from api.services.configuration.registry import ServiceProviders
from pipecat.services.azure.llm import AzureLLMService
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
from pipecat.services.cartesia.stt import CartesiaSTTService
from pipecat.services.cartesia.tts import CartesiaTTSService
from pipecat.services.deepgram.flux.stt import DeepgramFluxSTTService
from pipecat.services.deepgram.stt import DeepgramSTTService, LiveOptions
from pipecat.services.deepgram.tts import DeepgramTTSService
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
from pipecat.services.deepgram.flux.stt import (
DeepgramFluxSTTService,
DeepgramFluxSTTSettings,
)
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
from pipecat.services.deepgram.tts import DeepgramTTSService, DeepgramTTSSettings
from pipecat.services.dograh.llm import DograhLLMService
from pipecat.services.dograh.stt import DograhSTTService
from pipecat.services.dograh.tts import DograhTTSService
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
from pipecat.services.google.llm import GoogleLLMService
from pipecat.services.groq.llm import GroqLLMService
from pipecat.services.dograh.stt import DograhSTTService, DograhSTTSettings
from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
from pipecat.services.openai.base_llm import OpenAILLMSettings
from pipecat.services.openai.llm import OpenAILLMService
from pipecat.services.openai.stt import OpenAISTTService
from pipecat.services.openai.tts import OpenAITTSService
from pipecat.services.openrouter.llm import OpenRouterLLMService
from pipecat.services.sarvam.stt import SarvamSTTService
from pipecat.services.sarvam.tts import SarvamTTSService
from pipecat.services.speechmatics.stt import SpeechmaticsSTTService
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
from pipecat.services.speechmatics.stt import (
SpeechmaticsSTTService,
SpeechmaticsSTTSettings,
)
from pipecat.transcriptions.language import Language
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
@ -49,8 +56,8 @@ def create_stt_service(
logger.debug("Using DeepGram Flux Model")
return DeepgramFluxSTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model,
params=DeepgramFluxSTTService.InputParams(
settings=DeepgramFluxSTTSettings(
model=user_config.stt.model,
eot_timeout_ms=3000,
eot_threshold=0.7,
eager_eot_threshold=0.5,
@ -63,23 +70,23 @@ def create_stt_service(
# Other models than flux
# Use language from user config, defaulting to "multi" for multilingual support
language = getattr(user_config.stt, "language", None) or "multi"
live_options = LiveOptions(
language=language,
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
)
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
return DeepgramSTTService(
live_options=live_options,
api_key=user_config.stt.api_key,
settings=DeepgramSTTSettings(
language=language,
profanity_filter=False,
endpointing=100,
model=user_config.stt.model,
keyterm=keyterms or [],
),
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
return OpenAISTTService(
api_key=user_config.stt.api_key, model=user_config.stt.model
api_key=user_config.stt.api_key,
settings=OpenAISTTSettings(model=user_config.stt.model),
)
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
return CartesiaSTTService(
@ -92,8 +99,10 @@ def create_stt_service(
return DograhSTTService(
base_url=base_url,
api_key=user_config.stt.api_key,
model=user_config.stt.model,
language=language,
settings=DograhSTTSettings(
model=user_config.stt.model,
language=language,
),
keyterms=keyterms,
sample_rate=audio_config.transport_in_sample_rate,
)
@ -117,8 +126,10 @@ def create_stt_service(
pipecat_language = language_mapping.get(language, Language.HI_IN)
return SarvamSTTService(
api_key=user_config.stt.api_key,
model=user_config.stt.model,
params=SarvamSTTService.InputParams(language=pipecat_language),
settings=SarvamSTTSettings(
model=user_config.stt.model,
language=pipecat_language,
),
sample_rate=audio_config.transport_in_sample_rate,
)
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
@ -140,7 +151,7 @@ def create_stt_service(
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
return SpeechmaticsSTTService(
api_key=user_config.stt.api_key,
params=SpeechmaticsSTTService.InputParams(
settings=SpeechmaticsSTTSettings(
language=language,
operating_point=operating_point,
additional_vocab=additional_vocab,
@ -168,14 +179,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
return DeepgramTTSService(
api_key=user_config.tts.api_key,
voice=user_config.tts.voice,
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
return OpenAITTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model,
settings=OpenAITTSSettings(model=user_config.tts.model),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
# Backward compatible with older configuration "Name - voice_id"
@ -186,19 +199,25 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return ElevenLabsTTSService(
reconnect_on_error=False,
api_key=user_config.tts.api_key,
voice_id=voice_id,
model=user_config.tts.model,
params=ElevenLabsTTSService.InputParams(
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
settings=ElevenLabsTTSSettings(
voice=voice_id,
model=user_config.tts.model,
stability=0.8,
speed=user_config.tts.speed,
similarity_boost=0.75,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
return CartesiaTTSService(
api_key=user_config.tts.api_key,
voice_id=user_config.tts.voice,
model=user_config.tts.model,
settings=CartesiaTTSSettings(
voice=user_config.tts.voice,
model=user_config.tts.model,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
# Convert HTTP URL to WebSocket URL for TTS
@ -206,10 +225,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
return DograhTTSService(
base_url=base_url,
api_key=user_config.tts.api_key,
model=user_config.tts.model,
voice=user_config.tts.voice,
params=DograhTTSService.InputParams(speed=user_config.tts.speed),
settings=DograhTTSSettings(
model=user_config.tts.model,
voice=user_config.tts.voice,
speed=user_config.tts.speed,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
# Map Sarvam language code to pipecat Language enum for TTS
@ -232,10 +254,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
voice = getattr(user_config.tts, "voice", None) or "anushka"
return SarvamTTSService(
api_key=user_config.tts.api_key,
model=user_config.tts.model,
voice_id=voice,
params=SarvamTTSService.InputParams(language=pipecat_language),
settings=SarvamTTSSettings(
model=user_config.tts.model,
voice=voice,
language=pipecat_language,
),
text_filters=[xml_function_tag_filter],
silence_time_s=1.0,
)
else:
raise HTTPException(
@ -253,16 +278,15 @@ def create_llm_service(user_config):
if "gpt-5" in model:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(
reasoning_effort="minimal", verbosity="low"
settings=OpenAILLMSettings(
model=model,
extra={"reasoning_effort": "minimal", "verbosity": "low"},
),
)
else:
return OpenAILLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=OpenAILLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GROQ.value:
print(
@ -270,36 +294,30 @@ def create_llm_service(user_config):
)
return GroqLLMService(
api_key=user_config.llm.api_key,
model=model,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=GroqLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
return OpenRouterLLMService(
api_key=user_config.llm.api_key,
model=model,
base_url=user_config.llm.base_url,
params=OpenAILLMService.InputParams(temperature=0.1),
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
return GoogleLLMService(
api_key=user_config.llm.api_key,
model=model,
params=GoogleLLMService.InputParams(temperature=0.1),
settings=GoogleLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.AZURE.value:
return AzureLLMService(
api_key=user_config.llm.api_key,
endpoint=user_config.llm.endpoint,
model=model, # Azure uses deployment name as model
params=AzureLLMService.InputParams(temperature=0.1),
settings=AzureLLMSettings(model=model, temperature=0.1),
)
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
return DograhLLMService(
base_url=f"{MPS_API_URL}/api/v1/llm",
api_key=user_config.llm.api_key,
model=model,
settings=OpenAILLMSettings(model=model),
)
else:
raise HTTPException(status_code=400, detail="Invalid LLM provider")

View file

@ -5,6 +5,7 @@ from api.services.workflow.disposition_mapper import (
get_organization_id_from_workflow_run,
)
from api.services.workflow.workflow import Node, WorkflowGraph
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
BotStartedSpeakingFrame,
BotStoppedSpeakingFrame,
@ -16,6 +17,7 @@ from pipecat.frames.frames import (
from pipecat.pipeline.task import PipelineTask
from pipecat.processors.aggregators.llm_context import LLMContext
from pipecat.services.llm_service import FunctionCallParams
from pipecat.services.settings import LLMSettings
from pipecat.utils.enums import EndTaskReason
if TYPE_CHECKING:
@ -31,18 +33,19 @@ import asyncio
from loguru import logger
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
from api.services.workflow.pipecat_engine_context_composer import (
compose_functions_for_node,
compose_system_prompt_for_node,
)
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
render_template,
update_llm_context,
)
from api.services.workflow.pipecat_engine_variable_extractor import (
VariableExtractionManager,
)
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
from api.services.workflow.tools.knowledge_base import (
get_knowledge_base_tool,
retrieve_from_knowledge_base,
)
from api.services.workflow.tools.timezone import (
@ -50,6 +53,7 @@ from api.services.workflow.tools.timezone import (
get_current_time,
get_time_tools,
)
from api.utils.template_renderer import render_template
class PipecatEngine:
@ -68,6 +72,7 @@ class PipecatEngine:
embeddings_api_key: Optional[str] = None,
embeddings_model: Optional[str] = None,
embeddings_base_url: Optional[str] = None,
has_recordings: bool = False,
):
self.task = task
self.llm = llm
@ -113,6 +118,10 @@ class PipecatEngine:
# Audio configuration (set via set_audio_config from _run_pipeline)
self._audio_config = None
# True when the workflow has active recordings; enables recording
# response mode instructions on all nodes for in-context learning.
self._has_recordings: bool = has_recordings
async def _get_organization_id(self) -> Optional[int]:
"""Get and cache the organization ID from workflow run."""
if self._custom_tool_manager:
@ -194,15 +203,14 @@ class PipecatEngine:
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
raise
def _get_function_schema(self, function_name: str, description: str):
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
"""Update LLM settings with the composed system prompt and tool list."""
return get_function_schema(function_name, description)
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
"""Delegate context update to the shared workflow.utils implementation."""
update_llm_context(self.context, system_message, functions)
if functions:
tools_schema = ToolsSchema(standard_tools=functions)
self.context.set_tools(tools_schema)
def _format_prompt(self, prompt: str) -> str:
"""Delegate prompt formatting to the shared workflow.utils implementation."""
@ -473,12 +481,19 @@ class PipecatEngine:
if node.document_uuids:
await self._register_knowledge_base_function(node.document_uuids)
# Set up system message and functions
(
system_message,
functions,
) = await self._compose_system_message_functions_for_node(node)
await self._update_llm_context(system_message, functions)
# Compose prompt and functions via the context composer module
system_prompt = compose_system_prompt_for_node(
node=node,
workflow=self.workflow,
format_prompt=self._format_prompt,
has_recordings=self._has_recordings,
)
functions = await compose_functions_for_node(
node=node,
builtin_function_schemas=self.builtin_function_schemas,
custom_tool_manager=self._custom_tool_manager,
)
await self._update_llm_context(system_prompt, functions)
async def set_node(self, node_id: str):
"""
@ -610,62 +625,6 @@ class PipecatEngine:
)
await self.task.queue_frame(frame_to_push)
async def _compose_system_message_functions_for_node(
self, node: "Node"
) -> tuple[list[dict], list[dict]]:
"""Generate the system messages and function schemas for the given node.
This performs the same formatting logic used when entering a node but
does **not** register the functions with the LLM; callers are
responsible for that.
"""
global_prompt = ""
if self.workflow.global_node_id and node.add_global_prompt:
global_node = self.workflow.nodes[self.workflow.global_node_id]
global_prompt = self._format_prompt(global_node.prompt)
functions: list[dict] = []
# Add built-in function schemas (calculator and timezone tools)
functions.extend(self.builtin_function_schemas)
# Add knowledge base retrieval tool if node has documents
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Add custom tools from node.tool_uuids
if node.tool_uuids and self._custom_tool_manager:
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
# Transition functions (schema only; registration handled elsewhere)
for outgoing_edge in node.out_edges:
function_schema = self._get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
formatted_node_prompt = self._format_prompt(node.prompt)
system_message = {
"role": "system",
"content": "\n\n".join(
p for p in (global_prompt, formatted_node_prompt) if p
),
}
return system_message, functions
async def should_mute_user(self, frame: "Frame") -> bool:
"""
Callback for CallbackUserMuteStrategy to determine if the user should be muted.

View file

@ -0,0 +1,138 @@
"""System prompt and function schema composition for PipecatEngine nodes.
Extracts prompt and function composition logic from PipecatEngine into
reusable functions. Defines recording response mode markers and instructions.
"""
from typing import TYPE_CHECKING, Callable, Optional
if TYPE_CHECKING:
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.workflow import Node, WorkflowGraph
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
from api.services.workflow.tools.knowledge_base import get_knowledge_base_tool
# ---------------------------------------------------------------------------
# Recording response mode markers
# ---------------------------------------------------------------------------
RECORDING_MARKER = "" # Play pre-recorded audio
TTS_MARKER = "" # Generate dynamic TTS text
# ---------------------------------------------------------------------------
# Recording response mode system prompt instructions
# ---------------------------------------------------------------------------
RECORDING_RESPONSE_MODE_INSTRUCTIONS = """\
RESPONSE MODE INSTRUCTIONS - MANDATORY FORMAT:
Every response you generate MUST begin with a response mode indicator.
You have two modes for responding:
1. DYNAMIC SPEECH (): Generate text that will be converted to speech by TTS.
Format: `` followed by a space and your full spoken response.
Example: Hello! How can I help you today?
2. PRE-RECORDED AUDIO (): Play a pre-recorded audio message.
Format: `` followed by a space and ONLY the recording_id. Nothing else.
Example: rec_greeting_01
RULES:
- Your response MUST start with either `` or `` as the very first character.
- For `` (dynamic speech): Follow with a space and your full response text.
- For `` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
- Use `` when a pre-recorded message matches the situation well.
- Use `` when you need to generate a dynamic, contextual response.
- NEVER mix modes in a single response. Choose one."""
def compose_system_prompt_for_node(
*,
node: "Node",
workflow: "WorkflowGraph",
format_prompt: Callable[[str], str],
has_recordings: bool,
) -> str:
"""Compose the full system prompt text for a workflow node.
Combines the global prompt, node-specific prompt, and (when recordings
are enabled anywhere in the workflow) the recording response mode
instructions into a single string.
Args:
node: The workflow node to compose the prompt for.
workflow: The full workflow graph (needed for global node prompt).
format_prompt: Callable to render template variables in prompts.
has_recordings: Whether any node in the workflow uses recordings.
Returns:
The composed system prompt text.
"""
global_prompt = ""
if workflow.global_node_id and node.add_global_prompt:
global_node = workflow.nodes[workflow.global_node_id]
global_prompt = format_prompt(global_node.prompt)
formatted_node_prompt = format_prompt(node.prompt)
parts = [p for p in (global_prompt, formatted_node_prompt) if p]
if has_recordings:
parts.append(RECORDING_RESPONSE_MODE_INSTRUCTIONS)
# TODO: Append per-node available recordings list here once
# Node.recording_ids is populated. The list should include
# recording_id and a short description so the LLM can choose.
return "\n\n".join(parts)
async def compose_functions_for_node(
*,
node: "Node",
builtin_function_schemas: list[dict],
custom_tool_manager: Optional["CustomToolManager"],
) -> list[dict]:
"""Compose the function/tool schemas for a workflow node.
Gathers built-in tools, knowledge-base tools, custom tools,
and transition function schemas into a single list.
Args:
node: The workflow node to compose functions for.
builtin_function_schemas: Pre-computed schemas for built-in tools.
custom_tool_manager: Manager for user-defined custom tools (may be None).
Returns:
A list of function schemas to register with the LLM.
"""
functions: list[dict] = []
# Built-in tools (calculator, timezone)
functions.extend(builtin_function_schemas)
# Knowledge base retrieval tool
if node.document_uuids:
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
kb_schema = get_function_schema(
kb_tool_def["function"]["name"],
kb_tool_def["function"]["description"],
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
required=kb_tool_def["function"]["parameters"].get("required", []),
)
functions.append(kb_schema)
# Custom tools
if node.tool_uuids and custom_tool_manager:
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
node.tool_uuids
)
functions.extend(custom_tool_schemas)
# Transition function schemas
for outgoing_edge in node.out_edges:
function_schema = get_function_schema(
outgoing_edge.get_function_name(), outgoing_edge.condition
)
functions.append(function_schema)
return functions

View file

@ -10,7 +10,7 @@ import asyncio
import re
import time
import uuid
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from loguru import logger
@ -23,7 +23,6 @@ from api.services.telephony.transfer_event_protocol import TransferContext
from api.services.workflow.disposition_mapper import (
get_organization_id_from_workflow_run,
)
from api.services.workflow.pipecat_engine_utils import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
@ -42,6 +41,29 @@ if TYPE_CHECKING:
from api.services.workflow.pipecat_engine import PipecatEngine
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
class CustomToolManager:
"""Manager for custom tool registration and execution.

View file

@ -1,68 +0,0 @@
from __future__ import annotations
from typing import Any, Dict, List
from api.utils.template_renderer import render_template
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
__all__ = [
"get_function_schema",
"update_llm_context",
"render_template",
]
def get_function_schema(
function_name: str,
description: str,
*,
properties: Dict[str, Any] | None = None,
required: List[str] | None = None,
) -> FunctionSchema:
"""Create a FunctionSchema definition that can later be transformed into
the provider-specific format (OpenAI, Gemini, etc.).
The helper keeps the public signature backward-compatible callers that
only pass ``function_name`` and ``description`` continue to work and will
define a parameter-less function.
"""
return FunctionSchema(
name=function_name,
description=description,
properties=properties or {},
required=required or [],
)
def update_llm_context(
context: LLMContext,
system_message: Dict[str, Any],
functions: List[FunctionSchema],
) -> None:
"""Update *context* with an up-to-date system message and tool list.
This helper removes any previous system messages before inserting the new
*system_message* at the top of the conversation history and then instructs
the LLM which *functions* (a.k.a. tools) are currently available.
"""
# Wrap the provided function schemas in a ToolsSchema so that the adapter
# associated with the current LLM service can convert them to the correct
# provider-specific representation when required.
tools_schema = ToolsSchema(standard_tools=functions)
previous_interactions = context.messages
# Replace the first message if it's a system message, otherwise prepend.
# Keep any system messages that appear in the middle of the conversation.
if previous_interactions and previous_interactions[0]["role"] == "system":
messages = [system_message] + previous_interactions[1:]
else:
messages = [system_message] + previous_interactions
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)

View file

@ -13,14 +13,12 @@ from unittest.mock import AsyncMock, Mock, patch
import pytest
from api.services.workflow.pipecat_engine_utils import (
get_function_schema,
update_llm_context,
)
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
from api.services.workflow.tools.custom_tool import (
execute_http_tool,
tool_to_function_schema,
)
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.frames.frames import (
FunctionCallInProgressFrame,
FunctionCallResultFrame,
@ -862,11 +860,27 @@ class TestCustomToolManagerUnit:
assert result_received["status"] == "success"
def _update_llm_context(context, system_message, functions):
"""Inline helper replicating the old update_llm_context for tests."""
tools_schema = ToolsSchema(standard_tools=functions)
previous_interactions = context.messages
if previous_interactions and previous_interactions[0]["role"] == "system":
messages = [system_message] + previous_interactions[1:]
else:
messages = [system_message] + previous_interactions
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)
class TestUpdateLLMContext:
"""Tests for update_llm_context function."""
"""Tests for _update_llm_context inline logic."""
def test_replaces_system_message(self):
"""Test that update_llm_context replaces existing system messages."""
"""Test that _update_llm_context replaces existing system messages."""
context = LLMContext()
context.set_messages(
[
@ -877,7 +891,7 @@ class TestUpdateLLMContext:
)
new_system = {"role": "system", "content": "New system message"}
update_llm_context(context, new_system, [])
_update_llm_context(context, new_system, [])
messages = context.messages
# Should have new system message at the start
@ -902,7 +916,7 @@ class TestUpdateLLMContext:
)
new_system = {"role": "system", "content": "New prompt"}
update_llm_context(context, new_system, [])
_update_llm_context(context, new_system, [])
messages = context.messages
assert len(messages) == 5
@ -923,7 +937,7 @@ class TestUpdateLLMContext:
]
new_system = {"role": "system", "content": "New prompt with tools"}
update_llm_context(context, new_system, functions)
_update_llm_context(context, new_system, functions)
# Verify tools were set
tools = context.tools
@ -936,7 +950,7 @@ class TestUpdateLLMContext:
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "New prompt without tools"}
update_llm_context(context, new_system, [])
_update_llm_context(context, new_system, [])
# Tools should not be set (or remain None)
# Note: The function only calls set_tools if functions is truthy
@ -952,7 +966,7 @@ class TestUpdateLLMContext:
new_system = {"role": "system", "content": "Initial prompt"}
functions = [get_function_schema("test_func", "A test function")]
update_llm_context(context, new_system, functions)
_update_llm_context(context, new_system, functions)
messages = context.messages
assert len(messages) == 1

View file

@ -1,8 +1,8 @@
"""Integration tests for CustomToolManager with update_llm_context.
"""Integration tests for CustomToolManager with LLM context updates.
This module tests the full flow of:
1. CustomToolManager fetching and converting tool schemas
2. update_llm_context setting those tools on the LLM context
2. Setting those tools on the LLM context
3. Verifying the context is properly configured for LLM generation
"""
@ -10,16 +10,32 @@ from unittest.mock import AsyncMock, patch
import pytest
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
from api.services.workflow.pipecat_engine_utils import (
from api.services.workflow.pipecat_engine_custom_tools import (
CustomToolManager,
get_function_schema,
update_llm_context,
)
from api.tests.conftest import MockToolModel
from pipecat.adapters.schemas.function_schema import FunctionSchema
from pipecat.adapters.schemas.tools_schema import ToolsSchema
from pipecat.processors.aggregators.llm_context import LLMContext
def _update_llm_context(context, system_message, functions):
"""Inline helper replicating the update_llm_context logic for tests."""
tools_schema = ToolsSchema(standard_tools=functions)
previous_interactions = context.messages
if previous_interactions and previous_interactions[0]["role"] == "system":
messages = [system_message] + previous_interactions[1:]
else:
messages = [system_message] + previous_interactions
context.set_messages(messages)
if functions:
context.set_tools(tools_schema)
class TestCustomToolManagerContextIntegration:
"""Integration tests for CustomToolManager with LLMContext."""
@ -69,7 +85,7 @@ class TestCustomToolManagerContextIntegration:
"role": "system",
"content": "You are a scheduling assistant with access to weather and booking tools.",
}
update_llm_context(context, new_system, schemas)
_update_llm_context(context, new_system, schemas)
# Verify context was updated correctly
messages = context.messages
@ -195,7 +211,7 @@ class TestCustomToolManagerContextIntegration:
"role": "system",
"content": "Assistant with calculator and weather tools",
}
update_llm_context(context, new_system, all_functions)
_update_llm_context(context, new_system, all_functions)
# Verify all tools are present
tools = context.tools
@ -259,7 +275,7 @@ class TestCustomToolManagerContextIntegration:
)
new_system = {"role": "system", "content": "Updated weather assistant"}
update_llm_context(context, new_system, schemas)
_update_llm_context(context, new_system, schemas)
messages = context.messages
# System + user + assistant(tool_call) + tool + assistant = 5
@ -296,7 +312,7 @@ class TestCustomToolManagerContextIntegration:
context.set_messages([{"role": "system", "content": "Old"}])
new_system = {"role": "system", "content": "No tools available"}
update_llm_context(context, new_system, [])
_update_llm_context(context, new_system, [])
# Context should have updated message but no tools set
assert context.messages[0]["content"] == "No tools available"
@ -362,7 +378,7 @@ class TestCustomToolManagerContextIntegration:
# Update context - pass schema directly
context = LLMContext()
context.set_messages([{"role": "system", "content": "Old"}])
update_llm_context(
_update_llm_context(
context, {"role": "system", "content": "Order assistant"}, schemas
)

View file

@ -68,9 +68,7 @@ class ContextCapturingMockLLM(MockLLMService):
{
"step": self._current_step,
"messages": messages_snapshot,
"system_prompt": messages_snapshot[0]["content"]
if messages_snapshot
else None,
"system_prompt": self._settings.system_instruction,
}
)
@ -101,12 +99,10 @@ class ContextCapturingMockLLM(MockLLMService):
return False
def get_system_prompt_at_step(self, step: int) -> str:
"""Get the system prompt from context at a specific step."""
"""Get the system prompt from settings at a specific step."""
ctx = self.get_context_at_step(step)
if ctx and ctx["messages"]:
first_msg = ctx["messages"][0]
if first_msg.get("role") == "system":
return first_msg.get("content", "")
if ctx:
return ctx.get("system_prompt") or ""
return ""

View file

@ -186,7 +186,7 @@ class TestPipecatEngineToolCalls:
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_1(
@ -233,7 +233,7 @@ class TestPipecatEngineToolCalls:
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
@ -281,7 +281,7 @@ class TestPipecatEngineToolCalls:
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
@pytest.mark.asyncio
async def test_single_transition_call_through_engine(
@ -315,4 +315,4 @@ class TestPipecatEngineToolCalls:
)
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT

@ -1 +1 @@
Subproject commit 10e8ded96672b08503db48c3d34e8345b11be4a2
Subproject commit ac418b4b3e915a801c82f401e3c68f12ecca98cb

View file

@ -0,0 +1,24 @@
"use client";
import { ArrowLeft } from "lucide-react";
import { useRouter } from "next/navigation";
import { Button } from "@/components/ui/button";
export function BackButton() {
const router = useRouter();
return (
<header className="flex items-center border-b px-4 py-3">
<Button
variant="ghost"
size="sm"
onClick={() => router.back()}
className="gap-2"
>
<ArrowLeft className="h-4 w-4" />
Go Back
</Button>
</header>
);
}

View file

@ -2,6 +2,8 @@ import { StackHandler } from "@stackframe/stack";
import { getAuthProvider } from "@/lib/auth/config";
import { BackButton } from "./BackButton";
export default async function Handler(props: unknown) {
const authProvider = await getAuthProvider();
@ -18,9 +20,16 @@ export default async function Handler(props: unknown) {
const { getStackServerApp } = await import("@/lib/auth/server");
const app = await getStackServerApp();
return <StackHandler
fullPage
app={app!}
routeProps={props}
/>;
return (
<div className="flex flex-col h-screen">
<BackButton />
<div className="flex-1 overflow-auto">
<StackHandler
fullPage
app={app!}
routeProps={props}
/>
</div>
</div>
);
}

View file

@ -1,12 +1,12 @@
"use client";
import { Calendar, ChevronLeft, ChevronRight, Globe } from 'lucide-react';
import { ChevronLeft, ChevronRight, Globe } from 'lucide-react';
import { useRouter, useSearchParams } from 'next/navigation';
import { useCallback, useEffect, useId, useState } from 'react';
import TimezoneSelect, { type ITimezoneOption } from 'react-timezone-select';
import { getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet,getUsageHistoryApiV1OrganizationsUsageRunsGet } from '@/client/sdk.gen';
import type { CurrentUsageResponse, DailyUsageBreakdownResponse,UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
import { getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet, getUsageHistoryApiV1OrganizationsUsageRunsGet } from '@/client/sdk.gen';
import type { DailyUsageBreakdownResponse, MpsCreditsResponse, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
import { DailyUsageTable } from '@/components/DailyUsageTable';
import { FilterBuilder } from '@/components/filters/FilterBuilder';
import { MediaPreviewButton, MediaPreviewDialog } from '@/components/MediaPreviewDialog';
@ -37,9 +37,9 @@ export default function UsagePage() {
const { userConfig, saveUserConfig, loading: userConfigLoading, organizationPricing } = useUserConfig();
const auth = useAuth();
// Current usage state
const [currentUsage, setCurrentUsage] = useState<CurrentUsageResponse | null>(null);
const [isLoadingCurrent, setIsLoadingCurrent] = useState(true);
// MPS credits state
const [mpsCredits, setMpsCredits] = useState<MpsCreditsResponse | null>(null);
const [isLoadingCredits, setIsLoadingCredits] = useState(true);
// Usage history state
const [usageHistory, setUsageHistory] = useState<UsageHistoryResponse | null>(null);
@ -68,19 +68,18 @@ export default function UsagePage() {
const [savingTimezone, setSavingTimezone] = useState(false);
const timezoneSelectId = useId(); // Stable ID for react-select to prevent hydration mismatch
// Fetch current usage
const fetchCurrentUsage = useCallback(async () => {
// Fetch MPS credits
const fetchMpsCredits = useCallback(async () => {
if (!auth.isAuthenticated) return;
try {
const response = await getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet();
const response = await getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet();
if (response.data) {
setCurrentUsage(response.data);
setMpsCredits(response.data);
}
} catch (error) {
console.error('Failed to fetch current usage:', error);
console.error('Failed to fetch MPS credits:', error);
} finally {
setIsLoadingCurrent(false);
setIsLoadingCredits(false);
}
}, [auth.isAuthenticated]);
@ -195,10 +194,10 @@ export default function UsagePage() {
// Initial load - fetch when auth becomes available
useEffect(() => {
if (auth.isAuthenticated) {
fetchCurrentUsage();
fetchMpsCredits();
fetchUsageHistory(currentPage, activeFilters);
}
}, [auth.isAuthenticated, currentPage, activeFilters, fetchUsageHistory, fetchCurrentUsage]);
}, [auth.isAuthenticated, currentPage, activeFilters, fetchUsageHistory, fetchMpsCredits]);
// Fetch daily usage when organizationPricing becomes available
useEffect(() => {
@ -259,20 +258,6 @@ export default function UsagePage() {
router.push(`/workflow/${run.workflow_id}/run/${run.id}`);
};
// Format date for display with timezone support
const formatDate = (dateString: string) => {
const date = new Date(dateString);
const tzValue = typeof selectedTimezone === 'string' ? selectedTimezone : selectedTimezone.value;
// Use local timezone if none selected (during loading)
const effectiveTz = tzValue || localTimezone;
return date.toLocaleDateString('en-US', {
timeZone: effectiveTz,
year: 'numeric',
month: 'short',
day: 'numeric'
});
};
// Format datetime for display with timezone support
const formatDateTime = (dateString: string) => {
const date = new Date(dateString);
@ -383,68 +368,42 @@ export default function UsagePage() {
</div>
</div>
{/* Current Period Card */}
{/* MPS Credits Card */}
<Card className="mb-6">
<CardHeader>
<CardTitle>Current Billing Period</CardTitle>
<CardTitle>Dograh Model Credits</CardTitle>
<CardDescription>
{currentUsage && `${formatDate(currentUsage.period_start)} - ${formatDate(currentUsage.period_end)}`}
These track usage of Dograh models using Dograh Service Keys.
</CardDescription>
</CardHeader>
<CardContent>
{isLoadingCurrent ? (
{isLoadingCredits ? (
<div className="animate-pulse space-y-4">
<div className="h-4 bg-muted rounded w-1/4"></div>
<div className="h-8 bg-muted rounded"></div>
<div className="h-4 bg-muted rounded w-1/3"></div>
</div>
) : currentUsage ? (
) : mpsCredits ? (
<div className="space-y-4">
<div className="flex justify-between items-baseline">
<div>
{organizationPricing?.price_per_second_usd ? (
<>
<p className="text-2xl font-bold">
${(currentUsage.used_amount_usd || 0).toFixed(2)}
</p>
<p className="text-sm text-muted-foreground">Total Cost (USD)</p>
<p className="text-xs text-muted-foreground mt-1">
Rate: ${(organizationPricing.price_per_second_usd * 60).toFixed(4)}/minute
</p>
</>
) : (
<>
<p className="text-2xl font-bold">
{currentUsage.used_dograh_tokens.toLocaleString()} / {currentUsage.quota_dograh_tokens.toLocaleString()}
</p>
<p className="text-sm text-muted-foreground">Dograh Tokens</p>
</>
)}
<p className="text-2xl font-bold">
{mpsCredits.total_credits_used.toFixed(2)} <span className="text-lg font-normal text-muted-foreground">/ {mpsCredits.total_quota.toFixed(2)}</span>
</p>
<p className="text-sm text-muted-foreground">Credits Used</p>
</div>
<div className="text-right">
<p className="text-lg font-semibold">{mpsCredits.remaining_credits.toFixed(2)}</p>
<p className="text-sm text-muted-foreground">Remaining</p>
</div>
{!organizationPricing?.price_per_second_usd && (
<div className="text-right">
<p className="text-lg font-semibold">{currentUsage.percentage_used}%</p>
<p className="text-sm text-muted-foreground">Used</p>
</div>
)}
</div>
{!organizationPricing?.price_per_second_usd && (
<Progress value={currentUsage.percentage_used} className="h-3" />
{mpsCredits.total_quota > 0 && (
<Progress value={(mpsCredits.total_credits_used / mpsCredits.total_quota) * 100} className="h-3" />
)}
<div className="flex justify-between items-center text-sm text-muted-foreground">
<div className="flex items-center">
<Calendar className="h-4 w-4 mr-1" />
Next refresh: {formatDate(currentUsage.next_refresh_date)}
</div>
<div>
Total Duration: <span className="font-medium text-foreground">{formatDuration(currentUsage.total_duration_seconds)}</span>
</div>
</div>
</div>
) : (
<p className="text-muted-foreground">Unable to load usage data</p>
<p className="text-muted-foreground">No Dograh service keys configured. Set up a service key in your model configuration to see usage.</p>
)}
</CardContent>
</Card>

View file

@ -6,11 +6,11 @@ import {
Panel,
ReactFlow,
} from "@xyflow/react";
import { BookA, BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
import { BookA, BrushCleaning, Maximize2, Mic, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
import React, { useEffect, useMemo, useState } from 'react';
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listRecordingsApiV1WorkflowRecordingsGet, listToolsApiV1ToolsGet } from '@/client';
import type { DocumentResponseSchema, RecordingResponseSchema, ToolResponse } from '@/client/types.gen';
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
import { Button } from '@/components/ui/button';
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
@ -23,6 +23,7 @@ import { ConfigurationsDialog } from './components/ConfigurationsDialog';
import { DictionaryDialog } from './components/DictionaryDialog';
import { EmbedDialog } from './components/EmbedDialog';
import { PhoneCallDialog } from './components/PhoneCallDialog';
import { RecordingsDialog } from './components/RecordingsDialog';
import { TemplateContextVariablesDialog } from './components/TemplateContextVariablesDialog';
import { WorkflowEditorHeader } from "./components/WorkflowEditorHeader";
import { WorkflowProvider } from "./contexts/WorkflowContext";
@ -67,8 +68,10 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
const [isDictionaryDialogOpen, setIsDictionaryDialogOpen] = useState(false);
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
const [isRecordingsDialogOpen, setIsRecordingsDialogOpen] = useState(false);
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
const [recordings, setRecordings] = useState<RecordingResponseSchema[]>([]);
const {
rfInstance,
@ -102,7 +105,7 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
user,
});
// Fetch documents and tools once for the entire workflow
// Fetch documents, tools, and recordings once for the entire workflow
useEffect(() => {
const fetchData = async () => {
try {
@ -119,13 +122,25 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
if (toolsResponse.data) {
setTools(toolsResponse.data);
}
// Fetch recordings for this workflow
try {
const recordingsResponse = await listRecordingsApiV1WorkflowRecordingsGet({
query: { workflow_id: workflowId },
});
if (recordingsResponse.data) {
setRecordings(recordingsResponse.data.recordings);
}
} catch {
// Recordings API may not be available yet; silently ignore
}
} catch (error) {
console.error('Failed to fetch documents and tools:', error);
}
};
fetchData();
}, []);
}, [workflowId]);
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
const defaultEdgeOptions = useMemo(() => ({
@ -137,8 +152,9 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
const workflowContextValue = useMemo(() => ({
saveWorkflow,
documents,
tools
}), [saveWorkflow, documents, tools]);
tools,
recordings,
}), [saveWorkflow, documents, tools, recordings]);
return (
<WorkflowProvider value={workflowContextValue}>
@ -251,6 +267,22 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
variant="outline"
size="icon"
onClick={() => setIsRecordingsDialogOpen(true)}
className="bg-white shadow-sm hover:shadow-md"
>
<Mic className="h-4 w-4" />
</Button>
</TooltipTrigger>
<TooltipContent side="left">
<p>Recordings</p>
</TooltipContent>
</Tooltip>
<Tooltip>
<TooltipTrigger asChild>
<Button
@ -389,6 +421,13 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
workflowId={workflowId}
user={user}
/>
<RecordingsDialog
open={isRecordingsDialogOpen}
onOpenChange={setIsRecordingsDialogOpen}
workflowId={workflowId}
onRecordingsChange={setRecordings}
/>
</div>
</WorkflowProvider>
);

View file

@ -0,0 +1,314 @@
import { Loader2, Trash2Icon, Upload } from "lucide-react";
import { useCallback, useEffect, useRef, useState } from "react";
import {
createRecordingApiV1WorkflowRecordingsPost,
deleteRecordingApiV1WorkflowRecordingsRecordingIdDelete,
getUploadUrlApiV1WorkflowRecordingsUploadUrlPost,
listRecordingsApiV1WorkflowRecordingsGet,
} from "@/client";
import type { RecordingResponseSchema } from "@/client/types.gen";
import { Button } from "@/components/ui/button";
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { useUserConfig } from "@/context/UserConfigContext";
interface RecordingsDialogProps {
open: boolean;
onOpenChange: (open: boolean) => void;
workflowId: number;
onRecordingsChange?: (recordings: RecordingResponseSchema[]) => void;
}
const MAX_FILE_SIZE = 5 * 1024 * 1024; // 5MB
export const RecordingsDialog = ({
open,
onOpenChange,
workflowId,
onRecordingsChange,
}: RecordingsDialogProps) => {
const { userConfig } = useUserConfig();
const [recordings, setRecordings] = useState<RecordingResponseSchema[]>([]);
const [loading, setLoading] = useState(false);
const [uploading, setUploading] = useState(false);
const [transcript, setTranscript] = useState("");
const [selectedFile, setSelectedFile] = useState<File | null>(null);
const [error, setError] = useState<string | null>(null);
const fileInputRef = useRef<HTMLInputElement>(null);
const ttsProvider = (userConfig?.tts?.provider as string) ?? "";
const ttsModel = (userConfig?.tts?.model as string) ?? "";
const ttsVoiceId = (userConfig?.tts?.voice as string) ?? "";
const fetchRecordings = useCallback(async () => {
if (!workflowId) return;
setLoading(true);
try {
const result = await listRecordingsApiV1WorkflowRecordingsGet({
query: {
workflow_id: workflowId,
tts_provider: ttsProvider || undefined,
tts_model: ttsModel || undefined,
tts_voice_id: ttsVoiceId || undefined,
},
});
const recs = result.data?.recordings ?? [];
setRecordings(recs);
onRecordingsChange?.(recs);
} catch {
setError("Failed to load recordings");
} finally {
setLoading(false);
}
}, [workflowId, ttsProvider, ttsModel, ttsVoiceId, onRecordingsChange]);
useEffect(() => {
if (open) {
fetchRecordings();
setError(null);
setTranscript("");
setSelectedFile(null);
}
}, [open, fetchRecordings]);
const handleUpload = async () => {
if (!selectedFile || !transcript.trim()) return;
if (!ttsProvider || !ttsModel || !ttsVoiceId) {
setError(
"TTS configuration (provider, model, voice) must be set in your user configuration before uploading."
);
return;
}
setUploading(true);
setError(null);
try {
// Step 1: Get presigned URL
const uploadUrlResponse =
await getUploadUrlApiV1WorkflowRecordingsUploadUrlPost({
body: {
workflow_id: workflowId,
filename: selectedFile.name,
mime_type: selectedFile.type || "audio/wav",
file_size: selectedFile.size,
},
});
if (!uploadUrlResponse.data) {
throw new Error("Failed to get upload URL");
}
const { upload_url, recording_id, storage_key } =
uploadUrlResponse.data;
// Step 2: Upload file directly to storage
const uploadResponse = await fetch(upload_url, {
method: "PUT",
body: selectedFile,
headers: {
"Content-Type": selectedFile.type || "audio/wav",
},
});
if (!uploadResponse.ok) {
throw new Error("File upload failed");
}
// Step 3: Create recording record
await createRecordingApiV1WorkflowRecordingsPost({
body: {
recording_id,
workflow_id: workflowId,
tts_provider: ttsProvider,
tts_model: ttsModel,
tts_voice_id: ttsVoiceId,
transcript: transcript.trim(),
storage_key,
metadata: {
original_filename: selectedFile.name,
file_size_bytes: selectedFile.size,
mime_type: selectedFile.type,
},
},
});
// Reset form and refresh list
setTranscript("");
setSelectedFile(null);
if (fileInputRef.current) fileInputRef.current.value = "";
await fetchRecordings();
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to upload recording"
);
} finally {
setUploading(false);
}
};
const handleDelete = async (recordingId: string) => {
try {
await deleteRecordingApiV1WorkflowRecordingsRecordingIdDelete({
path: { recording_id: recordingId },
});
await fetchRecordings();
} catch {
setError("Failed to delete recording");
}
};
return (
<Dialog open={open} onOpenChange={onOpenChange}>
<DialogContent className="max-w-lg max-h-[80vh] overflow-y-auto">
<DialogHeader>
<DialogTitle>Workflow Recordings</DialogTitle>
<DialogDescription>
Upload audio recordings for hybrid prompts. Recordings are
scoped to your current TTS configuration. Use{" "}
<code className="text-xs bg-muted px-1 rounded">@</code> in
prompt fields to insert them.
</DialogDescription>
</DialogHeader>
{/* Current TTS Config */}
<div className="rounded-md border p-3 bg-muted/30 text-sm space-y-1">
<div className="font-medium text-xs text-muted-foreground uppercase tracking-wide">
Current TTS Configuration
</div>
{ttsProvider ? (
<div className="flex flex-wrap gap-2 text-xs">
<span className="bg-background px-2 py-0.5 rounded border">
Provider: {ttsProvider}
</span>
<span className="bg-background px-2 py-0.5 rounded border">
Model: {ttsModel}
</span>
<span className="bg-background px-2 py-0.5 rounded border truncate max-w-[200px]">
VoiceID: {ttsVoiceId}
</span>
</div>
) : (
<p className="text-xs text-destructive">
No TTS configuration found. Set it in Model Configurations.
</p>
)}
</div>
{error && (
<div className="text-sm text-destructive bg-destructive/10 rounded-md p-2">
{error}
</div>
)}
{/* Upload Section */}
<div className="space-y-3 border rounded-md p-3">
<Label className="text-sm font-medium">Upload New Recording</Label>
<div>
<Label className="text-xs text-muted-foreground">
Audio File
</Label>
<Input
ref={fileInputRef}
type="file"
accept="audio/*"
onChange={(e) => {
const file = e.target.files?.[0] ?? null;
if (file && file.size > MAX_FILE_SIZE) {
setError(
`File size (${(file.size / (1024 * 1024)).toFixed(1)}MB) exceeds the maximum allowed size of 5MB.`
);
setSelectedFile(null);
if (fileInputRef.current) fileInputRef.current.value = "";
return;
}
setError(null);
setSelectedFile(file);
}}
className="text-sm"
/>
<p className="text-xs text-muted-foreground mt-1">
Max 5MB
</p>
</div>
<div>
<Label className="text-xs text-muted-foreground">
Transcript
</Label>
<Input
placeholder="What does this recording say?"
value={transcript}
onChange={(e) => setTranscript(e.target.value)}
/>
</div>
<Button
size="sm"
onClick={handleUpload}
disabled={!selectedFile || !transcript.trim() || uploading}
>
{uploading ? (
<Loader2 className="w-4 h-4 mr-1 animate-spin" />
) : (
<Upload className="w-4 h-4 mr-1" />
)}
{uploading ? "Uploading..." : "Upload Recording"}
</Button>
</div>
{/* Recordings List */}
<div className="space-y-2">
<Label className="text-sm font-medium">
Recordings{" "}
{!loading && (
<span className="text-muted-foreground font-normal">
({recordings.length})
</span>
)}
</Label>
{loading ? (
<div className="flex items-center justify-center py-4">
<Loader2 className="w-5 h-5 animate-spin text-muted-foreground" />
</div>
) : recordings.length === 0 ? (
<p className="text-sm text-muted-foreground py-2">
No recordings yet for this TTS configuration.
</p>
) : (
recordings.map((rec) => (
<div
key={rec.recording_id}
className="flex items-start gap-2 p-2 border rounded-md"
>
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<code className="text-xs bg-muted px-1.5 py-0.5 rounded font-mono">
{rec.recording_id}
</code>
</div>
<p className="text-sm text-muted-foreground mt-1 break-all line-clamp-2">
{rec.transcript}
</p>
</div>
<Button
size="sm"
variant="ghost"
onClick={() => handleDelete(rec.recording_id)}
>
<Trash2Icon className="w-4 h-4" />
</Button>
</div>
))
)}
</div>
</DialogContent>
</Dialog>
);
};

View file

@ -1,11 +1,13 @@
import { createContext, useContext } from 'react';
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
import type { RecordingResponseSchema } from '@/client/types.gen';
interface WorkflowContextType {
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
documents?: DocumentResponseSchema[];
tools?: ToolResponse[];
recordings?: RecordingResponseSchema[];
}
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);

View file

@ -319,18 +319,10 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
setFeedbackMessages(prev => {
const last = prev[prev.length - 1];
if (last && last.type === 'bot-text' && !last.final) {
// Append to existing bot message with space if needed
const existingText = last.text;
const newText = message.payload.text;
// Add space between chunks if previous doesn't end with space
// and new doesn't start with space or punctuation
const needsSpace = existingText.length > 0 &&
!existingText.endsWith(' ') &&
!newText.startsWith(' ') &&
!/^[.,!?;:]/.test(newText);
// Append to existing bot message
return [
...prev.slice(0, -1),
{ ...last, text: existingText + (needsSpace ? ' ' : '') + newText }
{ ...last, text: last.text + message.payload.text }
];
}
// Start new bot message

View file

@ -69,7 +69,7 @@ export function processTranscriptEvents(events: TranscriptEvent[]): ProcessedMes
} else if (event.type === 'bot-text') {
// Combine consecutive bot-text from the same turn
if (currentBotText && currentBotText.event.turn === event.turn) {
currentBotText.text = currentBotText.text + ' ' + event.text;
currentBotText.text = currentBotText.text + event.text;
} else {
flushBotText();
currentBotText = { event, text: event.text };

View file

@ -16,5 +16,5 @@ import type { ClientOptions } from './types.gen';
export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> = (override?: Config<DefaultClientOptions & T>) => Config<Required<DefaultClientOptions> & T>;
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
baseUrl: 'http://127.0.0.1:8000'
baseUrl: 'https://app.dograh.com'
})));

File diff suppressed because one or more lines are too long

View file

@ -755,6 +755,12 @@ export type LoginRequest = {
password: string;
};
export type MpsCreditsResponse = {
total_credits_used: number;
remaining_credits: number;
total_quota: number;
};
export type PresignedUploadUrlRequest = {
/**
* CSV filename
@ -790,6 +796,116 @@ export type ProcessDocumentRequestSchema = {
s3_key: string;
};
/**
* Request schema for creating a recording record after upload.
*/
export type RecordingCreateRequestSchema = {
/**
* Short recording ID from upload step
*/
recording_id: string;
/**
* Workflow ID
*/
workflow_id: number;
/**
* TTS provider (e.g. elevenlabs)
*/
tts_provider: string;
/**
* TTS model name
*/
tts_model: string;
/**
* TTS voice identifier
*/
tts_voice_id: string;
/**
* User-provided transcript of the recording
*/
transcript: string;
/**
* Storage key from upload step
*/
storage_key: string;
/**
* Optional metadata (file_size, duration, etc.)
*/
metadata?: {
[key: string]: unknown;
} | null;
};
/**
* Response schema for list of recordings.
*/
export type RecordingListResponseSchema = {
recordings: Array<RecordingResponseSchema>;
total: number;
};
/**
* Response schema for a single recording.
*/
export type RecordingResponseSchema = {
id: number;
recording_id: string;
workflow_id: number;
organization_id: number;
tts_provider: string;
tts_model: string;
tts_voice_id: string;
transcript: string;
storage_key: string;
storage_backend: string;
metadata: {
[key: string]: unknown;
};
created_by: number;
created_at: string;
is_active: boolean;
};
/**
* Request schema for getting a presigned upload URL.
*/
export type RecordingUploadRequestSchema = {
/**
* Workflow ID this recording belongs to
*/
workflow_id: number;
/**
* Original filename of the audio file
*/
filename: string;
/**
* MIME type of the audio file
*/
mime_type?: string;
/**
* File size in bytes (max 5MB)
*/
file_size: number;
};
/**
* Response schema with presigned upload URL.
*/
export type RecordingUploadResponseSchema = {
/**
* Presigned URL for uploading the audio
*/
upload_url: string;
/**
* Short unique recording ID
*/
recording_id: string;
/**
* Storage key where file will be uploaded
*/
storage_key: string;
};
export type RetryConfigRequest = {
enabled?: boolean;
max_retries?: number;
@ -4268,6 +4384,39 @@ export type GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponse
export type GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponse = GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponses[keyof GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponses];
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/organizations/usage/mps-credits';
};
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetError = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors];
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses = {
/**
* Successful Response
*/
200: MpsCreditsResponse;
};
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponse = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses];
export type GetUsageHistoryApiV1OrganizationsUsageRunsGetData = {
body?: never;
headers?: {
@ -5065,6 +5214,155 @@ export type SearchChunksApiV1KnowledgeBaseSearchPostResponses = {
export type SearchChunksApiV1KnowledgeBaseSearchPostResponse = SearchChunksApiV1KnowledgeBaseSearchPostResponses[keyof SearchChunksApiV1KnowledgeBaseSearchPostResponses];
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostData = {
body: RecordingUploadRequestSchema;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/workflow-recordings/upload-url';
};
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostError = GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors[keyof GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors];
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses = {
/**
* Successful Response
*/
200: RecordingUploadResponseSchema;
};
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponse = GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses[keyof GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses];
export type ListRecordingsApiV1WorkflowRecordingsGetData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query: {
/**
* Workflow ID
*/
workflow_id: number;
/**
* Filter by TTS provider
*/
tts_provider?: string | null;
/**
* Filter by TTS model
*/
tts_model?: string | null;
/**
* Filter by TTS voice ID
*/
tts_voice_id?: string | null;
};
url: '/api/v1/workflow-recordings/';
};
export type ListRecordingsApiV1WorkflowRecordingsGetErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type ListRecordingsApiV1WorkflowRecordingsGetError = ListRecordingsApiV1WorkflowRecordingsGetErrors[keyof ListRecordingsApiV1WorkflowRecordingsGetErrors];
export type ListRecordingsApiV1WorkflowRecordingsGetResponses = {
/**
* Successful Response
*/
200: RecordingListResponseSchema;
};
export type ListRecordingsApiV1WorkflowRecordingsGetResponse = ListRecordingsApiV1WorkflowRecordingsGetResponses[keyof ListRecordingsApiV1WorkflowRecordingsGetResponses];
export type CreateRecordingApiV1WorkflowRecordingsPostData = {
body: RecordingCreateRequestSchema;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path?: never;
query?: never;
url: '/api/v1/workflow-recordings/';
};
export type CreateRecordingApiV1WorkflowRecordingsPostErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type CreateRecordingApiV1WorkflowRecordingsPostError = CreateRecordingApiV1WorkflowRecordingsPostErrors[keyof CreateRecordingApiV1WorkflowRecordingsPostErrors];
export type CreateRecordingApiV1WorkflowRecordingsPostResponses = {
/**
* Successful Response
*/
200: RecordingResponseSchema;
};
export type CreateRecordingApiV1WorkflowRecordingsPostResponse = CreateRecordingApiV1WorkflowRecordingsPostResponses[keyof CreateRecordingApiV1WorkflowRecordingsPostResponses];
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteData = {
body?: never;
headers?: {
authorization?: string | null;
'X-API-Key'?: string | null;
};
path: {
recording_id: string;
};
query?: never;
url: '/api/v1/workflow-recordings/{recording_id}';
};
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors = {
/**
* Not found
*/
404: unknown;
/**
* Validation Error
*/
422: HttpValidationError;
};
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteError = DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors[keyof DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors];
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteResponses = {
/**
* Successful Response
*/
200: unknown;
};
export type SignupApiV1AuthSignupPostData = {
body: SignupRequest;
path?: never;
@ -5180,5 +5478,5 @@ export type HealthApiV1HealthGetResponses = {
export type HealthApiV1HealthGetResponse = HealthApiV1HealthGetResponses[keyof HealthApiV1HealthGetResponses];
export type ClientOptions = {
baseUrl: 'http://127.0.0.1:8000' | (string & {});
baseUrl: 'https://app.dograh.com' | 'http://localhost:8000' | (string & {});
};

View file

@ -0,0 +1,213 @@
import {
type ChangeEvent,
type KeyboardEvent,
useCallback,
useEffect,
useMemo,
useRef,
useState,
} from "react";
import type { RecordingResponseSchema } from "@/client/types.gen";
import { cn } from "@/lib/utils";
export interface MentionItem {
id: string;
name: string;
description: string;
}
interface MentionTextareaProps {
value: string;
onChange: (value: string) => void;
placeholder?: string;
className?: string;
recordings?: RecordingResponseSchema[];
}
export function MentionTextarea({
value,
onChange,
placeholder,
className,
recordings = [],
}: MentionTextareaProps) {
const textareaRef = useRef<HTMLTextAreaElement>(null);
const dropdownRef = useRef<HTMLDivElement>(null);
const [showDropdown, setShowDropdown] = useState(false);
const [query, setQuery] = useState("");
const [mentionStartIndex, setMentionStartIndex] = useState<number | null>(null);
const [selectedIndex, setSelectedIndex] = useState(0);
// Convert recordings to mention items
const items: MentionItem[] = useMemo(
() =>
recordings.map((r) => ({
id: r.recording_id,
name: r.transcript,
description: r.transcript,
})),
[recordings]
);
const filtered = items.filter(
(item) =>
item.name.toLowerCase().includes(query.toLowerCase()) ||
item.id.toLowerCase().includes(query.toLowerCase())
);
const insertMention = useCallback(
(item: MentionItem) => {
if (mentionStartIndex === null) return;
const textarea = textareaRef.current;
if (!textarea) return;
const before = value.slice(0, mentionStartIndex);
const after = value.slice(textarea.selectionStart);
const mentionText = `RECORDING_ID: ${item.id} [ ${item.description} ]`;
const newValue = before + mentionText + after;
onChange(newValue);
setShowDropdown(false);
setQuery("");
setMentionStartIndex(null);
setSelectedIndex(0);
// Restore cursor position after the inserted mention
requestAnimationFrame(() => {
const cursorPos = before.length + mentionText.length;
textarea.focus();
textarea.setSelectionRange(cursorPos, cursorPos);
});
},
[mentionStartIndex, value, onChange]
);
const handleChange = useCallback(
(e: ChangeEvent<HTMLTextAreaElement>) => {
const newValue = e.target.value;
const cursorPos = e.target.selectionStart;
onChange(newValue);
// Look backwards from cursor to find an unmatched "@"
const textBeforeCursor = newValue.slice(0, cursorPos);
const lastAtIndex = textBeforeCursor.lastIndexOf("@");
if (lastAtIndex !== -1) {
const textBetween = textBeforeCursor.slice(lastAtIndex + 1);
// Only trigger if there's no space before the query or it's at the start
const charBeforeAt = lastAtIndex > 0 ? newValue[lastAtIndex - 1] : " ";
if (
(charBeforeAt === " " || charBeforeAt === "\n" || lastAtIndex === 0) &&
!textBetween.includes(" ")
) {
setShowDropdown(true);
setQuery(textBetween);
setMentionStartIndex(lastAtIndex);
setSelectedIndex(0);
return;
}
}
setShowDropdown(false);
setQuery("");
setMentionStartIndex(null);
},
[onChange]
);
const handleKeyDown = useCallback(
(e: KeyboardEvent<HTMLTextAreaElement>) => {
if (!showDropdown || filtered.length === 0) return;
if (e.key === "ArrowDown") {
e.preventDefault();
setSelectedIndex((prev) => (prev + 1) % filtered.length);
} else if (e.key === "ArrowUp") {
e.preventDefault();
setSelectedIndex((prev) => (prev - 1 + filtered.length) % filtered.length);
} else if (e.key === "Enter" || e.key === "Tab") {
e.preventDefault();
insertMention(filtered[selectedIndex]);
} else if (e.key === "Escape") {
e.preventDefault();
setShowDropdown(false);
}
},
[showDropdown, filtered, selectedIndex, insertMention]
);
// Close dropdown when clicking outside
useEffect(() => {
const handleClickOutside = (e: MouseEvent) => {
if (
dropdownRef.current &&
!dropdownRef.current.contains(e.target as Node) &&
textareaRef.current &&
!textareaRef.current.contains(e.target as Node)
) {
setShowDropdown(false);
}
};
document.addEventListener("mousedown", handleClickOutside);
return () => document.removeEventListener("mousedown", handleClickOutside);
}, []);
// Scroll selected item into view
useEffect(() => {
if (!showDropdown || !dropdownRef.current) return;
const selected = dropdownRef.current.querySelector("[data-selected='true']");
selected?.scrollIntoView({ block: "nearest" });
}, [selectedIndex, showDropdown]);
return (
<div className="relative">
<textarea
ref={textareaRef}
value={value}
onChange={handleChange}
onKeyDown={handleKeyDown}
placeholder={placeholder}
className={cn(
"border-input placeholder:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:bg-input/30 flex field-sizing-content min-h-16 w-full rounded-md border bg-transparent px-3 py-2 text-base shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
className
)}
/>
{showDropdown && filtered.length > 0 && (
<div
ref={dropdownRef}
className="absolute z-50 mt-1 w-full max-h-60 overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md"
>
{filtered.map((item, index) => (
<button
key={item.id}
type="button"
data-selected={index === selectedIndex}
className={cn(
"flex w-full flex-col gap-0.5 px-3 py-2 text-left text-sm cursor-pointer hover:bg-accent",
index === selectedIndex && "bg-accent"
)}
onMouseDown={(e) => {
e.preventDefault(); // prevent textarea blur
insertMention(item);
}}
onMouseEnter={() => setSelectedIndex(index)}
>
<div className="flex items-center gap-2">
<code className="text-xs bg-muted px-1 py-0.5 rounded font-mono">
{item.id}
</code>
<span className="font-medium truncate">{item.name}</span>
</div>
</button>
))}
</div>
)}
{showDropdown && filtered.length === 0 && items.length === 0 && (
<div className="absolute z-50 mt-1 w-full rounded-md border bg-popover text-popover-foreground shadow-md p-3 text-sm text-muted-foreground">
No recordings found. Upload recordings via the Recordings panel.
</div>
)}
</div>
);
}

View file

@ -3,9 +3,10 @@ import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-re
import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import type { DocumentResponseSchema, RecordingResponseSchema, ToolResponse } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { MentionTextarea } from "@/components/flow/MentionTextarea";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
@ -42,6 +43,7 @@ interface AgentNodeEditFormProps {
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
recordings: RecordingResponseSchema[];
}
interface AgentNodeProps extends NodeProps {
@ -50,7 +52,7 @@ interface AgentNodeProps extends NodeProps {
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
const { saveWorkflow, tools, documents } = useWorkflow();
const { saveWorkflow, tools, documents, recordings } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt);
@ -229,6 +231,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
recordings={recordings ?? []}
/>
)}
</NodeEditDialog>
@ -257,6 +260,7 @@ const AgentNodeEditForm = ({
setDocumentUuids,
tools,
documents,
recordings,
}: AgentNodeEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -318,13 +322,12 @@ const AgentNodeEditForm = ({
<Label className="text-xs text-muted-foreground">
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<Textarea
<MentionTextarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="min-h-[100px] max-h-[300px] resize-none"
style={{
overflowY: 'auto'
}}
onChange={setPrompt}
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
placeholder="Enter a prompt"
recordings={recordings}
/>
</div>

View file

@ -3,6 +3,8 @@ import { Edit, OctagonX, PlusIcon, Trash2Icon } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { RecordingResponseSchema } from "@/client/types.gen";
import { MentionTextarea } from "@/components/flow/MentionTextarea";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
@ -29,6 +31,7 @@ interface EndCallEditFormProps {
setVariables: (vars: ExtractionVariable[]) => void;
addGlobalPrompt: boolean;
setAddGlobalPrompt: (value: boolean) => void;
recordings: RecordingResponseSchema[];
}
interface EndCallNodeProps extends NodeProps {
@ -40,7 +43,7 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
id,
additionalData: { is_end: true }
});
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, recordings } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt);
@ -157,6 +160,7 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
setVariables={setVariables}
addGlobalPrompt={addGlobalPrompt}
setAddGlobalPrompt={setAddGlobalPrompt}
recordings={recordings ?? []}
/>
)}
</NodeEditDialog>
@ -177,6 +181,7 @@ const EndCallEditForm = ({
setVariables,
addGlobalPrompt,
setAddGlobalPrompt,
recordings,
}: EndCallEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -216,14 +221,12 @@ const EndCallEditForm = ({
<Label className="text-xs text-muted-foreground">
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<Textarea
<MentionTextarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="min-h-[100px] max-h-[300px] resize-none"
style={{
overflowY: 'auto'
}}
placeholder="Enter a dynamic prompt"
onChange={setPrompt}
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
placeholder="Enter a prompt"
recordings={recordings}
/>
<div className="flex items-center space-x-2">
<Switch id="add-global-prompt" checked={addGlobalPrompt} onCheckedChange={setAddGlobalPrompt} />

View file

@ -3,11 +3,12 @@ import { Edit, Headset, Trash2Icon } from "lucide-react";
import { memo, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { RecordingResponseSchema } from "@/client/types.gen";
import { MentionTextarea } from "@/components/flow/MentionTextarea";
import { FlowNodeData } from "@/components/flow/types";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Textarea } from "@/components/ui/textarea";
import { NODE_DOCUMENTATION_URLS } from "@/constants/documentation";
import { NodeContent } from "./common/NodeContent";
@ -20,6 +21,7 @@ interface GlobalNodeEditFormProps {
setPrompt: (value: string) => void;
name: string;
setName: (value: string) => void;
recordings: RecordingResponseSchema[];
}
interface GlobalNodeProps extends NodeProps {
@ -28,7 +30,7 @@ interface GlobalNodeProps extends NodeProps {
export const GlobalNode = memo(({ data, selected, id }: GlobalNodeProps) => {
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
const { saveWorkflow } = useWorkflow();
const { saveWorkflow, recordings } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt);
@ -118,6 +120,7 @@ export const GlobalNode = memo(({ data, selected, id }: GlobalNodeProps) => {
setPrompt={setPrompt}
name={name}
setName={setName}
recordings={recordings ?? []}
/>
)}
</NodeEditDialog>
@ -129,7 +132,8 @@ const GlobalNodeEditForm = ({
prompt,
setPrompt,
name,
setName
setName,
recordings,
}: GlobalNodeEditFormProps) => {
return (
<div className="grid gap-2">
@ -146,13 +150,12 @@ const GlobalNodeEditForm = ({
<Label className="text-xs text-muted-foreground">
This is the global prompt. This will be added to the system prompt of all the agents.
</Label>
<Textarea
<MentionTextarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="min-h-[100px] max-h-[300px] resize-none"
style={{
overflowY: 'auto'
}}
onChange={setPrompt}
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
placeholder="Enter a prompt"
recordings={recordings}
/>
</div>
);

View file

@ -4,8 +4,10 @@ import { memo, useCallback, useEffect, useMemo, useState } from "react";
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
import type { RecordingResponseSchema } from "@/client/types.gen";
import { DocumentBadges } from "@/components/flow/DocumentBadges";
import { DocumentSelector } from "@/components/flow/DocumentSelector";
import { MentionTextarea } from "@/components/flow/MentionTextarea";
import { ToolBadges } from "@/components/flow/ToolBadges";
import { ToolSelector } from "@/components/flow/ToolSelector";
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
@ -48,6 +50,7 @@ interface StartCallEditFormProps {
setDocumentUuids: (value: string[]) => void;
tools: ToolResponse[];
documents: DocumentResponseSchema[];
recordings: RecordingResponseSchema[];
}
interface StartCallNodeProps extends NodeProps {
@ -59,7 +62,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
id,
additionalData: { is_start: true }
});
const { saveWorkflow, tools, documents } = useWorkflow();
const { saveWorkflow, tools, documents, recordings } = useWorkflow();
// Form state
const [prompt, setPrompt] = useState(data.prompt ?? "");
@ -248,6 +251,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
setDocumentUuids={setDocumentUuids}
tools={tools ?? []}
documents={documents ?? []}
recordings={recordings ?? []}
/>
)}
</NodeEditDialog>
@ -282,6 +286,7 @@ const StartCallEditForm = ({
setDocumentUuids,
tools,
documents,
recordings,
}: StartCallEditFormProps) => {
const handleVariableNameChange = (idx: number, value: string) => {
const newVars = [...variables];
@ -325,14 +330,12 @@ const StartCallEditForm = ({
<Label className="text-xs text-muted-foreground">
Enter the prompt for the agent. This will be used to generate the agent&apos;s response. Prompt engineering&apos;s best practices apply.
</Label>
<Textarea
<MentionTextarea
value={prompt}
onChange={(e) => setPrompt(e.target.value)}
className="min-h-[100px] max-h-[300px] resize-none"
style={{
overflowY: 'auto'
}}
onChange={setPrompt}
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
placeholder="Enter a prompt"
recordings={recordings}
/>
<div className="flex items-center space-x-2">
<Switch id="allow-interrupt" checked={allowInterrupt} onCheckedChange={setAllowInterrupt} />