mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-13 08:15:21 +02:00
feat: add hybrid text + recording functionality in agents (#191)
* feat: add recording feature in agents * chore: pin pipecat version * feat: show usage in UI * chore: update pipecat
This commit is contained in:
parent
f075bcb623
commit
494c60d774
43 changed files with 2865 additions and 397 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -14,4 +14,5 @@ prd/
|
|||
venv/
|
||||
.venv/
|
||||
.playwright-mcp
|
||||
coturn/
|
||||
coturn/
|
||||
dograh_pcm_cache/
|
||||
102
api/alembic/versions/e54ddb048535_add_recording_table.py
Normal file
102
api/alembic/versions/e54ddb048535_add_recording_table.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""add recording table
|
||||
|
||||
Revision ID: e54ddb048535
|
||||
Revises: 6fd8fac02883
|
||||
Create Date: 2026-03-12 19:21:53.729888
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e54ddb048535"
|
||||
down_revision: Union[str, None] = "6fd8fac02883"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum("s3", "minio", name="recording_storage_backend").create(op.get_bind())
|
||||
op.create_table(
|
||||
"workflow_recordings",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("recording_id", sa.String(length=16), nullable=False),
|
||||
sa.Column("workflow_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tts_provider", sa.String(), nullable=False),
|
||||
sa.Column("tts_model", sa.String(), nullable=False),
|
||||
sa.Column("tts_voice_id", sa.String(), nullable=False),
|
||||
sa.Column("transcript", sa.Text(), nullable=False),
|
||||
sa.Column("storage_key", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"storage_backend",
|
||||
postgresql.ENUM(
|
||||
"s3", "minio", name="recording_storage_backend", create_type=False
|
||||
),
|
||||
server_default=sa.text("'s3'::recording_storage_backend"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"recording_metadata",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["workflow_id"], ["workflows.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_org_id",
|
||||
"workflow_recordings",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_workflow_recordings_recording_id"),
|
||||
"workflow_recordings",
|
||||
["recording_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_tts_scope",
|
||||
"workflow_recordings",
|
||||
["workflow_id", "tts_provider", "tts_model", "tts_voice_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_workflow_id",
|
||||
"workflow_recordings",
|
||||
["workflow_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_workflow_recordings_workflow_id", table_name="workflow_recordings"
|
||||
)
|
||||
op.drop_index("ix_workflow_recordings_tts_scope", table_name="workflow_recordings")
|
||||
op.drop_index(
|
||||
op.f("ix_workflow_recordings_recording_id"), table_name="workflow_recordings"
|
||||
)
|
||||
op.drop_index("ix_workflow_recordings_org_id", table_name="workflow_recordings")
|
||||
op.drop_table("workflow_recordings")
|
||||
sa.Enum("s3", "minio", name="recording_storage_backend").drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -526,7 +526,7 @@ class CampaignClient(BaseDBClient):
|
|||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.is_(None),
|
||||
)
|
||||
.order_by(QueuedRunModel.created_at)
|
||||
.order_by(func.random())
|
||||
.limit(remaining_slots)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from api.db.tool_client import ToolClient
|
|||
from api.db.user_client import UserClient
|
||||
from api.db.webhook_credential_client import WebhookCredentialClient
|
||||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_recording_client import WorkflowRecordingClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ class DBClient(
|
|||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
KnowledgeBaseClient,
|
||||
WorkflowRecordingClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
|
|||
|
|
@ -996,6 +996,77 @@ class KnowledgeBaseDocumentModel(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRecordingModel(Base):
|
||||
"""Model for storing audio recordings scoped to a workflow and TTS configuration.
|
||||
|
||||
Recordings are used in hybrid prompts where parts of the output are pre-recorded
|
||||
audio rather than dynamically generated TTS.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_recordings"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Short globally unique ID (e.g. "xbhfha3k") used in prompts
|
||||
recording_id = Column(String(16), unique=True, nullable=False, index=True)
|
||||
|
||||
# Scoping
|
||||
workflow_id = Column(
|
||||
Integer, ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# TTS configuration scope
|
||||
tts_provider = Column(String, nullable=False)
|
||||
tts_model = Column(String, nullable=False)
|
||||
tts_voice_id = Column(String, nullable=False)
|
||||
|
||||
# Content
|
||||
transcript = Column(Text, nullable=False)
|
||||
|
||||
# Storage
|
||||
storage_key = Column(String, nullable=False)
|
||||
storage_backend = Column(
|
||||
Enum("s3", "minio", name="recording_storage_backend"),
|
||||
nullable=False,
|
||||
default="s3",
|
||||
server_default=text("'s3'::recording_storage_backend"),
|
||||
)
|
||||
|
||||
# Extra metadata (file_size_bytes, duration_seconds, original_filename, mime_type, etc.)
|
||||
recording_metadata = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Audit
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Soft delete
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
workflow = relationship("WorkflowModel")
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_workflow_recordings_workflow_id", "workflow_id"),
|
||||
Index("ix_workflow_recordings_org_id", "organization_id"),
|
||||
Index("ix_workflow_recordings_recording_id", "recording_id"),
|
||||
Index(
|
||||
"ix_workflow_recordings_tts_scope",
|
||||
"workflow_id",
|
||||
"tts_provider",
|
||||
"tts_model",
|
||||
"tts_voice_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseChunkModel(Base):
|
||||
"""Model for storing document chunks with vector embeddings.
|
||||
|
||||
|
|
|
|||
218
api/db/workflow_recording_client.py
Normal file
218
api/db/workflow_recording_client.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""Database client for managing workflow recordings."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowRecordingModel
|
||||
|
||||
|
||||
def generate_short_id(length: int = 8) -> str:
|
||||
"""Generate a random lowercase alphanumeric short ID."""
|
||||
alphabet = string.ascii_lowercase + string.digits
|
||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
class WorkflowRecordingClient(BaseDBClient):
|
||||
"""Client for managing workflow audio recordings."""
|
||||
|
||||
async def create_recording(
|
||||
self,
|
||||
recording_id: str,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
tts_provider: str,
|
||||
tts_model: str,
|
||||
tts_voice_id: str,
|
||||
transcript: str,
|
||||
storage_key: str,
|
||||
storage_backend: str,
|
||||
created_by: int,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> WorkflowRecordingModel:
|
||||
"""Create a new workflow recording record.
|
||||
|
||||
Args:
|
||||
recording_id: Short unique recording identifier
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
tts_provider: TTS provider name
|
||||
tts_model: TTS model name
|
||||
tts_voice_id: TTS voice identifier
|
||||
transcript: User-provided transcript
|
||||
storage_key: S3/MinIO storage key
|
||||
storage_backend: Storage backend (s3 or minio)
|
||||
created_by: ID of the user
|
||||
metadata: Optional extra metadata
|
||||
|
||||
Returns:
|
||||
The created WorkflowRecordingModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
recording = WorkflowRecordingModel(
|
||||
recording_id=recording_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
tts_provider=tts_provider,
|
||||
tts_model=tts_model,
|
||||
tts_voice_id=tts_voice_id,
|
||||
transcript=transcript,
|
||||
storage_key=storage_key,
|
||||
storage_backend=storage_backend,
|
||||
created_by=created_by,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
session.add(recording)
|
||||
await session.commit()
|
||||
await session.refresh(recording)
|
||||
|
||||
logger.info(
|
||||
f"Created recording {recording_id} for workflow {workflow_id}, "
|
||||
f"org {organization_id}"
|
||||
)
|
||||
return recording
|
||||
|
||||
async def get_recordings_for_workflow(
|
||||
self,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
tts_provider: Optional[str] = None,
|
||||
tts_model: Optional[str] = None,
|
||||
tts_voice_id: Optional[str] = None,
|
||||
) -> List[WorkflowRecordingModel]:
|
||||
"""Get recordings for a workflow, optionally filtered by TTS config.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
tts_provider: Optional TTS provider filter
|
||||
tts_model: Optional TTS model filter
|
||||
tts_voice_id: Optional TTS voice ID filter
|
||||
|
||||
Returns:
|
||||
List of WorkflowRecordingModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.workflow_id == workflow_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
|
||||
if tts_provider:
|
||||
query = query.where(WorkflowRecordingModel.tts_provider == tts_provider)
|
||||
if tts_model:
|
||||
query = query.where(WorkflowRecordingModel.tts_model == tts_model)
|
||||
if tts_voice_id:
|
||||
query = query.where(WorkflowRecordingModel.tts_voice_id == tts_voice_id)
|
||||
|
||||
query = query.order_by(WorkflowRecordingModel.created_at.desc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_recording_by_recording_id(
|
||||
self,
|
||||
recording_id: str,
|
||||
organization_id: int,
|
||||
) -> Optional[WorkflowRecordingModel]:
|
||||
"""Get a recording by its short ID.
|
||||
|
||||
Args:
|
||||
recording_id: The short unique recording ID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
WorkflowRecordingModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def has_active_recordings(
|
||||
self,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Check if a workflow has any active recordings.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if at least one active recording exists, False otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRecordingModel)
|
||||
.where(
|
||||
WorkflowRecordingModel.workflow_id == workflow_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one() > 0
|
||||
|
||||
async def check_recording_id_exists(self, recording_id: str) -> bool:
|
||||
"""Check if a recording ID already exists globally.
|
||||
|
||||
Args:
|
||||
recording_id: The short recording ID to check
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel.id).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def delete_recording(
|
||||
self,
|
||||
recording_id: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a recording.
|
||||
|
||||
Args:
|
||||
recording_id: The short recording ID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
recording = result.scalar_one_or_none()
|
||||
|
||||
if not recording:
|
||||
return False
|
||||
|
||||
recording.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted recording {recording_id} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
|
@ -24,6 +24,7 @@ from api.routes.user import router as user_router
|
|||
from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
||||
from api.routes.workflow import router as workflow_router
|
||||
from api.routes.workflow_embed import router as workflow_embed_router
|
||||
from api.routes.workflow_recording import router as workflow_recording_router
|
||||
|
||||
router = APIRouter(
|
||||
tags=["main"],
|
||||
|
|
@ -51,6 +52,7 @@ router.include_router(public_agent_router)
|
|||
router.include_router(public_download_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
router.include_router(workflow_recording_router)
|
||||
router.include_router(auth_router)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@ from datetime import datetime, timedelta
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
router = APIRouter(prefix="/organizations")
|
||||
|
||||
|
|
@ -28,6 +31,12 @@ class CurrentUsageResponse(BaseModel):
|
|||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
||||
class MPSCreditsResponse(BaseModel):
|
||||
total_credits_used: float
|
||||
remaining_credits: float
|
||||
total_quota: float
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -85,6 +94,42 @@ async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/usage/mps-credits", response_model=MPSCreditsResponse)
|
||||
async def get_mps_credits(user: UserModel = Depends(get_user)):
|
||||
"""Get aggregated usage and quota from MPS.
|
||||
|
||||
OSS users: queries by provider_id (created_by).
|
||||
Hosted users: queries by organization_id.
|
||||
"""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected"
|
||||
)
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = usage.get("total_credits_used", 0.0)
|
||||
total_remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
return MPSCreditsResponse(
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch MPS credits: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/usage/runs", response_model=UsageHistoryResponse)
|
||||
async def get_usage_history(
|
||||
start_date: Optional[str] = Query(None, description="ISO format date string"),
|
||||
|
|
|
|||
218
api/routes/workflow_recording.py
Normal file
218
api/routes/workflow_recording.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""API routes for workflow recording operations."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.workflow_recording_client import generate_short_id
|
||||
from api.enums import StorageBackend
|
||||
from api.schemas.workflow_recording import (
|
||||
RecordingCreateRequestSchema,
|
||||
RecordingListResponseSchema,
|
||||
RecordingResponseSchema,
|
||||
RecordingUploadRequestSchema,
|
||||
RecordingUploadResponseSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
router = APIRouter(prefix="/workflow-recordings", tags=["workflow-recordings"])
|
||||
|
||||
|
||||
async def _generate_unique_recording_id() -> str:
|
||||
"""Generate a globally unique short recording ID."""
|
||||
for _ in range(10):
|
||||
rid = generate_short_id(8)
|
||||
exists = await db_client.check_recording_id_exists(rid)
|
||||
if not exists:
|
||||
return rid
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate unique recording ID"
|
||||
)
|
||||
|
||||
|
||||
def _build_response(rec) -> RecordingResponseSchema:
|
||||
return RecordingResponseSchema(
|
||||
id=rec.id,
|
||||
recording_id=rec.recording_id,
|
||||
workflow_id=rec.workflow_id,
|
||||
organization_id=rec.organization_id,
|
||||
tts_provider=rec.tts_provider,
|
||||
tts_model=rec.tts_model,
|
||||
tts_voice_id=rec.tts_voice_id,
|
||||
transcript=rec.transcript,
|
||||
storage_key=rec.storage_key,
|
||||
storage_backend=rec.storage_backend,
|
||||
metadata=rec.recording_metadata or {},
|
||||
created_by=rec.created_by,
|
||||
created_at=rec.created_at,
|
||||
is_active=rec.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=RecordingUploadResponseSchema,
|
||||
summary="Get presigned URL for recording upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: RecordingUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading an audio recording."""
|
||||
try:
|
||||
recording_id = await _generate_unique_recording_id()
|
||||
|
||||
storage_key = (
|
||||
f"recordings/{user.selected_organization_id}"
|
||||
f"/{request.workflow_id}/{recording_id}"
|
||||
f"/{request.filename}"
|
||||
)
|
||||
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=storage_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=5_242_880, # 5MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated recording upload URL: {recording_id}, "
|
||||
f"workflow {request.workflow_id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return RecordingUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
recording_id=recording_id,
|
||||
storage_key=storage_key,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating recording upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=RecordingResponseSchema,
|
||||
summary="Create recording record after upload",
|
||||
)
|
||||
async def create_recording(
|
||||
request: RecordingCreateRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Create a recording record after the audio has been uploaded to storage."""
|
||||
try:
|
||||
backend = StorageBackend.get_current_backend()
|
||||
|
||||
recording = await db_client.create_recording(
|
||||
recording_id=request.recording_id,
|
||||
workflow_id=request.workflow_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
tts_provider=request.tts_provider,
|
||||
tts_model=request.tts_model,
|
||||
tts_voice_id=request.tts_voice_id,
|
||||
transcript=request.transcript,
|
||||
storage_key=request.storage_key,
|
||||
storage_backend=backend.value,
|
||||
created_by=user.id,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created recording {request.recording_id} for workflow {request.workflow_id}"
|
||||
)
|
||||
|
||||
return _build_response(recording)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error creating recording: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create recording"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/",
|
||||
response_model=RecordingListResponseSchema,
|
||||
summary="List recordings for a workflow",
|
||||
)
|
||||
async def list_recordings(
|
||||
workflow_id: Annotated[int, Query(description="Workflow ID")],
|
||||
tts_provider: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS provider")
|
||||
] = None,
|
||||
tts_model: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS model")
|
||||
] = None,
|
||||
tts_voice_id: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS voice ID")
|
||||
] = None,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List recordings for a workflow, optionally filtered by TTS configuration."""
|
||||
try:
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
tts_provider=tts_provider,
|
||||
tts_model=tts_model,
|
||||
tts_voice_id=tts_voice_id,
|
||||
)
|
||||
|
||||
return RecordingListResponseSchema(
|
||||
recordings=[_build_response(r) for r in recordings],
|
||||
total=len(recordings),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing recordings: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list recordings"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{recording_id}",
|
||||
summary="Delete a recording",
|
||||
)
|
||||
async def delete_recording(
|
||||
recording_id: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a recording."""
|
||||
try:
|
||||
success = await db_client.delete_recording(
|
||||
recording_id=recording_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted recording {recording_id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Recording deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting recording: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete recording"
|
||||
) from exc
|
||||
73
api/schemas/workflow_recording.py
Normal file
73
api/schemas/workflow_recording.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Pydantic schemas for workflow recording operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecordingUploadRequestSchema(BaseModel):
|
||||
"""Request schema for getting a presigned upload URL."""
|
||||
|
||||
workflow_id: int = Field(..., description="Workflow ID this recording belongs to")
|
||||
filename: str = Field(..., description="Original filename of the audio file")
|
||||
mime_type: str = Field(
|
||||
default="audio/wav", description="MIME type of the audio file"
|
||||
)
|
||||
file_size: int = Field(
|
||||
...,
|
||||
gt=0,
|
||||
le=5_242_880,
|
||||
description="File size in bytes (max 5MB)",
|
||||
)
|
||||
|
||||
|
||||
class RecordingUploadResponseSchema(BaseModel):
|
||||
"""Response schema with presigned upload URL."""
|
||||
|
||||
upload_url: str = Field(..., description="Presigned URL for uploading the audio")
|
||||
recording_id: str = Field(..., description="Short unique recording ID")
|
||||
storage_key: str = Field(..., description="Storage key where file will be uploaded")
|
||||
|
||||
|
||||
class RecordingCreateRequestSchema(BaseModel):
|
||||
"""Request schema for creating a recording record after upload."""
|
||||
|
||||
recording_id: str = Field(..., description="Short recording ID from upload step")
|
||||
workflow_id: int = Field(..., description="Workflow ID")
|
||||
tts_provider: str = Field(..., description="TTS provider (e.g. elevenlabs)")
|
||||
tts_model: str = Field(..., description="TTS model name")
|
||||
tts_voice_id: str = Field(..., description="TTS voice identifier")
|
||||
transcript: str = Field(
|
||||
..., description="User-provided transcript of the recording"
|
||||
)
|
||||
storage_key: str = Field(..., description="Storage key from upload step")
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Optional metadata (file_size, duration, etc.)"
|
||||
)
|
||||
|
||||
|
||||
class RecordingResponseSchema(BaseModel):
|
||||
"""Response schema for a single recording."""
|
||||
|
||||
id: int
|
||||
recording_id: str
|
||||
workflow_id: int
|
||||
organization_id: int
|
||||
tts_provider: str
|
||||
tts_model: str
|
||||
tts_voice_id: str
|
||||
transcript: str
|
||||
storage_key: str
|
||||
storage_backend: str
|
||||
metadata: Dict[str, Any]
|
||||
created_by: int
|
||||
created_at: datetime
|
||||
is_active: bool
|
||||
|
||||
|
||||
class RecordingListResponseSchema(BaseModel):
|
||||
"""Response schema for list of recordings."""
|
||||
|
||||
recordings: List[RecordingResponseSchema]
|
||||
total: int
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional, TypedDict
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import openai
|
||||
from deepgram import DeepgramClient
|
||||
|
|
@ -12,6 +12,7 @@ from api.schemas.user_configuration import (
|
|||
UserConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
class APIKeyStatus(TypedDict):
|
||||
|
|
@ -25,7 +26,6 @@ class APIKeyStatusResponse(TypedDict):
|
|||
|
||||
class UserConfigurationValidator:
|
||||
def __init__(self):
|
||||
self._provider_api_key_validity_status: Dict[str, bool] = {}
|
||||
self._validator_map = {
|
||||
ServiceProviders.OPENAI.value: self._check_openai_api_key,
|
||||
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
|
||||
|
|
@ -73,8 +73,13 @@ class UserConfigurationValidator:
|
|||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
||||
if not self._check_api_key(provider, api_key):
|
||||
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
|
||||
try:
|
||||
if not self._check_api_key(provider, api_key):
|
||||
return [
|
||||
{"model": service_name, "message": f"Invalid {provider} API key"}
|
||||
]
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
|
||||
return []
|
||||
|
||||
|
|
@ -87,40 +92,28 @@ class UserConfigurationValidator:
|
|||
return validator(provider, api_key)
|
||||
|
||||
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except openai.AuthenticationError:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
try:
|
||||
deepgram = DeepgramClient(api_key=api_key)
|
||||
deepgram.manage.v1.projects.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except Exception:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = Groq(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except Exception:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
@ -135,7 +128,12 @@ class UserConfigurationValidator:
|
|||
return True
|
||||
|
||||
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
if api_key.startswith("dgr"):
|
||||
raise ValueError(
|
||||
"You provided a Dograh API key (dgr...) instead of a service key. "
|
||||
"Please use a service key (mps...)."
|
||||
)
|
||||
return mps_service_key_client.validate_service_key(api_key)
|
||||
|
||||
def _check_sarvam_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -285,6 +285,90 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def get_usage_by_created_by(self, created_by: str) -> dict:
|
||||
"""
|
||||
Get aggregated usage for all service keys created by a user (OSS mode).
|
||||
|
||||
Args:
|
||||
created_by: The user's provider ID
|
||||
|
||||
Returns:
|
||||
Dictionary containing total_credits_used and remaining_credits
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage/created-by",
|
||||
json={"created_by": created_by},
|
||||
headers=self._get_headers(created_by=created_by),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return {
|
||||
"total_credits_used": data.get("total_credits_used", 0.0),
|
||||
"remaining_credits": data.get("remaining_credits", 0.0),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get usage by created_by: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get usage by created_by: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_usage_by_organization(self, organization_id: int) -> dict:
|
||||
"""
|
||||
Get aggregated usage for all service keys belonging to an organization (hosted mode).
|
||||
|
||||
Args:
|
||||
organization_id: The organization's ID
|
||||
|
||||
Returns:
|
||||
Dictionary containing total_credits_used and remaining_credits
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage/organization",
|
||||
json={"organization_id": organization_id},
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return {
|
||||
"total_credits_used": data.get("total_credits_used", 0.0),
|
||||
"remaining_credits": data.get("remaining_credits", 0.0),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get usage by organization: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get usage by organization: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def validate_service_key(self, service_key: str) -> bool:
|
||||
"""
|
||||
Synchronously validate a Dograh service key by checking usage via MPS.
|
||||
|
||||
Returns True if the key is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.timeout) as client:
|
||||
response = client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage",
|
||||
json={"service_key": service_key},
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
logger.warning("Failed to validate Dograh service key via MPS")
|
||||
return False
|
||||
|
||||
async def get_voices(
|
||||
self,
|
||||
provider: str,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def build_pipeline(
|
|||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=None,
|
||||
recording_router=None,
|
||||
):
|
||||
"""Build the main pipeline with all components.
|
||||
|
||||
|
|
@ -47,6 +48,9 @@ def build_pipeline(
|
|||
voicemail_detector: Optional native pipecat VoicemailDetector. When provided,
|
||||
inserts voicemail detection after STT. Note: We don't use the TTS gate
|
||||
to avoid blocking TTS frames during classification.
|
||||
recording_router: Optional RecordingRouterProcessor. When provided,
|
||||
inserts between callback processor and TTS to route between
|
||||
pre-recorded audio playback and dynamic TTS.
|
||||
"""
|
||||
# Build processors list with optional voicemail detection
|
||||
processors = [
|
||||
|
|
@ -66,11 +70,15 @@ def build_pipeline(
|
|||
processors.append(voicemail_detector.detector())
|
||||
|
||||
# Continue with the rest of the pipeline
|
||||
post_llm = [pipeline_engine_callback_processor]
|
||||
if recording_router:
|
||||
post_llm.append(recording_router)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
pipeline_engine_callback_processor,
|
||||
*post_llm,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audio_buffer, # AudioBufferProcessor - records both input and output audio
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ from pipecat.frames.frames import (
|
|||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMTextFrame,
|
||||
MetricsFrame,
|
||||
StopFrame,
|
||||
TranscriptionFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
|
|
@ -207,7 +207,7 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
)
|
||||
# Handle bot TTS text - respect pts timing, WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
message = {
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {
|
||||
|
|
|
|||
338
api/services/pipecat/recording_audio_cache.py
Normal file
338
api/services/pipecat/recording_audio_cache.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""Filesystem-backed cache and audio fetcher for workflow recordings.
|
||||
|
||||
Downloads recording files from object storage on first access, converts them
|
||||
to raw 16-bit mono PCM at the pipeline sample rate via ffmpeg, trims
|
||||
leading/trailing silence, and caches the processed bytes on disk so
|
||||
subsequent plays (even from other workers) are instantaneous.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import APP_ROOT_DIR
|
||||
from pipecat.audio.utils import SPEAKING_THRESHOLD
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filesystem cache directory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
|
||||
os.makedirs(_CACHE_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def _cache_path(recording_id: str, sample_rate: int) -> str:
|
||||
"""Return the on-disk path for a cached PCM file."""
|
||||
return os.path.join(_CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_recording_audio_fetcher(
|
||||
organization_id: int,
|
||||
pipeline_sample_rate: int,
|
||||
) -> Callable[[str], Awaitable[Optional[bytes]]]:
|
||||
"""Create an async callback that returns raw PCM bytes for a recording_id.
|
||||
|
||||
The returned callable:
|
||||
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
|
||||
2. On miss, looks up the recording in the DB, downloads the audio file
|
||||
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
|
||||
trims leading/trailing silence, caches the result on disk, and returns it.
|
||||
|
||||
Args:
|
||||
organization_id: Organization owning the recordings.
|
||||
pipeline_sample_rate: Target PCM sample rate for the pipeline.
|
||||
|
||||
Returns:
|
||||
``async (recording_id: str) -> Optional[bytes]``
|
||||
"""
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_storage_for_backend
|
||||
|
||||
# Resolve storage instances once per backend at creation time, not per fetch.
|
||||
_storage_cache: dict[str, object] = {}
|
||||
|
||||
def _get_storage(backend: str):
|
||||
if backend not in _storage_cache:
|
||||
_storage_cache[backend] = get_storage_for_backend(backend)
|
||||
return _storage_cache[backend]
|
||||
|
||||
async def fetch(recording_id: str) -> Optional[bytes]:
|
||||
cached = _cache_path(recording_id, pipeline_sample_rate)
|
||||
|
||||
# 1. Serve from filesystem cache
|
||||
if os.path.exists(cached):
|
||||
logger.debug(f"Recording {recording_id} served from disk cache")
|
||||
return _read_file(cached)
|
||||
|
||||
# 2. DB lookup
|
||||
recording = await db_client.get_recording_by_recording_id(
|
||||
recording_id, organization_id
|
||||
)
|
||||
if not recording:
|
||||
logger.warning(f"Recording {recording_id} not found in database")
|
||||
return None
|
||||
|
||||
# 3. Download, convert, trim, and cache
|
||||
pcm_data = await _download_and_convert(
|
||||
recording, pipeline_sample_rate, _get_storage
|
||||
)
|
||||
return pcm_data
|
||||
|
||||
return fetch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache warming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def warm_recording_cache(
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
pipeline_sample_rate: int,
|
||||
) -> None:
|
||||
"""Pre-fetch all active recordings for a workflow into the disk cache.
|
||||
|
||||
Launched as a background ``asyncio.Task`` at pipeline startup so that
|
||||
recordings are ready before the first playback request. Errors are logged
|
||||
but never propagated — a cache miss falls back to the on-demand fetch path.
|
||||
"""
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_storage_for_backend
|
||||
|
||||
try:
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
workflow_id, organization_id
|
||||
)
|
||||
if not recordings:
|
||||
return
|
||||
|
||||
# Skip if every recording is already cached on disk
|
||||
uncached = [
|
||||
r
|
||||
for r in recordings
|
||||
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
|
||||
]
|
||||
if not uncached:
|
||||
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Warming recording cache: {len(uncached)}/{len(recordings)} "
|
||||
f"recording(s) for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
# Resolve storage instances once per backend, not per recording
|
||||
storage_by_backend: dict[str, object] = {}
|
||||
|
||||
def _get_storage(backend: str):
|
||||
if backend not in storage_by_backend:
|
||||
storage_by_backend[backend] = get_storage_for_backend(backend)
|
||||
return storage_by_backend[backend]
|
||||
|
||||
for recording in uncached:
|
||||
try:
|
||||
pcm_data = await _download_and_convert(
|
||||
recording, pipeline_sample_rate, _get_storage
|
||||
)
|
||||
if pcm_data:
|
||||
logger.debug(
|
||||
f"Cache warm: loaded {recording.recording_id} "
|
||||
f"({len(pcm_data)} bytes)"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Cache warm: error processing {recording.recording_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Recording cache warm complete for workflow {workflow_id}")
|
||||
except Exception:
|
||||
logger.exception("Recording cache warm failed")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared download → convert → trim → cache-to-disk helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _download_and_convert(
|
||||
recording, sample_rate: int, get_storage_fn
|
||||
) -> Optional[bytes]:
|
||||
"""Download a recording from storage, convert to PCM, trim, and cache to disk.
|
||||
|
||||
Returns the processed PCM bytes, or None on failure.
|
||||
"""
|
||||
ext = _ext_from_key(recording.storage_key)
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
|
||||
os.close(fd)
|
||||
try:
|
||||
storage = get_storage_fn(recording.storage_backend)
|
||||
success = await storage.adownload_file(recording.storage_key, tmp_path)
|
||||
if not success:
|
||||
logger.error(f"Failed to download recording {recording.recording_id}")
|
||||
return None
|
||||
|
||||
pcm_data = await _audio_file_to_pcm(tmp_path, sample_rate)
|
||||
if pcm_data is None:
|
||||
return None
|
||||
|
||||
pcm_data = _trim_silence(pcm_data, sample_rate)
|
||||
|
||||
# Write to disk cache atomically (write to tmp then rename)
|
||||
cached = _cache_path(recording.recording_id, sample_rate)
|
||||
fd, tmp_cache = tempfile.mkstemp(dir=_CACHE_DIR, suffix=".pcm.tmp")
|
||||
os.close(fd)
|
||||
_write_file(tmp_cache, pcm_data)
|
||||
os.replace(tmp_cache, cached)
|
||||
|
||||
return pcm_data
|
||||
except Exception:
|
||||
logger.exception(f"Error fetching recording {recording.recording_id}")
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File I/O helpers (run via asyncio.to_thread)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read_file(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _write_file(path: str, data: bytes) -> None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _audio_file_to_pcm(
|
||||
file_path: str, target_sample_rate: int
|
||||
) -> Optional[bytes]:
|
||||
"""Convert an audio file to raw 16-bit mono PCM bytes via ffmpeg."""
|
||||
ffmpeg = shutil.which("ffmpeg")
|
||||
if not ffmpeg:
|
||||
logger.error("ffmpeg not found on PATH — cannot decode recording")
|
||||
return None
|
||||
|
||||
cmd = [
|
||||
ffmpeg,
|
||||
"-i",
|
||||
file_path,
|
||||
"-f",
|
||||
"s16le", # raw 16-bit signed little-endian PCM
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ac",
|
||||
"1", # mono
|
||||
"-ar",
|
||||
str(target_sample_rate),
|
||||
"-loglevel",
|
||||
"error",
|
||||
"pipe:1", # output to stdout
|
||||
]
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
|
||||
return None
|
||||
|
||||
if not stdout:
|
||||
logger.error("ffmpeg produced no output")
|
||||
return None
|
||||
|
||||
return stdout
|
||||
except Exception:
|
||||
logger.exception("ffmpeg subprocess error")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Silence trimming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _trim_silence(pcm_data: bytes, sample_rate: int) -> bytes:
|
||||
"""Trim leading and trailing silence from raw 16-bit mono PCM bytes.
|
||||
|
||||
Uses 10ms frames and the same amplitude threshold as pipecat's
|
||||
``is_silence`` to detect speech boundaries.
|
||||
"""
|
||||
data = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
frame_size = int(sample_rate * 0.01) # 10ms frames
|
||||
num_frames = len(data) // frame_size
|
||||
|
||||
if num_frames == 0:
|
||||
return pcm_data
|
||||
|
||||
# Find first non-silent frame
|
||||
first_speech = None
|
||||
for i in range(num_frames):
|
||||
frame = data[i * frame_size : (i + 1) * frame_size]
|
||||
if np.abs(frame).max() > SPEAKING_THRESHOLD:
|
||||
first_speech = i
|
||||
break
|
||||
|
||||
if first_speech is None:
|
||||
# Entire clip is silence — return as-is to avoid empty audio
|
||||
return pcm_data
|
||||
|
||||
# Find last non-silent frame
|
||||
last_speech = first_speech
|
||||
for i in range(num_frames - 1, first_speech - 1, -1):
|
||||
frame = data[i * frame_size : (i + 1) * frame_size]
|
||||
if np.abs(frame).max() > SPEAKING_THRESHOLD:
|
||||
last_speech = i
|
||||
break
|
||||
|
||||
start = first_speech * frame_size
|
||||
end = (last_speech + 1) * frame_size
|
||||
trimmed = data[start:end]
|
||||
|
||||
trimmed_duration = len(trimmed) / sample_rate
|
||||
original_duration = len(data) / sample_rate
|
||||
if original_duration - trimmed_duration > 0.05:
|
||||
logger.debug(
|
||||
f"Trimmed silence: {original_duration:.2f}s → {trimmed_duration:.2f}s"
|
||||
)
|
||||
|
||||
return trimmed.tobytes()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ext_from_key(storage_key: str) -> str:
|
||||
"""Extract file extension from a storage key, defaulting to .wav."""
|
||||
_, ext = os.path.splitext(storage_key)
|
||||
return ext if ext else ".wav"
|
||||
254
api/services/pipecat/recording_router_processor.py
Normal file
254
api/services/pipecat/recording_router_processor.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Recording router processor for routing LLM output between TTS and pre-recorded audio.
|
||||
|
||||
Sits between the LLM (after pipeline_engine_callbacks_processor) and TTS in the
|
||||
pipeline. Detects response mode markers (▸ for TTS, ● for recording) and routes
|
||||
accordingly:
|
||||
|
||||
- ▸ (TTS): Strips the marker, passes remaining text downstream to TTS.
|
||||
- ● (Recording): Suppresses TTS, fetches cached audio, pushes
|
||||
OutputAudioRawFrame downstream.
|
||||
|
||||
Pattern modelled after ``pipecat.turns.user_turn_completion_mixin`` – buffer
|
||||
streaming LLM text tokens until the mode marker is detected, then act.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
RECORDING_MARKER,
|
||||
TTS_MARKER,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMTextFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class RecordingRouterProcessor(FrameProcessor):
|
||||
"""Routes LLM responses between TTS and pre-recorded audio playback.
|
||||
|
||||
When the LLM prefixes its response with:
|
||||
- ``▸`` – text flows to TTS as normal speech.
|
||||
- ``●`` – text is suppressed (skip_tts), and the referenced recording is
|
||||
fetched (with local disk cache) and streamed as ``OutputAudioRawFrame``.
|
||||
|
||||
If no marker is detected by the end of the response, text is passed through
|
||||
to TTS as a graceful degradation.
|
||||
|
||||
Args:
|
||||
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
|
||||
fetch_recording_audio: Async callback that takes a recording_id and
|
||||
returns raw 16-bit mono PCM bytes, or None on failure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_sample_rate: int,
|
||||
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._audio_sample_rate = audio_sample_rate
|
||||
self._fetch_recording_audio = fetch_recording_audio
|
||||
|
||||
# Per-response state
|
||||
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
|
||||
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
|
||||
self._recording_id_buffer = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_response_end(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LLMTextFrame handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_llm_text(self, frame: LLMTextFrame, direction: FrameDirection):
|
||||
# Pass through frames already marked skip_tts (e.g. turn completion ✓)
|
||||
if frame.skip_tts:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- TTS mode established: pass text through normally ---
|
||||
if self._mode == "tts":
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- Recording mode: buffer recording_id, suppress TTS ---
|
||||
if self._mode == "recording":
|
||||
self._recording_id_buffer += frame.text
|
||||
frame.skip_tts = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- Detection mode: buffer until marker found ---
|
||||
self._frame_buffer.append((frame, direction))
|
||||
buffered_text = self._buffered_text()
|
||||
|
||||
# Check for recording marker (●)
|
||||
if RECORDING_MARKER in buffered_text:
|
||||
self._mode = "recording"
|
||||
marker_end = buffered_text.index(RECORDING_MARKER) + len(RECORDING_MARKER)
|
||||
|
||||
# Push buffered frames with skip_tts, extract recording_id from post-marker text
|
||||
cumulative = 0
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
buf_frame.skip_tts = True
|
||||
frame_start = cumulative
|
||||
cumulative += len(buf_frame.text)
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
# Capture any recording_id text after the marker
|
||||
if cumulative > marker_end:
|
||||
offset = max(marker_end - frame_start, 0)
|
||||
remaining = buf_frame.text[offset:]
|
||||
if not self._recording_id_buffer and remaining.startswith(" "):
|
||||
remaining = remaining[1:]
|
||||
self._recording_id_buffer += remaining
|
||||
|
||||
self._frame_buffer = []
|
||||
return
|
||||
|
||||
# Check for TTS marker (▸)
|
||||
if TTS_MARKER in buffered_text:
|
||||
self._mode = "tts"
|
||||
marker_end = buffered_text.index(TTS_MARKER) + len(TTS_MARKER)
|
||||
|
||||
# Push buffered frames — skip_tts for marker portion, normal for the rest
|
||||
cumulative = 0
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
frame_start = cumulative
|
||||
cumulative += len(buf_frame.text)
|
||||
|
||||
if cumulative <= marker_end:
|
||||
# Entirely within marker portion — suppress TTS
|
||||
buf_frame.skip_tts = True
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
elif frame_start >= marker_end:
|
||||
# Entirely after marker — normal TTS speech
|
||||
if frame_start == marker_end and buf_frame.text.startswith(" "):
|
||||
buf_frame.text = buf_frame.text[1:]
|
||||
if buf_frame.text:
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
else:
|
||||
# Frame spans the marker boundary — split
|
||||
offset = marker_end - frame_start
|
||||
original_text = buf_frame.text
|
||||
buf_frame.text = original_text[:offset]
|
||||
buf_frame.skip_tts = True
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
tts_text = original_text[offset:]
|
||||
if tts_text.startswith(" "):
|
||||
tts_text = tts_text[1:]
|
||||
if tts_text:
|
||||
await self.push_frame(LLMTextFrame(tts_text), buf_dir)
|
||||
|
||||
self._frame_buffer = []
|
||||
return
|
||||
|
||||
# Neither marker found yet — keep buffering (should arrive very soon)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# End-of-response handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_response_end(
|
||||
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
|
||||
):
|
||||
if self._mode == "recording":
|
||||
recording_id = self._recording_id_buffer.strip()
|
||||
if recording_id:
|
||||
await self._play_recording(recording_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"RecordingRouterProcessor: recording mode but empty recording_id"
|
||||
)
|
||||
|
||||
elif self._mode is None and self._frame_buffer:
|
||||
# Graceful degradation: no marker detected, pass text to TTS as-is
|
||||
logger.warning(
|
||||
"RecordingRouterProcessor: no response mode marker found, "
|
||||
"passing text to TTS as-is"
|
||||
)
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Audio playback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _play_recording(self, recording_id: str):
|
||||
"""Fetch recording audio and push TTSStarted → TTSAudioRaw → TTSStopped.
|
||||
|
||||
The transport handles chunking automatically. The Started/Stopped
|
||||
frames ensure downstream processors (transport, audio buffer, observers)
|
||||
treat this as a proper TTS utterance.
|
||||
"""
|
||||
logger.info(f"Playing pre-recorded audio: {recording_id}")
|
||||
|
||||
audio_data = await self._fetch_recording_audio(recording_id)
|
||||
if not audio_data:
|
||||
logger.warning(
|
||||
f"Failed to fetch recording {recording_id}, no audio will play"
|
||||
)
|
||||
return
|
||||
|
||||
context_id = str(uuid.uuid4())
|
||||
await self.push_frame(TTSStartedFrame(context_id=context_id))
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=self._audio_sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self.push_frame(TTSStoppedFrame(context_id=context_id))
|
||||
|
||||
duration_secs = len(audio_data) / (self._audio_sample_rate * 2)
|
||||
logger.debug(
|
||||
f"Finished pushing recording {recording_id} "
|
||||
f"({len(audio_data)} bytes, {duration_secs:.1f}s)"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buffered_text(self) -> str:
|
||||
"""Return concatenated text from the frame buffer."""
|
||||
return "".join(f.text for f, _ in self._frame_buffer)
|
||||
|
||||
def _reset(self):
|
||||
"""Reset per-response state."""
|
||||
self._frame_buffer = []
|
||||
self._mode = None
|
||||
self._recording_id_buffer = ""
|
||||
|
|
@ -27,6 +27,11 @@ from api.services.pipecat.realtime_feedback_observer import (
|
|||
RealtimeFeedbackObserver,
|
||||
register_turn_log_handlers,
|
||||
)
|
||||
from api.services.pipecat.recording_audio_cache import (
|
||||
create_recording_audio_fetcher,
|
||||
warm_recording_cache,
|
||||
)
|
||||
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
|
|
@ -558,6 +563,12 @@ async def _run_pipeline(
|
|||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
|
||||
# Check if the workflow has any active recordings so the engine can
|
||||
# include recording response mode instructions in all node prompts.
|
||||
has_recordings = await db_client.has_active_recordings(
|
||||
workflow_id, workflow.organization_id
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
workflow=workflow_graph,
|
||||
|
|
@ -567,6 +578,7 @@ async def _run_pipeline(
|
|||
embeddings_api_key=embeddings_api_key,
|
||||
embeddings_model=embeddings_model,
|
||||
embeddings_base_url=embeddings_base_url,
|
||||
has_recordings=has_recordings,
|
||||
)
|
||||
|
||||
# Create pipeline components
|
||||
|
|
@ -680,6 +692,27 @@ async def _run_pipeline(
|
|||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
recording_router = None
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
|
|
@ -692,6 +725,7 @@ async def _run_pipeline(
|
|||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
|
|
|
|||
|
|
@ -5,25 +5,32 @@ from loguru import logger
|
|||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from pipecat.services.azure.llm import AzureLLMService
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.flux.stt import DeepgramFluxSTTService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService, LiveOptions
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
)
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService, DeepgramTTSSettings
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService, DograhSTTSettings
|
||||
from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
|
||||
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.services.speechmatics.stt import SpeechmaticsSTTService
|
||||
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
|
||||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
SpeechmaticsSTTService,
|
||||
SpeechmaticsSTTSettings,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
|
||||
|
||||
|
|
@ -49,8 +56,8 @@ def create_stt_service(
|
|||
logger.debug("Using DeepGram Flux Model")
|
||||
return DeepgramFluxSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=DeepgramFluxSTTService.InputParams(
|
||||
settings=DeepgramFluxSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
eot_timeout_ms=3000,
|
||||
eot_threshold=0.7,
|
||||
eager_eot_threshold=0.5,
|
||||
|
|
@ -63,23 +70,23 @@ def create_stt_service(
|
|||
# Other models than flux
|
||||
# Use language from user config, defaulting to "multi" for multilingual support
|
||||
language = getattr(user_config.stt, "language", None) or "multi"
|
||||
live_options = LiveOptions(
|
||||
language=language,
|
||||
profanity_filter=False,
|
||||
endpointing=100,
|
||||
model=user_config.stt.model,
|
||||
keyterm=keyterms or [],
|
||||
)
|
||||
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
|
||||
return DeepgramSTTService(
|
||||
live_options=live_options,
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=DeepgramSTTSettings(
|
||||
language=language,
|
||||
profanity_filter=False,
|
||||
endpointing=100,
|
||||
model=user_config.stt.model,
|
||||
keyterm=keyterms or [],
|
||||
),
|
||||
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key, model=user_config.stt.model
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=OpenAISTTSettings(model=user_config.stt.model),
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
|
|
@ -92,8 +99,10 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
),
|
||||
keyterms=keyterms,
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
|
|
@ -117,8 +126,10 @@ def create_stt_service(
|
|||
pipecat_language = language_mapping.get(language, Language.HI_IN)
|
||||
return SarvamSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=SarvamSTTService.InputParams(language=pipecat_language),
|
||||
settings=SarvamSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=pipecat_language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
|
||||
|
|
@ -140,7 +151,7 @@ def create_stt_service(
|
|||
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
|
||||
return SpeechmaticsSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
params=SpeechmaticsSTTService.InputParams(
|
||||
settings=SpeechmaticsSTTSettings(
|
||||
language=language,
|
||||
operating_point=operating_point,
|
||||
additional_vocab=additional_vocab,
|
||||
|
|
@ -168,14 +179,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice,
|
||||
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
# Backward compatible with older configuration "Name - voice_id"
|
||||
|
|
@ -186,19 +199,25 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return ElevenLabsTTSService(
|
||||
reconnect_on_error=False,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=voice_id,
|
||||
model=user_config.tts.model,
|
||||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
settings=ElevenLabsTTSSettings(
|
||||
voice=voice_id,
|
||||
model=user_config.tts.model,
|
||||
stability=0.8,
|
||||
speed=user_config.tts.speed,
|
||||
similarity_boost=0.75,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
settings=CartesiaTTSSettings(
|
||||
voice=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
|
|
@ -206,10 +225,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
params=DograhTTSService.InputParams(speed=user_config.tts.speed),
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
speed=user_config.tts.speed,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum for TTS
|
||||
|
|
@ -232,10 +254,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
voice = getattr(user_config.tts, "voice", None) or "anushka"
|
||||
return SarvamTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
voice_id=voice,
|
||||
params=SarvamTTSService.InputParams(language=pipecat_language),
|
||||
settings=SarvamTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=voice,
|
||||
language=pipecat_language,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -253,16 +278,15 @@ def create_llm_service(user_config):
|
|||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(
|
||||
reasoning_effort="minimal", verbosity="low"
|
||||
settings=OpenAILLMSettings(
|
||||
model=model,
|
||||
extra={"reasoning_effort": "minimal", "verbosity": "low"},
|
||||
),
|
||||
)
|
||||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=OpenAILLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
|
|
@ -270,36 +294,30 @@ def create_llm_service(user_config):
|
|||
)
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=GroqLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
|
||||
return OpenRouterLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
base_url=user_config.llm.base_url,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
|
||||
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=GoogleLLMService.InputParams(temperature=0.1),
|
||||
settings=GoogleLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
model=model, # Azure uses deployment name as model
|
||||
params=AzureLLMService.InputParams(temperature=0.1),
|
||||
settings=AzureLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from api.services.workflow.disposition_mapper import (
|
|||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
|
|
@ -16,6 +17,7 @@ from pipecat.frames.frames import (
|
|||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -31,18 +33,19 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
get_function_schema,
|
||||
render_template,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
get_knowledge_base_tool,
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
|
|
@ -50,6 +53,7 @@ from api.services.workflow.tools.timezone import (
|
|||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
class PipecatEngine:
|
||||
|
|
@ -68,6 +72,7 @@ class PipecatEngine:
|
|||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
has_recordings: bool = False,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
|
|
@ -113,6 +118,10 @@ class PipecatEngine:
|
|||
# Audio configuration (set via set_audio_config from _run_pipeline)
|
||||
self._audio_config = None
|
||||
|
||||
# True when the workflow has active recordings; enables recording
|
||||
# response mode instructions on all nodes for in-context learning.
|
||||
self._has_recordings: bool = has_recordings
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
|
|
@ -194,15 +203,14 @@ class PipecatEngine:
|
|||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
raise
|
||||
|
||||
def _get_function_schema(self, function_name: str, description: str):
|
||||
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
|
||||
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
|
||||
"""Update LLM settings with the composed system prompt and tool list."""
|
||||
|
||||
return get_function_schema(function_name, description)
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
|
||||
"""Delegate context update to the shared workflow.utils implementation."""
|
||||
|
||||
update_llm_context(self.context, system_message, functions)
|
||||
if functions:
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
self.context.set_tools(tools_schema)
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
|
@ -473,12 +481,19 @@ class PipecatEngine:
|
|||
if node.document_uuids:
|
||||
await self._register_knowledge_base_function(node.document_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
# Compose prompt and functions via the context composer module
|
||||
system_prompt = compose_system_prompt_for_node(
|
||||
node=node,
|
||||
workflow=self.workflow,
|
||||
format_prompt=self._format_prompt,
|
||||
has_recordings=self._has_recordings,
|
||||
)
|
||||
functions = await compose_functions_for_node(
|
||||
node=node,
|
||||
builtin_function_schemas=self.builtin_function_schemas,
|
||||
custom_tool_manager=self._custom_tool_manager,
|
||||
)
|
||||
await self._update_llm_context(system_prompt, functions)
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -610,62 +625,6 @@ class PipecatEngine:
|
|||
)
|
||||
await self.task.queue_frame(frame_to_push)
|
||||
|
||||
async def _compose_system_message_functions_for_node(
|
||||
self, node: "Node"
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Generate the system messages and function schemas for the given node.
|
||||
|
||||
This performs the same formatting logic used when entering a node but
|
||||
does **not** register the functions with the LLM; callers are
|
||||
responsible for that.
|
||||
"""
|
||||
|
||||
global_prompt = ""
|
||||
if self.workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = self.workflow.nodes[self.workflow.global_node_id]
|
||||
global_prompt = self._format_prompt(global_node.prompt)
|
||||
|
||||
functions: list[dict] = []
|
||||
|
||||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add knowledge base retrieval tool if node has documents
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
formatted_node_prompt = self._format_prompt(node.prompt)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "\n\n".join(
|
||||
p for p in (global_prompt, formatted_node_prompt) if p
|
||||
),
|
||||
}
|
||||
|
||||
return system_message, functions
|
||||
|
||||
async def should_mute_user(self, frame: "Frame") -> bool:
|
||||
"""
|
||||
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
|
||||
|
|
|
|||
138
api/services/workflow/pipecat_engine_context_composer.py
Normal file
138
api/services/workflow/pipecat_engine_context_composer.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""System prompt and function schema composition for PipecatEngine nodes.
|
||||
|
||||
Extracts prompt and function composition logic from PipecatEngine into
|
||||
reusable functions. Defines recording response mode markers and instructions.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
|
||||
from api.services.workflow.tools.knowledge_base import get_knowledge_base_tool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording response mode markers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RECORDING_MARKER = "●" # Play pre-recorded audio
|
||||
TTS_MARKER = "▸" # Generate dynamic TTS text
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording response mode system prompt instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RECORDING_RESPONSE_MODE_INSTRUCTIONS = """\
|
||||
RESPONSE MODE INSTRUCTIONS - MANDATORY FORMAT:
|
||||
Every response you generate MUST begin with a response mode indicator.
|
||||
You have two modes for responding:
|
||||
|
||||
1. DYNAMIC SPEECH (▸): Generate text that will be converted to speech by TTS.
|
||||
Format: `▸` followed by a space and your full spoken response.
|
||||
Example: ▸ Hello! How can I help you today?
|
||||
|
||||
2. PRE-RECORDED AUDIO (●): Play a pre-recorded audio message.
|
||||
Format: `●` followed by a space and ONLY the recording_id. Nothing else.
|
||||
Example: ● rec_greeting_01
|
||||
|
||||
RULES:
|
||||
- Your response MUST start with either `▸` or `●` as the very first character.
|
||||
- For `▸` (dynamic speech): Follow with a space and your full response text.
|
||||
- For `●` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
|
||||
- Use `●` when a pre-recorded message matches the situation well.
|
||||
- Use `▸` when you need to generate a dynamic, contextual response.
|
||||
- NEVER mix modes in a single response. Choose one."""
|
||||
|
||||
|
||||
def compose_system_prompt_for_node(
|
||||
*,
|
||||
node: "Node",
|
||||
workflow: "WorkflowGraph",
|
||||
format_prompt: Callable[[str], str],
|
||||
has_recordings: bool,
|
||||
) -> str:
|
||||
"""Compose the full system prompt text for a workflow node.
|
||||
|
||||
Combines the global prompt, node-specific prompt, and (when recordings
|
||||
are enabled anywhere in the workflow) the recording response mode
|
||||
instructions into a single string.
|
||||
|
||||
Args:
|
||||
node: The workflow node to compose the prompt for.
|
||||
workflow: The full workflow graph (needed for global node prompt).
|
||||
format_prompt: Callable to render template variables in prompts.
|
||||
has_recordings: Whether any node in the workflow uses recordings.
|
||||
|
||||
Returns:
|
||||
The composed system prompt text.
|
||||
"""
|
||||
global_prompt = ""
|
||||
if workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = workflow.nodes[workflow.global_node_id]
|
||||
global_prompt = format_prompt(global_node.prompt)
|
||||
|
||||
formatted_node_prompt = format_prompt(node.prompt)
|
||||
|
||||
parts = [p for p in (global_prompt, formatted_node_prompt) if p]
|
||||
|
||||
if has_recordings:
|
||||
parts.append(RECORDING_RESPONSE_MODE_INSTRUCTIONS)
|
||||
# TODO: Append per-node available recordings list here once
|
||||
# Node.recording_ids is populated. The list should include
|
||||
# recording_id and a short description so the LLM can choose.
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def compose_functions_for_node(
|
||||
*,
|
||||
node: "Node",
|
||||
builtin_function_schemas: list[dict],
|
||||
custom_tool_manager: Optional["CustomToolManager"],
|
||||
) -> list[dict]:
|
||||
"""Compose the function/tool schemas for a workflow node.
|
||||
|
||||
Gathers built-in tools, knowledge-base tools, custom tools,
|
||||
and transition function schemas into a single list.
|
||||
|
||||
Args:
|
||||
node: The workflow node to compose functions for.
|
||||
builtin_function_schemas: Pre-computed schemas for built-in tools.
|
||||
custom_tool_manager: Manager for user-defined custom tools (may be None).
|
||||
|
||||
Returns:
|
||||
A list of function schemas to register with the LLM.
|
||||
"""
|
||||
functions: list[dict] = []
|
||||
|
||||
# Built-in tools (calculator, timezone)
|
||||
functions.extend(builtin_function_schemas)
|
||||
|
||||
# Knowledge base retrieval tool
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Custom tools
|
||||
if node.tool_uuids and custom_tool_manager:
|
||||
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition function schemas
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
return functions
|
||||
|
|
@ -10,7 +10,7 @@ import asyncio
|
|||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -23,7 +23,6 @@ from api.services.telephony.transfer_event_protocol import TransferContext
|
|||
from api.services.workflow.disposition_mapper import (
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_utils import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
|
|
@ -42,6 +41,29 @@ if TYPE_CHECKING:
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
__all__ = [
|
||||
"get_function_schema",
|
||||
"update_llm_context",
|
||||
"render_template",
|
||||
]
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_context(
|
||||
context: LLMContext,
|
||||
system_message: Dict[str, Any],
|
||||
functions: List[FunctionSchema],
|
||||
) -> None:
|
||||
"""Update *context* with an up-to-date system message and tool list.
|
||||
|
||||
This helper removes any previous system messages before inserting the new
|
||||
*system_message* at the top of the conversation history and then instructs
|
||||
the LLM which *functions* (a.k.a. tools) are currently available.
|
||||
"""
|
||||
|
||||
# Wrap the provided function schemas in a ToolsSchema so that the adapter
|
||||
# associated with the current LLM service can convert them to the correct
|
||||
# provider-specific representation when required.
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
# Replace the first message if it's a system message, otherwise prepend.
|
||||
# Keep any system messages that appear in the middle of the conversation.
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
|
@ -13,14 +13,12 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
|
|
@ -862,11 +860,27 @@ class TestCustomToolManagerUnit:
|
|||
assert result_received["status"] == "success"
|
||||
|
||||
|
||||
def _update_llm_context(context, system_message, functions):
|
||||
"""Inline helper replicating the old update_llm_context for tests."""
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
|
||||
class TestUpdateLLMContext:
|
||||
"""Tests for update_llm_context function."""
|
||||
"""Tests for _update_llm_context inline logic."""
|
||||
|
||||
def test_replaces_system_message(self):
|
||||
"""Test that update_llm_context replaces existing system messages."""
|
||||
"""Test that _update_llm_context replaces existing system messages."""
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
|
|
@ -877,7 +891,7 @@ class TestUpdateLLMContext:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "New system message"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
messages = context.messages
|
||||
# Should have new system message at the start
|
||||
|
|
@ -902,7 +916,7 @@ class TestUpdateLLMContext:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
messages = context.messages
|
||||
assert len(messages) == 5
|
||||
|
|
@ -923,7 +937,7 @@ class TestUpdateLLMContext:
|
|||
]
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt with tools"}
|
||||
update_llm_context(context, new_system, functions)
|
||||
_update_llm_context(context, new_system, functions)
|
||||
|
||||
# Verify tools were set
|
||||
tools = context.tools
|
||||
|
|
@ -936,7 +950,7 @@ class TestUpdateLLMContext:
|
|||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt without tools"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
# Tools should not be set (or remain None)
|
||||
# Note: The function only calls set_tools if functions is truthy
|
||||
|
|
@ -952,7 +966,7 @@ class TestUpdateLLMContext:
|
|||
new_system = {"role": "system", "content": "Initial prompt"}
|
||||
functions = [get_function_schema("test_func", "A test function")]
|
||||
|
||||
update_llm_context(context, new_system, functions)
|
||||
_update_llm_context(context, new_system, functions)
|
||||
|
||||
messages = context.messages
|
||||
assert len(messages) == 1
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Integration tests for CustomToolManager with update_llm_context.
|
||||
"""Integration tests for CustomToolManager with LLM context updates.
|
||||
|
||||
This module tests the full flow of:
|
||||
1. CustomToolManager fetching and converting tool schemas
|
||||
2. update_llm_context setting those tools on the LLM context
|
||||
2. Setting those tools on the LLM context
|
||||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
|
|
@ -10,16 +10,32 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.tests.conftest import MockToolModel
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
def _update_llm_context(context, system_message, functions):
|
||||
"""Inline helper replicating the update_llm_context logic for tests."""
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
|
|
@ -69,7 +85,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
"role": "system",
|
||||
"content": "You are a scheduling assistant with access to weather and booking tools.",
|
||||
}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
_update_llm_context(context, new_system, schemas)
|
||||
|
||||
# Verify context was updated correctly
|
||||
messages = context.messages
|
||||
|
|
@ -195,7 +211,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
"role": "system",
|
||||
"content": "Assistant with calculator and weather tools",
|
||||
}
|
||||
update_llm_context(context, new_system, all_functions)
|
||||
_update_llm_context(context, new_system, all_functions)
|
||||
|
||||
# Verify all tools are present
|
||||
tools = context.tools
|
||||
|
|
@ -259,7 +275,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "Updated weather assistant"}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
_update_llm_context(context, new_system, schemas)
|
||||
|
||||
messages = context.messages
|
||||
# System + user + assistant(tool_call) + tool + assistant = 5
|
||||
|
|
@ -296,7 +312,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "No tools available"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
# Context should have updated message but no tools set
|
||||
assert context.messages[0]["content"] == "No tools available"
|
||||
|
|
@ -362,7 +378,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
# Update context - pass schema directly
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
update_llm_context(
|
||||
_update_llm_context(
|
||||
context, {"role": "system", "content": "Order assistant"}, schemas
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,9 +68,7 @@ class ContextCapturingMockLLM(MockLLMService):
|
|||
{
|
||||
"step": self._current_step,
|
||||
"messages": messages_snapshot,
|
||||
"system_prompt": messages_snapshot[0]["content"]
|
||||
if messages_snapshot
|
||||
else None,
|
||||
"system_prompt": self._settings.system_instruction,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -101,12 +99,10 @@ class ContextCapturingMockLLM(MockLLMService):
|
|||
return False
|
||||
|
||||
def get_system_prompt_at_step(self, step: int) -> str:
|
||||
"""Get the system prompt from context at a specific step."""
|
||||
"""Get the system prompt from settings at a specific step."""
|
||||
ctx = self.get_context_at_step(step)
|
||||
if ctx and ctx["messages"]:
|
||||
first_msg = ctx["messages"][0]
|
||||
if first_msg.get("role") == "system":
|
||||
return first_msg.get("content", "")
|
||||
if ctx:
|
||||
return ctx.get("system_prompt") or ""
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_1(
|
||||
|
|
@ -233,7 +233,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
|
||||
|
|
@ -281,7 +281,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_call_through_engine(
|
||||
|
|
@ -315,4 +315,4 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
|
|
|||
2
pipecat
2
pipecat
|
|
@ -1 +1 @@
|
|||
Subproject commit 10e8ded96672b08503db48c3d34e8345b11be4a2
|
||||
Subproject commit ac418b4b3e915a801c82f401e3c68f12ecca98cb
|
||||
24
ui/src/app/handler/[...stack]/BackButton.tsx
Normal file
24
ui/src/app/handler/[...stack]/BackButton.tsx
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"use client";
|
||||
|
||||
import { ArrowLeft } from "lucide-react";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
export function BackButton() {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<header className="flex items-center border-b px-4 py-3">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => router.back()}
|
||||
className="gap-2"
|
||||
>
|
||||
<ArrowLeft className="h-4 w-4" />
|
||||
Go Back
|
||||
</Button>
|
||||
</header>
|
||||
);
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@ import { StackHandler } from "@stackframe/stack";
|
|||
|
||||
import { getAuthProvider } from "@/lib/auth/config";
|
||||
|
||||
import { BackButton } from "./BackButton";
|
||||
|
||||
export default async function Handler(props: unknown) {
|
||||
const authProvider = await getAuthProvider();
|
||||
|
||||
|
|
@ -18,9 +20,16 @@ export default async function Handler(props: unknown) {
|
|||
const { getStackServerApp } = await import("@/lib/auth/server");
|
||||
const app = await getStackServerApp();
|
||||
|
||||
return <StackHandler
|
||||
fullPage
|
||||
app={app!}
|
||||
routeProps={props}
|
||||
/>;
|
||||
return (
|
||||
<div className="flex flex-col h-screen">
|
||||
<BackButton />
|
||||
<div className="flex-1 overflow-auto">
|
||||
<StackHandler
|
||||
fullPage
|
||||
app={app!}
|
||||
routeProps={props}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
"use client";
|
||||
|
||||
import { Calendar, ChevronLeft, ChevronRight, Globe } from 'lucide-react';
|
||||
import { ChevronLeft, ChevronRight, Globe } from 'lucide-react';
|
||||
import { useRouter, useSearchParams } from 'next/navigation';
|
||||
import { useCallback, useEffect, useId, useState } from 'react';
|
||||
import TimezoneSelect, { type ITimezoneOption } from 'react-timezone-select';
|
||||
|
||||
import { getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet, getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet,getUsageHistoryApiV1OrganizationsUsageRunsGet } from '@/client/sdk.gen';
|
||||
import type { CurrentUsageResponse, DailyUsageBreakdownResponse,UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
|
||||
import { getDailyUsageBreakdownApiV1OrganizationsUsageDailyBreakdownGet, getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet, getUsageHistoryApiV1OrganizationsUsageRunsGet } from '@/client/sdk.gen';
|
||||
import type { DailyUsageBreakdownResponse, MpsCreditsResponse, UsageHistoryResponse, WorkflowRunUsageResponse } from '@/client/types.gen';
|
||||
import { DailyUsageTable } from '@/components/DailyUsageTable';
|
||||
import { FilterBuilder } from '@/components/filters/FilterBuilder';
|
||||
import { MediaPreviewButton, MediaPreviewDialog } from '@/components/MediaPreviewDialog';
|
||||
|
|
@ -37,9 +37,9 @@ export default function UsagePage() {
|
|||
const { userConfig, saveUserConfig, loading: userConfigLoading, organizationPricing } = useUserConfig();
|
||||
const auth = useAuth();
|
||||
|
||||
// Current usage state
|
||||
const [currentUsage, setCurrentUsage] = useState<CurrentUsageResponse | null>(null);
|
||||
const [isLoadingCurrent, setIsLoadingCurrent] = useState(true);
|
||||
// MPS credits state
|
||||
const [mpsCredits, setMpsCredits] = useState<MpsCreditsResponse | null>(null);
|
||||
const [isLoadingCredits, setIsLoadingCredits] = useState(true);
|
||||
|
||||
// Usage history state
|
||||
const [usageHistory, setUsageHistory] = useState<UsageHistoryResponse | null>(null);
|
||||
|
|
@ -68,19 +68,18 @@ export default function UsagePage() {
|
|||
const [savingTimezone, setSavingTimezone] = useState(false);
|
||||
const timezoneSelectId = useId(); // Stable ID for react-select to prevent hydration mismatch
|
||||
|
||||
// Fetch current usage
|
||||
const fetchCurrentUsage = useCallback(async () => {
|
||||
// Fetch MPS credits
|
||||
const fetchMpsCredits = useCallback(async () => {
|
||||
if (!auth.isAuthenticated) return;
|
||||
try {
|
||||
const response = await getCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGet();
|
||||
|
||||
const response = await getMpsCreditsApiV1OrganizationsUsageMpsCreditsGet();
|
||||
if (response.data) {
|
||||
setCurrentUsage(response.data);
|
||||
setMpsCredits(response.data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch current usage:', error);
|
||||
console.error('Failed to fetch MPS credits:', error);
|
||||
} finally {
|
||||
setIsLoadingCurrent(false);
|
||||
setIsLoadingCredits(false);
|
||||
}
|
||||
}, [auth.isAuthenticated]);
|
||||
|
||||
|
|
@ -195,10 +194,10 @@ export default function UsagePage() {
|
|||
// Initial load - fetch when auth becomes available
|
||||
useEffect(() => {
|
||||
if (auth.isAuthenticated) {
|
||||
fetchCurrentUsage();
|
||||
fetchMpsCredits();
|
||||
fetchUsageHistory(currentPage, activeFilters);
|
||||
}
|
||||
}, [auth.isAuthenticated, currentPage, activeFilters, fetchUsageHistory, fetchCurrentUsage]);
|
||||
}, [auth.isAuthenticated, currentPage, activeFilters, fetchUsageHistory, fetchMpsCredits]);
|
||||
|
||||
// Fetch daily usage when organizationPricing becomes available
|
||||
useEffect(() => {
|
||||
|
|
@ -259,20 +258,6 @@ export default function UsagePage() {
|
|||
router.push(`/workflow/${run.workflow_id}/run/${run.id}`);
|
||||
};
|
||||
|
||||
// Format date for display with timezone support
|
||||
const formatDate = (dateString: string) => {
|
||||
const date = new Date(dateString);
|
||||
const tzValue = typeof selectedTimezone === 'string' ? selectedTimezone : selectedTimezone.value;
|
||||
// Use local timezone if none selected (during loading)
|
||||
const effectiveTz = tzValue || localTimezone;
|
||||
return date.toLocaleDateString('en-US', {
|
||||
timeZone: effectiveTz,
|
||||
year: 'numeric',
|
||||
month: 'short',
|
||||
day: 'numeric'
|
||||
});
|
||||
};
|
||||
|
||||
// Format datetime for display with timezone support
|
||||
const formatDateTime = (dateString: string) => {
|
||||
const date = new Date(dateString);
|
||||
|
|
@ -383,68 +368,42 @@ export default function UsagePage() {
|
|||
</div>
|
||||
</div>
|
||||
|
||||
{/* Current Period Card */}
|
||||
{/* MPS Credits Card */}
|
||||
<Card className="mb-6">
|
||||
<CardHeader>
|
||||
<CardTitle>Current Billing Period</CardTitle>
|
||||
<CardTitle>Dograh Model Credits</CardTitle>
|
||||
<CardDescription>
|
||||
{currentUsage && `${formatDate(currentUsage.period_start)} - ${formatDate(currentUsage.period_end)}`}
|
||||
These track usage of Dograh models using Dograh Service Keys.
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{isLoadingCurrent ? (
|
||||
{isLoadingCredits ? (
|
||||
<div className="animate-pulse space-y-4">
|
||||
<div className="h-4 bg-muted rounded w-1/4"></div>
|
||||
<div className="h-8 bg-muted rounded"></div>
|
||||
<div className="h-4 bg-muted rounded w-1/3"></div>
|
||||
</div>
|
||||
) : currentUsage ? (
|
||||
) : mpsCredits ? (
|
||||
<div className="space-y-4">
|
||||
<div className="flex justify-between items-baseline">
|
||||
<div>
|
||||
{organizationPricing?.price_per_second_usd ? (
|
||||
<>
|
||||
<p className="text-2xl font-bold">
|
||||
${(currentUsage.used_amount_usd || 0).toFixed(2)}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">Total Cost (USD)</p>
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Rate: ${(organizationPricing.price_per_second_usd * 60).toFixed(4)}/minute
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<p className="text-2xl font-bold">
|
||||
{currentUsage.used_dograh_tokens.toLocaleString()} / {currentUsage.quota_dograh_tokens.toLocaleString()}
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">Dograh Tokens</p>
|
||||
</>
|
||||
)}
|
||||
<p className="text-2xl font-bold">
|
||||
{mpsCredits.total_credits_used.toFixed(2)} <span className="text-lg font-normal text-muted-foreground">/ {mpsCredits.total_quota.toFixed(2)}</span>
|
||||
</p>
|
||||
<p className="text-sm text-muted-foreground">Credits Used</p>
|
||||
</div>
|
||||
<div className="text-right">
|
||||
<p className="text-lg font-semibold">{mpsCredits.remaining_credits.toFixed(2)}</p>
|
||||
<p className="text-sm text-muted-foreground">Remaining</p>
|
||||
</div>
|
||||
{!organizationPricing?.price_per_second_usd && (
|
||||
<div className="text-right">
|
||||
<p className="text-lg font-semibold">{currentUsage.percentage_used}%</p>
|
||||
<p className="text-sm text-muted-foreground">Used</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{!organizationPricing?.price_per_second_usd && (
|
||||
<Progress value={currentUsage.percentage_used} className="h-3" />
|
||||
{mpsCredits.total_quota > 0 && (
|
||||
<Progress value={(mpsCredits.total_credits_used / mpsCredits.total_quota) * 100} className="h-3" />
|
||||
)}
|
||||
|
||||
<div className="flex justify-between items-center text-sm text-muted-foreground">
|
||||
<div className="flex items-center">
|
||||
<Calendar className="h-4 w-4 mr-1" />
|
||||
Next refresh: {formatDate(currentUsage.next_refresh_date)}
|
||||
</div>
|
||||
<div>
|
||||
Total Duration: <span className="font-medium text-foreground">{formatDuration(currentUsage.total_duration_seconds)}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-muted-foreground">Unable to load usage data</p>
|
||||
<p className="text-muted-foreground">No Dograh service keys configured. Set up a service key in your model configuration to see usage.</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@ import {
|
|||
Panel,
|
||||
ReactFlow,
|
||||
} from "@xyflow/react";
|
||||
import { BookA, BrushCleaning, Maximize2, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
|
||||
import { BookA, BrushCleaning, Maximize2, Mic, Minus, Plus, Rocket, Settings, Variable } from 'lucide-react';
|
||||
import React, { useEffect, useMemo, useState } from 'react';
|
||||
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listToolsApiV1ToolsGet } from '@/client';
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
import { listDocumentsApiV1KnowledgeBaseDocumentsGet, listRecordingsApiV1WorkflowRecordingsGet, listToolsApiV1ToolsGet } from '@/client';
|
||||
import type { DocumentResponseSchema, RecordingResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
import { FlowEdge, FlowNode, NodeType } from "@/components/flow/types";
|
||||
import { Button } from '@/components/ui/button';
|
||||
import { Tooltip, TooltipContent, TooltipProvider, TooltipTrigger } from '@/components/ui/tooltip';
|
||||
|
|
@ -23,6 +23,7 @@ import { ConfigurationsDialog } from './components/ConfigurationsDialog';
|
|||
import { DictionaryDialog } from './components/DictionaryDialog';
|
||||
import { EmbedDialog } from './components/EmbedDialog';
|
||||
import { PhoneCallDialog } from './components/PhoneCallDialog';
|
||||
import { RecordingsDialog } from './components/RecordingsDialog';
|
||||
import { TemplateContextVariablesDialog } from './components/TemplateContextVariablesDialog';
|
||||
import { WorkflowEditorHeader } from "./components/WorkflowEditorHeader";
|
||||
import { WorkflowProvider } from "./contexts/WorkflowContext";
|
||||
|
|
@ -67,8 +68,10 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
const [isDictionaryDialogOpen, setIsDictionaryDialogOpen] = useState(false);
|
||||
const [isEmbedDialogOpen, setIsEmbedDialogOpen] = useState(false);
|
||||
const [isPhoneCallDialogOpen, setIsPhoneCallDialogOpen] = useState(false);
|
||||
const [isRecordingsDialogOpen, setIsRecordingsDialogOpen] = useState(false);
|
||||
const [documents, setDocuments] = useState<DocumentResponseSchema[] | undefined>(undefined);
|
||||
const [tools, setTools] = useState<ToolResponse[] | undefined>(undefined);
|
||||
const [recordings, setRecordings] = useState<RecordingResponseSchema[]>([]);
|
||||
|
||||
const {
|
||||
rfInstance,
|
||||
|
|
@ -102,7 +105,7 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
user,
|
||||
});
|
||||
|
||||
// Fetch documents and tools once for the entire workflow
|
||||
// Fetch documents, tools, and recordings once for the entire workflow
|
||||
useEffect(() => {
|
||||
const fetchData = async () => {
|
||||
try {
|
||||
|
|
@ -119,13 +122,25 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
if (toolsResponse.data) {
|
||||
setTools(toolsResponse.data);
|
||||
}
|
||||
|
||||
// Fetch recordings for this workflow
|
||||
try {
|
||||
const recordingsResponse = await listRecordingsApiV1WorkflowRecordingsGet({
|
||||
query: { workflow_id: workflowId },
|
||||
});
|
||||
if (recordingsResponse.data) {
|
||||
setRecordings(recordingsResponse.data.recordings);
|
||||
}
|
||||
} catch {
|
||||
// Recordings API may not be available yet; silently ignore
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to fetch documents and tools:', error);
|
||||
}
|
||||
};
|
||||
|
||||
fetchData();
|
||||
}, []);
|
||||
}, [workflowId]);
|
||||
|
||||
// Memoize defaultEdgeOptions to prevent unnecessary re-renders
|
||||
const defaultEdgeOptions = useMemo(() => ({
|
||||
|
|
@ -137,8 +152,9 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
const workflowContextValue = useMemo(() => ({
|
||||
saveWorkflow,
|
||||
documents,
|
||||
tools
|
||||
}), [saveWorkflow, documents, tools]);
|
||||
tools,
|
||||
recordings,
|
||||
}), [saveWorkflow, documents, tools, recordings]);
|
||||
|
||||
return (
|
||||
<WorkflowProvider value={workflowContextValue}>
|
||||
|
|
@ -251,6 +267,22 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
</TooltipContent>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="icon"
|
||||
onClick={() => setIsRecordingsDialogOpen(true)}
|
||||
className="bg-white shadow-sm hover:shadow-md"
|
||||
>
|
||||
<Mic className="h-4 w-4" />
|
||||
</Button>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent side="left">
|
||||
<p>Recordings</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<Button
|
||||
|
|
@ -389,6 +421,13 @@ function RenderWorkflow({ initialWorkflowName, workflowId, initialFlow, initialT
|
|||
workflowId={workflowId}
|
||||
user={user}
|
||||
/>
|
||||
|
||||
<RecordingsDialog
|
||||
open={isRecordingsDialogOpen}
|
||||
onOpenChange={setIsRecordingsDialogOpen}
|
||||
workflowId={workflowId}
|
||||
onRecordingsChange={setRecordings}
|
||||
/>
|
||||
</div>
|
||||
</WorkflowProvider>
|
||||
);
|
||||
|
|
|
|||
314
ui/src/app/workflow/[workflowId]/components/RecordingsDialog.tsx
Normal file
314
ui/src/app/workflow/[workflowId]/components/RecordingsDialog.tsx
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
import { Loader2, Trash2Icon, Upload } from "lucide-react";
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
import {
|
||||
createRecordingApiV1WorkflowRecordingsPost,
|
||||
deleteRecordingApiV1WorkflowRecordingsRecordingIdDelete,
|
||||
getUploadUrlApiV1WorkflowRecordingsUploadUrlPost,
|
||||
listRecordingsApiV1WorkflowRecordingsGet,
|
||||
} from "@/client";
|
||||
import type { RecordingResponseSchema } from "@/client/types.gen";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { useUserConfig } from "@/context/UserConfigContext";
|
||||
|
||||
interface RecordingsDialogProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
workflowId: number;
|
||||
onRecordingsChange?: (recordings: RecordingResponseSchema[]) => void;
|
||||
}
|
||||
|
||||
const MAX_FILE_SIZE = 5 * 1024 * 1024; // 5MB
|
||||
|
||||
export const RecordingsDialog = ({
|
||||
open,
|
||||
onOpenChange,
|
||||
workflowId,
|
||||
onRecordingsChange,
|
||||
}: RecordingsDialogProps) => {
|
||||
const { userConfig } = useUserConfig();
|
||||
const [recordings, setRecordings] = useState<RecordingResponseSchema[]>([]);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [uploading, setUploading] = useState(false);
|
||||
const [transcript, setTranscript] = useState("");
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const fileInputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const ttsProvider = (userConfig?.tts?.provider as string) ?? "";
|
||||
const ttsModel = (userConfig?.tts?.model as string) ?? "";
|
||||
const ttsVoiceId = (userConfig?.tts?.voice as string) ?? "";
|
||||
|
||||
const fetchRecordings = useCallback(async () => {
|
||||
if (!workflowId) return;
|
||||
setLoading(true);
|
||||
try {
|
||||
const result = await listRecordingsApiV1WorkflowRecordingsGet({
|
||||
query: {
|
||||
workflow_id: workflowId,
|
||||
tts_provider: ttsProvider || undefined,
|
||||
tts_model: ttsModel || undefined,
|
||||
tts_voice_id: ttsVoiceId || undefined,
|
||||
},
|
||||
});
|
||||
const recs = result.data?.recordings ?? [];
|
||||
setRecordings(recs);
|
||||
onRecordingsChange?.(recs);
|
||||
} catch {
|
||||
setError("Failed to load recordings");
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
}, [workflowId, ttsProvider, ttsModel, ttsVoiceId, onRecordingsChange]);
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
fetchRecordings();
|
||||
setError(null);
|
||||
setTranscript("");
|
||||
setSelectedFile(null);
|
||||
}
|
||||
}, [open, fetchRecordings]);
|
||||
|
||||
const handleUpload = async () => {
|
||||
if (!selectedFile || !transcript.trim()) return;
|
||||
if (!ttsProvider || !ttsModel || !ttsVoiceId) {
|
||||
setError(
|
||||
"TTS configuration (provider, model, voice) must be set in your user configuration before uploading."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
setUploading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
// Step 1: Get presigned URL
|
||||
const uploadUrlResponse =
|
||||
await getUploadUrlApiV1WorkflowRecordingsUploadUrlPost({
|
||||
body: {
|
||||
workflow_id: workflowId,
|
||||
filename: selectedFile.name,
|
||||
mime_type: selectedFile.type || "audio/wav",
|
||||
file_size: selectedFile.size,
|
||||
},
|
||||
});
|
||||
|
||||
if (!uploadUrlResponse.data) {
|
||||
throw new Error("Failed to get upload URL");
|
||||
}
|
||||
|
||||
const { upload_url, recording_id, storage_key } =
|
||||
uploadUrlResponse.data;
|
||||
|
||||
// Step 2: Upload file directly to storage
|
||||
const uploadResponse = await fetch(upload_url, {
|
||||
method: "PUT",
|
||||
body: selectedFile,
|
||||
headers: {
|
||||
"Content-Type": selectedFile.type || "audio/wav",
|
||||
},
|
||||
});
|
||||
|
||||
if (!uploadResponse.ok) {
|
||||
throw new Error("File upload failed");
|
||||
}
|
||||
|
||||
// Step 3: Create recording record
|
||||
await createRecordingApiV1WorkflowRecordingsPost({
|
||||
body: {
|
||||
recording_id,
|
||||
workflow_id: workflowId,
|
||||
tts_provider: ttsProvider,
|
||||
tts_model: ttsModel,
|
||||
tts_voice_id: ttsVoiceId,
|
||||
transcript: transcript.trim(),
|
||||
storage_key,
|
||||
metadata: {
|
||||
original_filename: selectedFile.name,
|
||||
file_size_bytes: selectedFile.size,
|
||||
mime_type: selectedFile.type,
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
// Reset form and refresh list
|
||||
setTranscript("");
|
||||
setSelectedFile(null);
|
||||
if (fileInputRef.current) fileInputRef.current.value = "";
|
||||
await fetchRecordings();
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to upload recording"
|
||||
);
|
||||
} finally {
|
||||
setUploading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleDelete = async (recordingId: string) => {
|
||||
try {
|
||||
await deleteRecordingApiV1WorkflowRecordingsRecordingIdDelete({
|
||||
path: { recording_id: recordingId },
|
||||
});
|
||||
await fetchRecordings();
|
||||
} catch {
|
||||
setError("Failed to delete recording");
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||
<DialogContent className="max-w-lg max-h-[80vh] overflow-y-auto">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Workflow Recordings</DialogTitle>
|
||||
<DialogDescription>
|
||||
Upload audio recordings for hybrid prompts. Recordings are
|
||||
scoped to your current TTS configuration. Use{" "}
|
||||
<code className="text-xs bg-muted px-1 rounded">@</code> in
|
||||
prompt fields to insert them.
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
{/* Current TTS Config */}
|
||||
<div className="rounded-md border p-3 bg-muted/30 text-sm space-y-1">
|
||||
<div className="font-medium text-xs text-muted-foreground uppercase tracking-wide">
|
||||
Current TTS Configuration
|
||||
</div>
|
||||
{ttsProvider ? (
|
||||
<div className="flex flex-wrap gap-2 text-xs">
|
||||
<span className="bg-background px-2 py-0.5 rounded border">
|
||||
Provider: {ttsProvider}
|
||||
</span>
|
||||
<span className="bg-background px-2 py-0.5 rounded border">
|
||||
Model: {ttsModel}
|
||||
</span>
|
||||
<span className="bg-background px-2 py-0.5 rounded border truncate max-w-[200px]">
|
||||
VoiceID: {ttsVoiceId}
|
||||
</span>
|
||||
</div>
|
||||
) : (
|
||||
<p className="text-xs text-destructive">
|
||||
No TTS configuration found. Set it in Model Configurations.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="text-sm text-destructive bg-destructive/10 rounded-md p-2">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Upload Section */}
|
||||
<div className="space-y-3 border rounded-md p-3">
|
||||
<Label className="text-sm font-medium">Upload New Recording</Label>
|
||||
<div>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Audio File
|
||||
</Label>
|
||||
<Input
|
||||
ref={fileInputRef}
|
||||
type="file"
|
||||
accept="audio/*"
|
||||
onChange={(e) => {
|
||||
const file = e.target.files?.[0] ?? null;
|
||||
if (file && file.size > MAX_FILE_SIZE) {
|
||||
setError(
|
||||
`File size (${(file.size / (1024 * 1024)).toFixed(1)}MB) exceeds the maximum allowed size of 5MB.`
|
||||
);
|
||||
setSelectedFile(null);
|
||||
if (fileInputRef.current) fileInputRef.current.value = "";
|
||||
return;
|
||||
}
|
||||
setError(null);
|
||||
setSelectedFile(file);
|
||||
}}
|
||||
className="text-sm"
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground mt-1">
|
||||
Max 5MB
|
||||
</p>
|
||||
</div>
|
||||
<div>
|
||||
<Label className="text-xs text-muted-foreground">
|
||||
Transcript
|
||||
</Label>
|
||||
<Input
|
||||
placeholder="What does this recording say?"
|
||||
value={transcript}
|
||||
onChange={(e) => setTranscript(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
onClick={handleUpload}
|
||||
disabled={!selectedFile || !transcript.trim() || uploading}
|
||||
>
|
||||
{uploading ? (
|
||||
<Loader2 className="w-4 h-4 mr-1 animate-spin" />
|
||||
) : (
|
||||
<Upload className="w-4 h-4 mr-1" />
|
||||
)}
|
||||
{uploading ? "Uploading..." : "Upload Recording"}
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{/* Recordings List */}
|
||||
<div className="space-y-2">
|
||||
<Label className="text-sm font-medium">
|
||||
Recordings{" "}
|
||||
{!loading && (
|
||||
<span className="text-muted-foreground font-normal">
|
||||
({recordings.length})
|
||||
</span>
|
||||
)}
|
||||
</Label>
|
||||
{loading ? (
|
||||
<div className="flex items-center justify-center py-4">
|
||||
<Loader2 className="w-5 h-5 animate-spin text-muted-foreground" />
|
||||
</div>
|
||||
) : recordings.length === 0 ? (
|
||||
<p className="text-sm text-muted-foreground py-2">
|
||||
No recordings yet for this TTS configuration.
|
||||
</p>
|
||||
) : (
|
||||
recordings.map((rec) => (
|
||||
<div
|
||||
key={rec.recording_id}
|
||||
className="flex items-start gap-2 p-2 border rounded-md"
|
||||
>
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2">
|
||||
<code className="text-xs bg-muted px-1.5 py-0.5 rounded font-mono">
|
||||
{rec.recording_id}
|
||||
</code>
|
||||
</div>
|
||||
<p className="text-sm text-muted-foreground mt-1 break-all line-clamp-2">
|
||||
{rec.transcript}
|
||||
</p>
|
||||
</div>
|
||||
<Button
|
||||
size="sm"
|
||||
variant="ghost"
|
||||
onClick={() => handleDelete(rec.recording_id)}
|
||||
>
|
||||
<Trash2Icon className="w-4 h-4" />
|
||||
</Button>
|
||||
</div>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
|
@ -1,11 +1,13 @@
|
|||
import { createContext, useContext } from 'react';
|
||||
|
||||
import type { DocumentResponseSchema, ToolResponse } from '@/client/types.gen';
|
||||
import type { RecordingResponseSchema } from '@/client/types.gen';
|
||||
|
||||
interface WorkflowContextType {
|
||||
saveWorkflow: (updateWorkflowDefinition?: boolean) => Promise<void>;
|
||||
documents?: DocumentResponseSchema[];
|
||||
tools?: ToolResponse[];
|
||||
recordings?: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
const WorkflowContext = createContext<WorkflowContextType | undefined>(undefined);
|
||||
|
|
|
|||
|
|
@ -319,18 +319,10 @@ export const useWebSocketRTC = ({ workflowId, workflowRunId, accessToken, initia
|
|||
setFeedbackMessages(prev => {
|
||||
const last = prev[prev.length - 1];
|
||||
if (last && last.type === 'bot-text' && !last.final) {
|
||||
// Append to existing bot message with space if needed
|
||||
const existingText = last.text;
|
||||
const newText = message.payload.text;
|
||||
// Add space between chunks if previous doesn't end with space
|
||||
// and new doesn't start with space or punctuation
|
||||
const needsSpace = existingText.length > 0 &&
|
||||
!existingText.endsWith(' ') &&
|
||||
!newText.startsWith(' ') &&
|
||||
!/^[.,!?;:]/.test(newText);
|
||||
// Append to existing bot message
|
||||
return [
|
||||
...prev.slice(0, -1),
|
||||
{ ...last, text: existingText + (needsSpace ? ' ' : '') + newText }
|
||||
{ ...last, text: last.text + message.payload.text }
|
||||
];
|
||||
}
|
||||
// Start new bot message
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ export function processTranscriptEvents(events: TranscriptEvent[]): ProcessedMes
|
|||
} else if (event.type === 'bot-text') {
|
||||
// Combine consecutive bot-text from the same turn
|
||||
if (currentBotText && currentBotText.event.turn === event.turn) {
|
||||
currentBotText.text = currentBotText.text + ' ' + event.text;
|
||||
currentBotText.text = currentBotText.text + event.text;
|
||||
} else {
|
||||
flushBotText();
|
||||
currentBotText = { event, text: event.text };
|
||||
|
|
|
|||
|
|
@ -16,5 +16,5 @@ import type { ClientOptions } from './types.gen';
|
|||
export type CreateClientConfig<T extends DefaultClientOptions = ClientOptions> = (override?: Config<DefaultClientOptions & T>) => Config<Required<DefaultClientOptions> & T>;
|
||||
|
||||
export const client = createClient(createClientConfig(createConfig<ClientOptions>({
|
||||
baseUrl: 'http://127.0.0.1:8000'
|
||||
baseUrl: 'https://app.dograh.com'
|
||||
})));
|
||||
|
|
|
|||
File diff suppressed because one or more lines are too long
|
|
@ -755,6 +755,12 @@ export type LoginRequest = {
|
|||
password: string;
|
||||
};
|
||||
|
||||
export type MpsCreditsResponse = {
|
||||
total_credits_used: number;
|
||||
remaining_credits: number;
|
||||
total_quota: number;
|
||||
};
|
||||
|
||||
export type PresignedUploadUrlRequest = {
|
||||
/**
|
||||
* CSV filename
|
||||
|
|
@ -790,6 +796,116 @@ export type ProcessDocumentRequestSchema = {
|
|||
s3_key: string;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for creating a recording record after upload.
|
||||
*/
|
||||
export type RecordingCreateRequestSchema = {
|
||||
/**
|
||||
* Short recording ID from upload step
|
||||
*/
|
||||
recording_id: string;
|
||||
/**
|
||||
* Workflow ID
|
||||
*/
|
||||
workflow_id: number;
|
||||
/**
|
||||
* TTS provider (e.g. elevenlabs)
|
||||
*/
|
||||
tts_provider: string;
|
||||
/**
|
||||
* TTS model name
|
||||
*/
|
||||
tts_model: string;
|
||||
/**
|
||||
* TTS voice identifier
|
||||
*/
|
||||
tts_voice_id: string;
|
||||
/**
|
||||
* User-provided transcript of the recording
|
||||
*/
|
||||
transcript: string;
|
||||
/**
|
||||
* Storage key from upload step
|
||||
*/
|
||||
storage_key: string;
|
||||
/**
|
||||
* Optional metadata (file_size, duration, etc.)
|
||||
*/
|
||||
metadata?: {
|
||||
[key: string]: unknown;
|
||||
} | null;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for list of recordings.
|
||||
*/
|
||||
export type RecordingListResponseSchema = {
|
||||
recordings: Array<RecordingResponseSchema>;
|
||||
total: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema for a single recording.
|
||||
*/
|
||||
export type RecordingResponseSchema = {
|
||||
id: number;
|
||||
recording_id: string;
|
||||
workflow_id: number;
|
||||
organization_id: number;
|
||||
tts_provider: string;
|
||||
tts_model: string;
|
||||
tts_voice_id: string;
|
||||
transcript: string;
|
||||
storage_key: string;
|
||||
storage_backend: string;
|
||||
metadata: {
|
||||
[key: string]: unknown;
|
||||
};
|
||||
created_by: number;
|
||||
created_at: string;
|
||||
is_active: boolean;
|
||||
};
|
||||
|
||||
/**
|
||||
* Request schema for getting a presigned upload URL.
|
||||
*/
|
||||
export type RecordingUploadRequestSchema = {
|
||||
/**
|
||||
* Workflow ID this recording belongs to
|
||||
*/
|
||||
workflow_id: number;
|
||||
/**
|
||||
* Original filename of the audio file
|
||||
*/
|
||||
filename: string;
|
||||
/**
|
||||
* MIME type of the audio file
|
||||
*/
|
||||
mime_type?: string;
|
||||
/**
|
||||
* File size in bytes (max 5MB)
|
||||
*/
|
||||
file_size: number;
|
||||
};
|
||||
|
||||
/**
|
||||
* Response schema with presigned upload URL.
|
||||
*/
|
||||
export type RecordingUploadResponseSchema = {
|
||||
/**
|
||||
* Presigned URL for uploading the audio
|
||||
*/
|
||||
upload_url: string;
|
||||
/**
|
||||
* Short unique recording ID
|
||||
*/
|
||||
recording_id: string;
|
||||
/**
|
||||
* Storage key where file will be uploaded
|
||||
*/
|
||||
storage_key: string;
|
||||
};
|
||||
|
||||
export type RetryConfigRequest = {
|
||||
enabled?: boolean;
|
||||
max_retries?: number;
|
||||
|
|
@ -4268,6 +4384,39 @@ export type GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponse
|
|||
|
||||
export type GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponse = GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponses[keyof GetCurrentPeriodUsageApiV1OrganizationsUsageCurrentPeriodGetResponses];
|
||||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/organizations/usage/mps-credits';
|
||||
};
|
||||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetError = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetErrors];
|
||||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: MpsCreditsResponse;
|
||||
};
|
||||
|
||||
export type GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponse = GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses[keyof GetMpsCreditsApiV1OrganizationsUsageMpsCreditsGetResponses];
|
||||
|
||||
export type GetUsageHistoryApiV1OrganizationsUsageRunsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
|
|
@ -5065,6 +5214,155 @@ export type SearchChunksApiV1KnowledgeBaseSearchPostResponses = {
|
|||
|
||||
export type SearchChunksApiV1KnowledgeBaseSearchPostResponse = SearchChunksApiV1KnowledgeBaseSearchPostResponses[keyof SearchChunksApiV1KnowledgeBaseSearchPostResponses];
|
||||
|
||||
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostData = {
|
||||
body: RecordingUploadRequestSchema;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/workflow-recordings/upload-url';
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostError = GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors[keyof GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostErrors];
|
||||
|
||||
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: RecordingUploadResponseSchema;
|
||||
};
|
||||
|
||||
export type GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponse = GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses[keyof GetUploadUrlApiV1WorkflowRecordingsUploadUrlPostResponses];
|
||||
|
||||
export type ListRecordingsApiV1WorkflowRecordingsGetData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query: {
|
||||
/**
|
||||
* Workflow ID
|
||||
*/
|
||||
workflow_id: number;
|
||||
/**
|
||||
* Filter by TTS provider
|
||||
*/
|
||||
tts_provider?: string | null;
|
||||
/**
|
||||
* Filter by TTS model
|
||||
*/
|
||||
tts_model?: string | null;
|
||||
/**
|
||||
* Filter by TTS voice ID
|
||||
*/
|
||||
tts_voice_id?: string | null;
|
||||
};
|
||||
url: '/api/v1/workflow-recordings/';
|
||||
};
|
||||
|
||||
export type ListRecordingsApiV1WorkflowRecordingsGetErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type ListRecordingsApiV1WorkflowRecordingsGetError = ListRecordingsApiV1WorkflowRecordingsGetErrors[keyof ListRecordingsApiV1WorkflowRecordingsGetErrors];
|
||||
|
||||
export type ListRecordingsApiV1WorkflowRecordingsGetResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: RecordingListResponseSchema;
|
||||
};
|
||||
|
||||
export type ListRecordingsApiV1WorkflowRecordingsGetResponse = ListRecordingsApiV1WorkflowRecordingsGetResponses[keyof ListRecordingsApiV1WorkflowRecordingsGetResponses];
|
||||
|
||||
export type CreateRecordingApiV1WorkflowRecordingsPostData = {
|
||||
body: RecordingCreateRequestSchema;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path?: never;
|
||||
query?: never;
|
||||
url: '/api/v1/workflow-recordings/';
|
||||
};
|
||||
|
||||
export type CreateRecordingApiV1WorkflowRecordingsPostErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type CreateRecordingApiV1WorkflowRecordingsPostError = CreateRecordingApiV1WorkflowRecordingsPostErrors[keyof CreateRecordingApiV1WorkflowRecordingsPostErrors];
|
||||
|
||||
export type CreateRecordingApiV1WorkflowRecordingsPostResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: RecordingResponseSchema;
|
||||
};
|
||||
|
||||
export type CreateRecordingApiV1WorkflowRecordingsPostResponse = CreateRecordingApiV1WorkflowRecordingsPostResponses[keyof CreateRecordingApiV1WorkflowRecordingsPostResponses];
|
||||
|
||||
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteData = {
|
||||
body?: never;
|
||||
headers?: {
|
||||
authorization?: string | null;
|
||||
'X-API-Key'?: string | null;
|
||||
};
|
||||
path: {
|
||||
recording_id: string;
|
||||
};
|
||||
query?: never;
|
||||
url: '/api/v1/workflow-recordings/{recording_id}';
|
||||
};
|
||||
|
||||
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors = {
|
||||
/**
|
||||
* Not found
|
||||
*/
|
||||
404: unknown;
|
||||
/**
|
||||
* Validation Error
|
||||
*/
|
||||
422: HttpValidationError;
|
||||
};
|
||||
|
||||
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteError = DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors[keyof DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteErrors];
|
||||
|
||||
export type DeleteRecordingApiV1WorkflowRecordingsRecordingIdDeleteResponses = {
|
||||
/**
|
||||
* Successful Response
|
||||
*/
|
||||
200: unknown;
|
||||
};
|
||||
|
||||
export type SignupApiV1AuthSignupPostData = {
|
||||
body: SignupRequest;
|
||||
path?: never;
|
||||
|
|
@ -5180,5 +5478,5 @@ export type HealthApiV1HealthGetResponses = {
|
|||
export type HealthApiV1HealthGetResponse = HealthApiV1HealthGetResponses[keyof HealthApiV1HealthGetResponses];
|
||||
|
||||
export type ClientOptions = {
|
||||
baseUrl: 'http://127.0.0.1:8000' | (string & {});
|
||||
baseUrl: 'https://app.dograh.com' | 'http://localhost:8000' | (string & {});
|
||||
};
|
||||
|
|
|
|||
213
ui/src/components/flow/MentionTextarea.tsx
Normal file
213
ui/src/components/flow/MentionTextarea.tsx
Normal file
|
|
@ -0,0 +1,213 @@
|
|||
import {
|
||||
type ChangeEvent,
|
||||
type KeyboardEvent,
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
|
||||
import type { RecordingResponseSchema } from "@/client/types.gen";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface MentionItem {
|
||||
id: string;
|
||||
name: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
interface MentionTextareaProps {
|
||||
value: string;
|
||||
onChange: (value: string) => void;
|
||||
placeholder?: string;
|
||||
className?: string;
|
||||
recordings?: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
export function MentionTextarea({
|
||||
value,
|
||||
onChange,
|
||||
placeholder,
|
||||
className,
|
||||
recordings = [],
|
||||
}: MentionTextareaProps) {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const [showDropdown, setShowDropdown] = useState(false);
|
||||
const [query, setQuery] = useState("");
|
||||
const [mentionStartIndex, setMentionStartIndex] = useState<number | null>(null);
|
||||
const [selectedIndex, setSelectedIndex] = useState(0);
|
||||
|
||||
// Convert recordings to mention items
|
||||
const items: MentionItem[] = useMemo(
|
||||
() =>
|
||||
recordings.map((r) => ({
|
||||
id: r.recording_id,
|
||||
name: r.transcript,
|
||||
description: r.transcript,
|
||||
})),
|
||||
[recordings]
|
||||
);
|
||||
|
||||
const filtered = items.filter(
|
||||
(item) =>
|
||||
item.name.toLowerCase().includes(query.toLowerCase()) ||
|
||||
item.id.toLowerCase().includes(query.toLowerCase())
|
||||
);
|
||||
|
||||
const insertMention = useCallback(
|
||||
(item: MentionItem) => {
|
||||
if (mentionStartIndex === null) return;
|
||||
const textarea = textareaRef.current;
|
||||
if (!textarea) return;
|
||||
|
||||
const before = value.slice(0, mentionStartIndex);
|
||||
const after = value.slice(textarea.selectionStart);
|
||||
const mentionText = `RECORDING_ID: ${item.id} [ ${item.description} ]`;
|
||||
const newValue = before + mentionText + after;
|
||||
|
||||
onChange(newValue);
|
||||
setShowDropdown(false);
|
||||
setQuery("");
|
||||
setMentionStartIndex(null);
|
||||
setSelectedIndex(0);
|
||||
|
||||
// Restore cursor position after the inserted mention
|
||||
requestAnimationFrame(() => {
|
||||
const cursorPos = before.length + mentionText.length;
|
||||
textarea.focus();
|
||||
textarea.setSelectionRange(cursorPos, cursorPos);
|
||||
});
|
||||
},
|
||||
[mentionStartIndex, value, onChange]
|
||||
);
|
||||
|
||||
const handleChange = useCallback(
|
||||
(e: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const newValue = e.target.value;
|
||||
const cursorPos = e.target.selectionStart;
|
||||
onChange(newValue);
|
||||
|
||||
// Look backwards from cursor to find an unmatched "@"
|
||||
const textBeforeCursor = newValue.slice(0, cursorPos);
|
||||
const lastAtIndex = textBeforeCursor.lastIndexOf("@");
|
||||
|
||||
if (lastAtIndex !== -1) {
|
||||
const textBetween = textBeforeCursor.slice(lastAtIndex + 1);
|
||||
// Only trigger if there's no space before the query or it's at the start
|
||||
const charBeforeAt = lastAtIndex > 0 ? newValue[lastAtIndex - 1] : " ";
|
||||
if (
|
||||
(charBeforeAt === " " || charBeforeAt === "\n" || lastAtIndex === 0) &&
|
||||
!textBetween.includes(" ")
|
||||
) {
|
||||
setShowDropdown(true);
|
||||
setQuery(textBetween);
|
||||
setMentionStartIndex(lastAtIndex);
|
||||
setSelectedIndex(0);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
setShowDropdown(false);
|
||||
setQuery("");
|
||||
setMentionStartIndex(null);
|
||||
},
|
||||
[onChange]
|
||||
);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(e: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (!showDropdown || filtered.length === 0) return;
|
||||
|
||||
if (e.key === "ArrowDown") {
|
||||
e.preventDefault();
|
||||
setSelectedIndex((prev) => (prev + 1) % filtered.length);
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
setSelectedIndex((prev) => (prev - 1 + filtered.length) % filtered.length);
|
||||
} else if (e.key === "Enter" || e.key === "Tab") {
|
||||
e.preventDefault();
|
||||
insertMention(filtered[selectedIndex]);
|
||||
} else if (e.key === "Escape") {
|
||||
e.preventDefault();
|
||||
setShowDropdown(false);
|
||||
}
|
||||
},
|
||||
[showDropdown, filtered, selectedIndex, insertMention]
|
||||
);
|
||||
|
||||
// Close dropdown when clicking outside
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (e: MouseEvent) => {
|
||||
if (
|
||||
dropdownRef.current &&
|
||||
!dropdownRef.current.contains(e.target as Node) &&
|
||||
textareaRef.current &&
|
||||
!textareaRef.current.contains(e.target as Node)
|
||||
) {
|
||||
setShowDropdown(false);
|
||||
}
|
||||
};
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => document.removeEventListener("mousedown", handleClickOutside);
|
||||
}, []);
|
||||
|
||||
// Scroll selected item into view
|
||||
useEffect(() => {
|
||||
if (!showDropdown || !dropdownRef.current) return;
|
||||
const selected = dropdownRef.current.querySelector("[data-selected='true']");
|
||||
selected?.scrollIntoView({ block: "nearest" });
|
||||
}, [selectedIndex, showDropdown]);
|
||||
|
||||
return (
|
||||
<div className="relative">
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
value={value}
|
||||
onChange={handleChange}
|
||||
onKeyDown={handleKeyDown}
|
||||
placeholder={placeholder}
|
||||
className={cn(
|
||||
"border-input placeholder:text-muted-foreground focus-visible:border-ring focus-visible:ring-ring/50 aria-invalid:ring-destructive/20 dark:aria-invalid:ring-destructive/40 aria-invalid:border-destructive dark:bg-input/30 flex field-sizing-content min-h-16 w-full rounded-md border bg-transparent px-3 py-2 text-base shadow-xs transition-[color,box-shadow] outline-none focus-visible:ring-[3px] disabled:cursor-not-allowed disabled:opacity-50 md:text-sm",
|
||||
className
|
||||
)}
|
||||
/>
|
||||
{showDropdown && filtered.length > 0 && (
|
||||
<div
|
||||
ref={dropdownRef}
|
||||
className="absolute z-50 mt-1 w-full max-h-60 overflow-y-auto rounded-md border bg-popover text-popover-foreground shadow-md"
|
||||
>
|
||||
{filtered.map((item, index) => (
|
||||
<button
|
||||
key={item.id}
|
||||
type="button"
|
||||
data-selected={index === selectedIndex}
|
||||
className={cn(
|
||||
"flex w-full flex-col gap-0.5 px-3 py-2 text-left text-sm cursor-pointer hover:bg-accent",
|
||||
index === selectedIndex && "bg-accent"
|
||||
)}
|
||||
onMouseDown={(e) => {
|
||||
e.preventDefault(); // prevent textarea blur
|
||||
insertMention(item);
|
||||
}}
|
||||
onMouseEnter={() => setSelectedIndex(index)}
|
||||
>
|
||||
<div className="flex items-center gap-2">
|
||||
<code className="text-xs bg-muted px-1 py-0.5 rounded font-mono">
|
||||
{item.id}
|
||||
</code>
|
||||
<span className="font-medium truncate">{item.name}</span>
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
{showDropdown && filtered.length === 0 && items.length === 0 && (
|
||||
<div className="absolute z-50 mt-1 w-full rounded-md border bg-popover text-popover-foreground shadow-md p-3 text-sm text-muted-foreground">
|
||||
No recordings found. Upload recordings via the Recordings panel.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
|
@ -3,9 +3,10 @@ import { Edit, FileText, Headset, PlusIcon, Trash2Icon, Wrench } from "lucide-re
|
|||
import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import type { DocumentResponseSchema, RecordingResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { MentionTextarea } from "@/components/flow/MentionTextarea";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
|
|
@ -42,6 +43,7 @@ interface AgentNodeEditFormProps {
|
|||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
recordings: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
interface AgentNodeProps extends NodeProps {
|
||||
|
|
@ -50,7 +52,7 @@ interface AgentNodeProps extends NodeProps {
|
|||
|
||||
export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
||||
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents, recordings } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
|
|
@ -229,6 +231,7 @@ export const AgentNode = memo(({ data, selected, id }: AgentNodeProps) => {
|
|||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
recordings={recordings ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -257,6 +260,7 @@ const AgentNodeEditForm = ({
|
|||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
recordings,
|
||||
}: AgentNodeEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -318,13 +322,12 @@ const AgentNodeEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
Enter the prompt for the agent. This will be used to generate the agent's response. Prompt engineering's best practices apply.
|
||||
</Label>
|
||||
<Textarea
|
||||
<MentionTextarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
className="min-h-[100px] max-h-[300px] resize-none"
|
||||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
onChange={setPrompt}
|
||||
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
|
||||
placeholder="Enter a prompt"
|
||||
recordings={recordings}
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ import { Edit, OctagonX, PlusIcon, Trash2Icon } from "lucide-react";
|
|||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { RecordingResponseSchema } from "@/client/types.gen";
|
||||
import { MentionTextarea } from "@/components/flow/MentionTextarea";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
|
|
@ -29,6 +31,7 @@ interface EndCallEditFormProps {
|
|||
setVariables: (vars: ExtractionVariable[]) => void;
|
||||
addGlobalPrompt: boolean;
|
||||
setAddGlobalPrompt: (value: boolean) => void;
|
||||
recordings: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
interface EndCallNodeProps extends NodeProps {
|
||||
|
|
@ -40,7 +43,7 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
id,
|
||||
additionalData: { is_end: true }
|
||||
});
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, recordings } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
|
|
@ -157,6 +160,7 @@ export const EndCall = memo(({ data, selected, id }: EndCallNodeProps) => {
|
|||
setVariables={setVariables}
|
||||
addGlobalPrompt={addGlobalPrompt}
|
||||
setAddGlobalPrompt={setAddGlobalPrompt}
|
||||
recordings={recordings ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -177,6 +181,7 @@ const EndCallEditForm = ({
|
|||
setVariables,
|
||||
addGlobalPrompt,
|
||||
setAddGlobalPrompt,
|
||||
recordings,
|
||||
}: EndCallEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -216,14 +221,12 @@ const EndCallEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
Enter the prompt for the agent. This will be used to generate the agent's response. Prompt engineering's best practices apply.
|
||||
</Label>
|
||||
<Textarea
|
||||
<MentionTextarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
className="min-h-[100px] max-h-[300px] resize-none"
|
||||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
placeholder="Enter a dynamic prompt"
|
||||
onChange={setPrompt}
|
||||
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
|
||||
placeholder="Enter a prompt"
|
||||
recordings={recordings}
|
||||
/>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="add-global-prompt" checked={addGlobalPrompt} onCheckedChange={setAddGlobalPrompt} />
|
||||
|
|
|
|||
|
|
@ -3,11 +3,12 @@ import { Edit, Headset, Trash2Icon } from "lucide-react";
|
|||
import { memo, useEffect, useMemo, useState } from "react";
|
||||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { RecordingResponseSchema } from "@/client/types.gen";
|
||||
import { MentionTextarea } from "@/components/flow/MentionTextarea";
|
||||
import { FlowNodeData } from "@/components/flow/types";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { NODE_DOCUMENTATION_URLS } from "@/constants/documentation";
|
||||
|
||||
import { NodeContent } from "./common/NodeContent";
|
||||
|
|
@ -20,6 +21,7 @@ interface GlobalNodeEditFormProps {
|
|||
setPrompt: (value: string) => void;
|
||||
name: string;
|
||||
setName: (value: string) => void;
|
||||
recordings: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
interface GlobalNodeProps extends NodeProps {
|
||||
|
|
@ -28,7 +30,7 @@ interface GlobalNodeProps extends NodeProps {
|
|||
|
||||
export const GlobalNode = memo(({ data, selected, id }: GlobalNodeProps) => {
|
||||
const { open, setOpen, handleSaveNodeData, handleDeleteNode } = useNodeHandlers({ id });
|
||||
const { saveWorkflow } = useWorkflow();
|
||||
const { saveWorkflow, recordings } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt);
|
||||
|
|
@ -118,6 +120,7 @@ export const GlobalNode = memo(({ data, selected, id }: GlobalNodeProps) => {
|
|||
setPrompt={setPrompt}
|
||||
name={name}
|
||||
setName={setName}
|
||||
recordings={recordings ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -129,7 +132,8 @@ const GlobalNodeEditForm = ({
|
|||
prompt,
|
||||
setPrompt,
|
||||
name,
|
||||
setName
|
||||
setName,
|
||||
recordings,
|
||||
}: GlobalNodeEditFormProps) => {
|
||||
return (
|
||||
<div className="grid gap-2">
|
||||
|
|
@ -146,13 +150,12 @@ const GlobalNodeEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
This is the global prompt. This will be added to the system prompt of all the agents.
|
||||
</Label>
|
||||
<Textarea
|
||||
<MentionTextarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
className="min-h-[100px] max-h-[300px] resize-none"
|
||||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
onChange={setPrompt}
|
||||
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
|
||||
placeholder="Enter a prompt"
|
||||
recordings={recordings}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@ import { memo, useCallback, useEffect, useMemo, useState } from "react";
|
|||
|
||||
import { useWorkflow } from "@/app/workflow/[workflowId]/contexts/WorkflowContext";
|
||||
import type { DocumentResponseSchema, ToolResponse } from "@/client/types.gen";
|
||||
import type { RecordingResponseSchema } from "@/client/types.gen";
|
||||
import { DocumentBadges } from "@/components/flow/DocumentBadges";
|
||||
import { DocumentSelector } from "@/components/flow/DocumentSelector";
|
||||
import { MentionTextarea } from "@/components/flow/MentionTextarea";
|
||||
import { ToolBadges } from "@/components/flow/ToolBadges";
|
||||
import { ToolSelector } from "@/components/flow/ToolSelector";
|
||||
import { ExtractionVariable, FlowNodeData } from "@/components/flow/types";
|
||||
|
|
@ -48,6 +50,7 @@ interface StartCallEditFormProps {
|
|||
setDocumentUuids: (value: string[]) => void;
|
||||
tools: ToolResponse[];
|
||||
documents: DocumentResponseSchema[];
|
||||
recordings: RecordingResponseSchema[];
|
||||
}
|
||||
|
||||
interface StartCallNodeProps extends NodeProps {
|
||||
|
|
@ -59,7 +62,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
id,
|
||||
additionalData: { is_start: true }
|
||||
});
|
||||
const { saveWorkflow, tools, documents } = useWorkflow();
|
||||
const { saveWorkflow, tools, documents, recordings } = useWorkflow();
|
||||
|
||||
// Form state
|
||||
const [prompt, setPrompt] = useState(data.prompt ?? "");
|
||||
|
|
@ -248,6 +251,7 @@ export const StartCall = memo(({ data, selected, id }: StartCallNodeProps) => {
|
|||
setDocumentUuids={setDocumentUuids}
|
||||
tools={tools ?? []}
|
||||
documents={documents ?? []}
|
||||
recordings={recordings ?? []}
|
||||
/>
|
||||
)}
|
||||
</NodeEditDialog>
|
||||
|
|
@ -282,6 +286,7 @@ const StartCallEditForm = ({
|
|||
setDocumentUuids,
|
||||
tools,
|
||||
documents,
|
||||
recordings,
|
||||
}: StartCallEditFormProps) => {
|
||||
const handleVariableNameChange = (idx: number, value: string) => {
|
||||
const newVars = [...variables];
|
||||
|
|
@ -325,14 +330,12 @@ const StartCallEditForm = ({
|
|||
<Label className="text-xs text-muted-foreground">
|
||||
Enter the prompt for the agent. This will be used to generate the agent's response. Prompt engineering's best practices apply.
|
||||
</Label>
|
||||
<Textarea
|
||||
<MentionTextarea
|
||||
value={prompt}
|
||||
onChange={(e) => setPrompt(e.target.value)}
|
||||
className="min-h-[100px] max-h-[300px] resize-none"
|
||||
style={{
|
||||
overflowY: 'auto'
|
||||
}}
|
||||
onChange={setPrompt}
|
||||
className="min-h-[100px] max-h-[300px] resize-none overflow-y-auto"
|
||||
placeholder="Enter a prompt"
|
||||
recordings={recordings}
|
||||
/>
|
||||
<div className="flex items-center space-x-2">
|
||||
<Switch id="allow-interrupt" checked={allowInterrupt} onCheckedChange={setAllowInterrupt} />
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue