mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-25 08:48:13 +02:00
feat: add hybrid text + recording functionality in agents (#191)
* feat: add recording feature in agents * chore: pin pipecat version * feat: show usage in UI * chore: update pipecat
This commit is contained in:
parent
f075bcb623
commit
494c60d774
43 changed files with 2865 additions and 397 deletions
102
api/alembic/versions/e54ddb048535_add_recording_table.py
Normal file
102
api/alembic/versions/e54ddb048535_add_recording_table.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
"""add recording table
|
||||
|
||||
Revision ID: e54ddb048535
|
||||
Revises: 6fd8fac02883
|
||||
Create Date: 2026-03-12 19:21:53.729888
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "e54ddb048535"
|
||||
down_revision: Union[str, None] = "6fd8fac02883"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum("s3", "minio", name="recording_storage_backend").create(op.get_bind())
|
||||
op.create_table(
|
||||
"workflow_recordings",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("recording_id", sa.String(length=16), nullable=False),
|
||||
sa.Column("workflow_id", sa.Integer(), nullable=False),
|
||||
sa.Column("organization_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tts_provider", sa.String(), nullable=False),
|
||||
sa.Column("tts_model", sa.String(), nullable=False),
|
||||
sa.Column("tts_voice_id", sa.String(), nullable=False),
|
||||
sa.Column("transcript", sa.Text(), nullable=False),
|
||||
sa.Column("storage_key", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"storage_backend",
|
||||
postgresql.ENUM(
|
||||
"s3", "minio", name="recording_storage_backend", create_type=False
|
||||
),
|
||||
server_default=sa.text("'s3'::recording_storage_backend"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"recording_metadata",
|
||||
sa.JSON(),
|
||||
server_default=sa.text("'{}'::json"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("created_by", sa.Integer(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["created_by"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"], ["organizations.id"], ondelete="CASCADE"
|
||||
),
|
||||
sa.ForeignKeyConstraint(["workflow_id"], ["workflows.id"], ondelete="CASCADE"),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_org_id",
|
||||
"workflow_recordings",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_workflow_recordings_recording_id"),
|
||||
"workflow_recordings",
|
||||
["recording_id"],
|
||||
unique=True,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_tts_scope",
|
||||
"workflow_recordings",
|
||||
["workflow_id", "tts_provider", "tts_model", "tts_voice_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
"ix_workflow_recordings_workflow_id",
|
||||
"workflow_recordings",
|
||||
["workflow_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(
|
||||
"ix_workflow_recordings_workflow_id", table_name="workflow_recordings"
|
||||
)
|
||||
op.drop_index("ix_workflow_recordings_tts_scope", table_name="workflow_recordings")
|
||||
op.drop_index(
|
||||
op.f("ix_workflow_recordings_recording_id"), table_name="workflow_recordings"
|
||||
)
|
||||
op.drop_index("ix_workflow_recordings_org_id", table_name="workflow_recordings")
|
||||
op.drop_table("workflow_recordings")
|
||||
sa.Enum("s3", "minio", name="recording_storage_backend").drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -526,7 +526,7 @@ class CampaignClient(BaseDBClient):
|
|||
QueuedRunModel.state == "queued",
|
||||
QueuedRunModel.scheduled_for.is_(None),
|
||||
)
|
||||
.order_by(QueuedRunModel.created_at)
|
||||
.order_by(func.random())
|
||||
.limit(remaining_slots)
|
||||
.with_for_update(skip_locked=True)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from api.db.tool_client import ToolClient
|
|||
from api.db.user_client import UserClient
|
||||
from api.db.webhook_credential_client import WebhookCredentialClient
|
||||
from api.db.workflow_client import WorkflowClient
|
||||
from api.db.workflow_recording_client import WorkflowRecordingClient
|
||||
from api.db.workflow_run_client import WorkflowRunClient
|
||||
from api.db.workflow_template_client import WorkflowTemplateClient
|
||||
|
||||
|
|
@ -35,6 +36,7 @@ class DBClient(
|
|||
WebhookCredentialClient,
|
||||
ToolClient,
|
||||
KnowledgeBaseClient,
|
||||
WorkflowRecordingClient,
|
||||
):
|
||||
"""
|
||||
Unified database client that combines all specialized database operations.
|
||||
|
|
|
|||
|
|
@ -996,6 +996,77 @@ class KnowledgeBaseDocumentModel(Base):
|
|||
)
|
||||
|
||||
|
||||
class WorkflowRecordingModel(Base):
|
||||
"""Model for storing audio recordings scoped to a workflow and TTS configuration.
|
||||
|
||||
Recordings are used in hybrid prompts where parts of the output are pre-recorded
|
||||
audio rather than dynamically generated TTS.
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_recordings"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
|
||||
# Short globally unique ID (e.g. "xbhfha3k") used in prompts
|
||||
recording_id = Column(String(16), unique=True, nullable=False, index=True)
|
||||
|
||||
# Scoping
|
||||
workflow_id = Column(
|
||||
Integer, ForeignKey("workflows.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
organization_id = Column(
|
||||
Integer, ForeignKey("organizations.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
# TTS configuration scope
|
||||
tts_provider = Column(String, nullable=False)
|
||||
tts_model = Column(String, nullable=False)
|
||||
tts_voice_id = Column(String, nullable=False)
|
||||
|
||||
# Content
|
||||
transcript = Column(Text, nullable=False)
|
||||
|
||||
# Storage
|
||||
storage_key = Column(String, nullable=False)
|
||||
storage_backend = Column(
|
||||
Enum("s3", "minio", name="recording_storage_backend"),
|
||||
nullable=False,
|
||||
default="s3",
|
||||
server_default=text("'s3'::recording_storage_backend"),
|
||||
)
|
||||
|
||||
# Extra metadata (file_size_bytes, duration_seconds, original_filename, mime_type, etc.)
|
||||
recording_metadata = Column(
|
||||
JSON, nullable=False, default=dict, server_default=text("'{}'::json")
|
||||
)
|
||||
|
||||
# Audit
|
||||
created_by = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
|
||||
# Soft delete
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
# Relationships
|
||||
workflow = relationship("WorkflowModel")
|
||||
organization = relationship("OrganizationModel")
|
||||
created_by_user = relationship("UserModel")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (
|
||||
Index("ix_workflow_recordings_workflow_id", "workflow_id"),
|
||||
Index("ix_workflow_recordings_org_id", "organization_id"),
|
||||
Index("ix_workflow_recordings_recording_id", "recording_id"),
|
||||
Index(
|
||||
"ix_workflow_recordings_tts_scope",
|
||||
"workflow_id",
|
||||
"tts_provider",
|
||||
"tts_model",
|
||||
"tts_voice_id",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class KnowledgeBaseChunkModel(Base):
|
||||
"""Model for storing document chunks with vector embeddings.
|
||||
|
||||
|
|
|
|||
218
api/db/workflow_recording_client.py
Normal file
218
api/db/workflow_recording_client.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""Database client for managing workflow recordings."""
|
||||
|
||||
import secrets
|
||||
import string
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import WorkflowRecordingModel
|
||||
|
||||
|
||||
def generate_short_id(length: int = 8) -> str:
|
||||
"""Generate a random lowercase alphanumeric short ID."""
|
||||
alphabet = string.ascii_lowercase + string.digits
|
||||
return "".join(secrets.choice(alphabet) for _ in range(length))
|
||||
|
||||
|
||||
class WorkflowRecordingClient(BaseDBClient):
|
||||
"""Client for managing workflow audio recordings."""
|
||||
|
||||
async def create_recording(
|
||||
self,
|
||||
recording_id: str,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
tts_provider: str,
|
||||
tts_model: str,
|
||||
tts_voice_id: str,
|
||||
transcript: str,
|
||||
storage_key: str,
|
||||
storage_backend: str,
|
||||
created_by: int,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> WorkflowRecordingModel:
|
||||
"""Create a new workflow recording record.
|
||||
|
||||
Args:
|
||||
recording_id: Short unique recording identifier
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
tts_provider: TTS provider name
|
||||
tts_model: TTS model name
|
||||
tts_voice_id: TTS voice identifier
|
||||
transcript: User-provided transcript
|
||||
storage_key: S3/MinIO storage key
|
||||
storage_backend: Storage backend (s3 or minio)
|
||||
created_by: ID of the user
|
||||
metadata: Optional extra metadata
|
||||
|
||||
Returns:
|
||||
The created WorkflowRecordingModel
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
recording = WorkflowRecordingModel(
|
||||
recording_id=recording_id,
|
||||
workflow_id=workflow_id,
|
||||
organization_id=organization_id,
|
||||
tts_provider=tts_provider,
|
||||
tts_model=tts_model,
|
||||
tts_voice_id=tts_voice_id,
|
||||
transcript=transcript,
|
||||
storage_key=storage_key,
|
||||
storage_backend=storage_backend,
|
||||
created_by=created_by,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
session.add(recording)
|
||||
await session.commit()
|
||||
await session.refresh(recording)
|
||||
|
||||
logger.info(
|
||||
f"Created recording {recording_id} for workflow {workflow_id}, "
|
||||
f"org {organization_id}"
|
||||
)
|
||||
return recording
|
||||
|
||||
async def get_recordings_for_workflow(
|
||||
self,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
tts_provider: Optional[str] = None,
|
||||
tts_model: Optional[str] = None,
|
||||
tts_voice_id: Optional[str] = None,
|
||||
) -> List[WorkflowRecordingModel]:
|
||||
"""Get recordings for a workflow, optionally filtered by TTS config.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
tts_provider: Optional TTS provider filter
|
||||
tts_model: Optional TTS model filter
|
||||
tts_voice_id: Optional TTS voice ID filter
|
||||
|
||||
Returns:
|
||||
List of WorkflowRecordingModel instances
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.workflow_id == workflow_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
|
||||
if tts_provider:
|
||||
query = query.where(WorkflowRecordingModel.tts_provider == tts_provider)
|
||||
if tts_model:
|
||||
query = query.where(WorkflowRecordingModel.tts_model == tts_model)
|
||||
if tts_voice_id:
|
||||
query = query.where(WorkflowRecordingModel.tts_voice_id == tts_voice_id)
|
||||
|
||||
query = query.order_by(WorkflowRecordingModel.created_at.desc())
|
||||
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_recording_by_recording_id(
|
||||
self,
|
||||
recording_id: str,
|
||||
organization_id: int,
|
||||
) -> Optional[WorkflowRecordingModel]:
|
||||
"""Get a recording by its short ID.
|
||||
|
||||
Args:
|
||||
recording_id: The short unique recording ID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
WorkflowRecordingModel if found, None otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def has_active_recordings(
|
||||
self,
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Check if a workflow has any active recordings.
|
||||
|
||||
Args:
|
||||
workflow_id: ID of the workflow
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if at least one active recording exists, False otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = (
|
||||
select(func.count())
|
||||
.select_from(WorkflowRecordingModel)
|
||||
.where(
|
||||
WorkflowRecordingModel.workflow_id == workflow_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
WorkflowRecordingModel.is_active == True,
|
||||
)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one() > 0
|
||||
|
||||
async def check_recording_id_exists(self, recording_id: str) -> bool:
|
||||
"""Check if a recording ID already exists globally.
|
||||
|
||||
Args:
|
||||
recording_id: The short recording ID to check
|
||||
|
||||
Returns:
|
||||
True if exists, False otherwise
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel.id).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return result.scalar_one_or_none() is not None
|
||||
|
||||
async def delete_recording(
|
||||
self,
|
||||
recording_id: str,
|
||||
organization_id: int,
|
||||
) -> bool:
|
||||
"""Soft delete a recording.
|
||||
|
||||
Args:
|
||||
recording_id: The short recording ID
|
||||
organization_id: ID of the organization
|
||||
|
||||
Returns:
|
||||
True if deleted, False if not found
|
||||
"""
|
||||
async with self.async_session() as session:
|
||||
query = select(WorkflowRecordingModel).where(
|
||||
WorkflowRecordingModel.recording_id == recording_id,
|
||||
WorkflowRecordingModel.organization_id == organization_id,
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
recording = result.scalar_one_or_none()
|
||||
|
||||
if not recording:
|
||||
return False
|
||||
|
||||
recording.is_active = False
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
f"Deleted recording {recording_id} for organization {organization_id}"
|
||||
)
|
||||
return True
|
||||
|
|
@ -24,6 +24,7 @@ from api.routes.user import router as user_router
|
|||
from api.routes.webrtc_signaling import router as webrtc_signaling_router
|
||||
from api.routes.workflow import router as workflow_router
|
||||
from api.routes.workflow_embed import router as workflow_embed_router
|
||||
from api.routes.workflow_recording import router as workflow_recording_router
|
||||
|
||||
router = APIRouter(
|
||||
tags=["main"],
|
||||
|
|
@ -51,6 +52,7 @@ router.include_router(public_agent_router)
|
|||
router.include_router(public_download_router)
|
||||
router.include_router(workflow_embed_router)
|
||||
router.include_router(knowledge_base_router)
|
||||
router.include_router(workflow_recording_router)
|
||||
router.include_router(auth_router)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@ from datetime import datetime, timedelta
|
|||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
router = APIRouter(prefix="/organizations")
|
||||
|
||||
|
|
@ -28,6 +31,12 @@ class CurrentUsageResponse(BaseModel):
|
|||
price_per_second_usd: Optional[float] = None
|
||||
|
||||
|
||||
class MPSCreditsResponse(BaseModel):
|
||||
total_credits_used: float
|
||||
remaining_credits: float
|
||||
total_quota: float
|
||||
|
||||
|
||||
class WorkflowRunUsageResponse(BaseModel):
|
||||
id: int
|
||||
workflow_id: int
|
||||
|
|
@ -85,6 +94,42 @@ async def get_current_period_usage(user: UserModel = Depends(get_user)):
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/usage/mps-credits", response_model=MPSCreditsResponse)
|
||||
async def get_mps_credits(user: UserModel = Depends(get_user)):
|
||||
"""Get aggregated usage and quota from MPS.
|
||||
|
||||
OSS users: queries by provider_id (created_by).
|
||||
Hosted users: queries by organization_id.
|
||||
"""
|
||||
try:
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
usage = await mps_service_key_client.get_usage_by_created_by(
|
||||
str(user.provider_id)
|
||||
)
|
||||
else:
|
||||
if not user.selected_organization_id:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No organization selected"
|
||||
)
|
||||
usage = await mps_service_key_client.get_usage_by_organization(
|
||||
user.selected_organization_id
|
||||
)
|
||||
|
||||
total_used = usage.get("total_credits_used", 0.0)
|
||||
total_remaining = usage.get("remaining_credits", 0.0)
|
||||
|
||||
return MPSCreditsResponse(
|
||||
total_credits_used=total_used,
|
||||
remaining_credits=total_remaining,
|
||||
total_quota=total_used + total_remaining,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch MPS credits: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/usage/runs", response_model=UsageHistoryResponse)
|
||||
async def get_usage_history(
|
||||
start_date: Optional[str] = Query(None, description="ISO format date string"),
|
||||
|
|
|
|||
218
api/routes/workflow_recording.py
Normal file
218
api/routes/workflow_recording.py
Normal file
|
|
@ -0,0 +1,218 @@
|
|||
"""API routes for workflow recording operations."""
|
||||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.workflow_recording_client import generate_short_id
|
||||
from api.enums import StorageBackend
|
||||
from api.schemas.workflow_recording import (
|
||||
RecordingCreateRequestSchema,
|
||||
RecordingListResponseSchema,
|
||||
RecordingResponseSchema,
|
||||
RecordingUploadRequestSchema,
|
||||
RecordingUploadResponseSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
router = APIRouter(prefix="/workflow-recordings", tags=["workflow-recordings"])
|
||||
|
||||
|
||||
async def _generate_unique_recording_id() -> str:
|
||||
"""Generate a globally unique short recording ID."""
|
||||
for _ in range(10):
|
||||
rid = generate_short_id(8)
|
||||
exists = await db_client.check_recording_id_exists(rid)
|
||||
if not exists:
|
||||
return rid
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate unique recording ID"
|
||||
)
|
||||
|
||||
|
||||
def _build_response(rec) -> RecordingResponseSchema:
|
||||
return RecordingResponseSchema(
|
||||
id=rec.id,
|
||||
recording_id=rec.recording_id,
|
||||
workflow_id=rec.workflow_id,
|
||||
organization_id=rec.organization_id,
|
||||
tts_provider=rec.tts_provider,
|
||||
tts_model=rec.tts_model,
|
||||
tts_voice_id=rec.tts_voice_id,
|
||||
transcript=rec.transcript,
|
||||
storage_key=rec.storage_key,
|
||||
storage_backend=rec.storage_backend,
|
||||
metadata=rec.recording_metadata or {},
|
||||
created_by=rec.created_by,
|
||||
created_at=rec.created_at,
|
||||
is_active=rec.is_active,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload-url",
|
||||
response_model=RecordingUploadResponseSchema,
|
||||
summary="Get presigned URL for recording upload",
|
||||
)
|
||||
async def get_upload_url(
|
||||
request: RecordingUploadRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Generate a presigned PUT URL for uploading an audio recording."""
|
||||
try:
|
||||
recording_id = await _generate_unique_recording_id()
|
||||
|
||||
storage_key = (
|
||||
f"recordings/{user.selected_organization_id}"
|
||||
f"/{request.workflow_id}/{recording_id}"
|
||||
f"/{request.filename}"
|
||||
)
|
||||
|
||||
upload_url = await storage_fs.aget_presigned_put_url(
|
||||
file_path=storage_key,
|
||||
expiration=1800, # 30 minutes
|
||||
content_type=request.mime_type,
|
||||
max_size=5_242_880, # 5MB max
|
||||
)
|
||||
|
||||
if not upload_url:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate presigned upload URL"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generated recording upload URL: {recording_id}, "
|
||||
f"workflow {request.workflow_id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return RecordingUploadResponseSchema(
|
||||
upload_url=upload_url,
|
||||
recording_id=recording_id,
|
||||
storage_key=storage_key,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error generating recording upload URL: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to generate upload URL"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/",
|
||||
response_model=RecordingResponseSchema,
|
||||
summary="Create recording record after upload",
|
||||
)
|
||||
async def create_recording(
|
||||
request: RecordingCreateRequestSchema,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Create a recording record after the audio has been uploaded to storage."""
|
||||
try:
|
||||
backend = StorageBackend.get_current_backend()
|
||||
|
||||
recording = await db_client.create_recording(
|
||||
recording_id=request.recording_id,
|
||||
workflow_id=request.workflow_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
tts_provider=request.tts_provider,
|
||||
tts_model=request.tts_model,
|
||||
tts_voice_id=request.tts_voice_id,
|
||||
transcript=request.transcript,
|
||||
storage_key=request.storage_key,
|
||||
storage_backend=backend.value,
|
||||
created_by=user.id,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created recording {request.recording_id} for workflow {request.workflow_id}"
|
||||
)
|
||||
|
||||
return _build_response(recording)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error creating recording: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create recording"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.get(
|
||||
"/",
|
||||
response_model=RecordingListResponseSchema,
|
||||
summary="List recordings for a workflow",
|
||||
)
|
||||
async def list_recordings(
|
||||
workflow_id: Annotated[int, Query(description="Workflow ID")],
|
||||
tts_provider: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS provider")
|
||||
] = None,
|
||||
tts_model: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS model")
|
||||
] = None,
|
||||
tts_voice_id: Annotated[
|
||||
Optional[str], Query(description="Filter by TTS voice ID")
|
||||
] = None,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""List recordings for a workflow, optionally filtered by TTS configuration."""
|
||||
try:
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
tts_provider=tts_provider,
|
||||
tts_model=tts_model,
|
||||
tts_voice_id=tts_voice_id,
|
||||
)
|
||||
|
||||
return RecordingListResponseSchema(
|
||||
recordings=[_build_response(r) for r in recordings],
|
||||
total=len(recordings),
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error listing recordings: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to list recordings"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{recording_id}",
|
||||
summary="Delete a recording",
|
||||
)
|
||||
async def delete_recording(
|
||||
recording_id: str,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Soft delete a recording."""
|
||||
try:
|
||||
success = await db_client.delete_recording(
|
||||
recording_id=recording_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail="Recording not found")
|
||||
|
||||
logger.info(
|
||||
f"Deleted recording {recording_id}, org {user.selected_organization_id}"
|
||||
)
|
||||
|
||||
return {"success": True, "message": "Recording deleted successfully"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"Error deleting recording: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete recording"
|
||||
) from exc
|
||||
73
api/schemas/workflow_recording.py
Normal file
73
api/schemas/workflow_recording.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
"""Pydantic schemas for workflow recording operations."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RecordingUploadRequestSchema(BaseModel):
|
||||
"""Request schema for getting a presigned upload URL."""
|
||||
|
||||
workflow_id: int = Field(..., description="Workflow ID this recording belongs to")
|
||||
filename: str = Field(..., description="Original filename of the audio file")
|
||||
mime_type: str = Field(
|
||||
default="audio/wav", description="MIME type of the audio file"
|
||||
)
|
||||
file_size: int = Field(
|
||||
...,
|
||||
gt=0,
|
||||
le=5_242_880,
|
||||
description="File size in bytes (max 5MB)",
|
||||
)
|
||||
|
||||
|
||||
class RecordingUploadResponseSchema(BaseModel):
|
||||
"""Response schema with presigned upload URL."""
|
||||
|
||||
upload_url: str = Field(..., description="Presigned URL for uploading the audio")
|
||||
recording_id: str = Field(..., description="Short unique recording ID")
|
||||
storage_key: str = Field(..., description="Storage key where file will be uploaded")
|
||||
|
||||
|
||||
class RecordingCreateRequestSchema(BaseModel):
|
||||
"""Request schema for creating a recording record after upload."""
|
||||
|
||||
recording_id: str = Field(..., description="Short recording ID from upload step")
|
||||
workflow_id: int = Field(..., description="Workflow ID")
|
||||
tts_provider: str = Field(..., description="TTS provider (e.g. elevenlabs)")
|
||||
tts_model: str = Field(..., description="TTS model name")
|
||||
tts_voice_id: str = Field(..., description="TTS voice identifier")
|
||||
transcript: str = Field(
|
||||
..., description="User-provided transcript of the recording"
|
||||
)
|
||||
storage_key: str = Field(..., description="Storage key from upload step")
|
||||
metadata: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="Optional metadata (file_size, duration, etc.)"
|
||||
)
|
||||
|
||||
|
||||
class RecordingResponseSchema(BaseModel):
|
||||
"""Response schema for a single recording."""
|
||||
|
||||
id: int
|
||||
recording_id: str
|
||||
workflow_id: int
|
||||
organization_id: int
|
||||
tts_provider: str
|
||||
tts_model: str
|
||||
tts_voice_id: str
|
||||
transcript: str
|
||||
storage_key: str
|
||||
storage_backend: str
|
||||
metadata: Dict[str, Any]
|
||||
created_by: int
|
||||
created_at: datetime
|
||||
is_active: bool
|
||||
|
||||
|
||||
class RecordingListResponseSchema(BaseModel):
|
||||
"""Response schema for list of recordings."""
|
||||
|
||||
recordings: List[RecordingResponseSchema]
|
||||
total: int
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Dict, Optional, TypedDict
|
||||
from typing import Optional, TypedDict
|
||||
|
||||
import openai
|
||||
from deepgram import DeepgramClient
|
||||
|
|
@ -12,6 +12,7 @@ from api.schemas.user_configuration import (
|
|||
UserConfiguration,
|
||||
)
|
||||
from api.services.configuration.registry import ServiceConfig, ServiceProviders
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
|
||||
|
||||
class APIKeyStatus(TypedDict):
|
||||
|
|
@ -25,7 +26,6 @@ class APIKeyStatusResponse(TypedDict):
|
|||
|
||||
class UserConfigurationValidator:
|
||||
def __init__(self):
|
||||
self._provider_api_key_validity_status: Dict[str, bool] = {}
|
||||
self._validator_map = {
|
||||
ServiceProviders.OPENAI.value: self._check_openai_api_key,
|
||||
ServiceProviders.DEEPGRAM.value: self._check_deepgram_api_key,
|
||||
|
|
@ -73,8 +73,13 @@ class UserConfigurationValidator:
|
|||
provider = service_config.provider
|
||||
api_key = service_config.api_key
|
||||
|
||||
if not self._check_api_key(provider, api_key):
|
||||
return [{"model": service_name, "message": f"Invalid {provider} API key"}]
|
||||
try:
|
||||
if not self._check_api_key(provider, api_key):
|
||||
return [
|
||||
{"model": service_name, "message": f"Invalid {provider} API key"}
|
||||
]
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
|
||||
return []
|
||||
|
||||
|
|
@ -87,40 +92,28 @@ class UserConfigurationValidator:
|
|||
return validator(provider, api_key)
|
||||
|
||||
def _check_openai_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except openai.AuthenticationError:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _check_deepgram_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
try:
|
||||
deepgram = DeepgramClient(api_key=api_key)
|
||||
deepgram.manage.v1.projects.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except Exception:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _check_groq_api_key(self, model: str, api_key: str) -> bool:
|
||||
if model in self._provider_api_key_validity_status:
|
||||
return self._provider_api_key_validity_status[model]
|
||||
|
||||
client = Groq(api_key=api_key)
|
||||
try:
|
||||
client.models.list()
|
||||
self._provider_api_key_validity_status[model] = True
|
||||
return True
|
||||
except Exception:
|
||||
self._provider_api_key_validity_status[model] = False
|
||||
return self._provider_api_key_validity_status[model]
|
||||
return False
|
||||
|
||||
def _validate_elevenlabs_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
@ -135,7 +128,12 @@ class UserConfigurationValidator:
|
|||
return True
|
||||
|
||||
def _check_dograh_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
if api_key.startswith("dgr"):
|
||||
raise ValueError(
|
||||
"You provided a Dograh API key (dgr...) instead of a service key. "
|
||||
"Please use a service key (mps...)."
|
||||
)
|
||||
return mps_service_key_client.validate_service_key(api_key)
|
||||
|
||||
def _check_sarvam_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -285,6 +285,90 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def get_usage_by_created_by(self, created_by: str) -> dict:
|
||||
"""
|
||||
Get aggregated usage for all service keys created by a user (OSS mode).
|
||||
|
||||
Args:
|
||||
created_by: The user's provider ID
|
||||
|
||||
Returns:
|
||||
Dictionary containing total_credits_used and remaining_credits
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage/created-by",
|
||||
json={"created_by": created_by},
|
||||
headers=self._get_headers(created_by=created_by),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return {
|
||||
"total_credits_used": data.get("total_credits_used", 0.0),
|
||||
"remaining_credits": data.get("remaining_credits", 0.0),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get usage by created_by: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get usage by created_by: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
async def get_usage_by_organization(self, organization_id: int) -> dict:
|
||||
"""
|
||||
Get aggregated usage for all service keys belonging to an organization (hosted mode).
|
||||
|
||||
Args:
|
||||
organization_id: The organization's ID
|
||||
|
||||
Returns:
|
||||
Dictionary containing total_credits_used and remaining_credits
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=self.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage/organization",
|
||||
json={"organization_id": organization_id},
|
||||
headers=self._get_headers(organization_id=organization_id),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
return {
|
||||
"total_credits_used": data.get("total_credits_used", 0.0),
|
||||
"remaining_credits": data.get("remaining_credits", 0.0),
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get usage by organization: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to get usage by organization: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def validate_service_key(self, service_key: str) -> bool:
|
||||
"""
|
||||
Synchronously validate a Dograh service key by checking usage via MPS.
|
||||
|
||||
Returns True if the key is valid, False otherwise.
|
||||
"""
|
||||
try:
|
||||
with httpx.Client(timeout=self.timeout) as client:
|
||||
response = client.post(
|
||||
f"{self.base_url}/api/v1/service-keys/usage",
|
||||
json={"service_key": service_key},
|
||||
headers=self._get_headers(),
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
logger.warning("Failed to validate Dograh service key via MPS")
|
||||
return False
|
||||
|
||||
async def get_voices(
|
||||
self,
|
||||
provider: str,
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ def build_pipeline(
|
|||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=None,
|
||||
recording_router=None,
|
||||
):
|
||||
"""Build the main pipeline with all components.
|
||||
|
||||
|
|
@ -47,6 +48,9 @@ def build_pipeline(
|
|||
voicemail_detector: Optional native pipecat VoicemailDetector. When provided,
|
||||
inserts voicemail detection after STT. Note: We don't use the TTS gate
|
||||
to avoid blocking TTS frames during classification.
|
||||
recording_router: Optional RecordingRouterProcessor. When provided,
|
||||
inserts between callback processor and TTS to route between
|
||||
pre-recorded audio playback and dynamic TTS.
|
||||
"""
|
||||
# Build processors list with optional voicemail detection
|
||||
processors = [
|
||||
|
|
@ -66,11 +70,15 @@ def build_pipeline(
|
|||
processors.append(voicemail_detector.detector())
|
||||
|
||||
# Continue with the rest of the pipeline
|
||||
post_llm = [pipeline_engine_callback_processor]
|
||||
if recording_router:
|
||||
post_llm.append(recording_router)
|
||||
|
||||
processors.extend(
|
||||
[
|
||||
user_context_aggregator,
|
||||
llm, # LLM
|
||||
pipeline_engine_callback_processor,
|
||||
*post_llm,
|
||||
tts, # TTS
|
||||
transport.output(), # Transport bot output
|
||||
audio_buffer, # AudioBufferProcessor - records both input and output audio
|
||||
|
|
|
|||
|
|
@ -37,10 +37,10 @@ from pipecat.frames.frames import (
|
|||
FunctionCallResultFrame,
|
||||
InterimTranscriptionFrame,
|
||||
InterruptionFrame,
|
||||
LLMTextFrame,
|
||||
MetricsFrame,
|
||||
StopFrame,
|
||||
TranscriptionFrame,
|
||||
TTSTextFrame,
|
||||
)
|
||||
from pipecat.metrics.metrics import TTFBMetricsData
|
||||
from pipecat.observers.base_observer import BaseObserver, FramePushed
|
||||
|
|
@ -207,7 +207,7 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
)
|
||||
# Handle bot TTS text - respect pts timing, WebSocket only
|
||||
# Complete turn text is persisted via register_turn_handlers
|
||||
elif isinstance(frame, TTSTextFrame):
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
message = {
|
||||
"type": RealtimeFeedbackType.BOT_TEXT.value,
|
||||
"payload": {
|
||||
|
|
|
|||
338
api/services/pipecat/recording_audio_cache.py
Normal file
338
api/services/pipecat/recording_audio_cache.py
Normal file
|
|
@ -0,0 +1,338 @@
|
|||
"""Filesystem-backed cache and audio fetcher for workflow recordings.
|
||||
|
||||
Downloads recording files from object storage on first access, converts them
|
||||
to raw 16-bit mono PCM at the pipeline sample rate via ffmpeg, trims
|
||||
leading/trailing silence, and caches the processed bytes on disk so
|
||||
subsequent plays (even from other workers) are instantaneous.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import APP_ROOT_DIR
|
||||
from pipecat.audio.utils import SPEAKING_THRESHOLD
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filesystem cache directory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CACHE_DIR = os.path.join(os.path.dirname(APP_ROOT_DIR), "dograh_pcm_cache")
|
||||
os.makedirs(_CACHE_DIR, exist_ok=True)
|
||||
|
||||
|
||||
def _cache_path(recording_id: str, sample_rate: int) -> str:
|
||||
"""Return the on-disk path for a cached PCM file."""
|
||||
return os.path.join(_CACHE_DIR, f"{recording_id}_{sample_rate}.pcm")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public factory
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def create_recording_audio_fetcher(
|
||||
organization_id: int,
|
||||
pipeline_sample_rate: int,
|
||||
) -> Callable[[str], Awaitable[Optional[bytes]]]:
|
||||
"""Create an async callback that returns raw PCM bytes for a recording_id.
|
||||
|
||||
The returned callable:
|
||||
1. Checks the filesystem cache (keyed by ``recording_id`` + sample rate).
|
||||
2. On miss, looks up the recording in the DB, downloads the audio file
|
||||
from S3/MinIO, converts it to 16-bit mono PCM at *pipeline_sample_rate*,
|
||||
trims leading/trailing silence, caches the result on disk, and returns it.
|
||||
|
||||
Args:
|
||||
organization_id: Organization owning the recordings.
|
||||
pipeline_sample_rate: Target PCM sample rate for the pipeline.
|
||||
|
||||
Returns:
|
||||
``async (recording_id: str) -> Optional[bytes]``
|
||||
"""
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_storage_for_backend
|
||||
|
||||
# Resolve storage instances once per backend at creation time, not per fetch.
|
||||
_storage_cache: dict[str, object] = {}
|
||||
|
||||
def _get_storage(backend: str):
|
||||
if backend not in _storage_cache:
|
||||
_storage_cache[backend] = get_storage_for_backend(backend)
|
||||
return _storage_cache[backend]
|
||||
|
||||
async def fetch(recording_id: str) -> Optional[bytes]:
|
||||
cached = _cache_path(recording_id, pipeline_sample_rate)
|
||||
|
||||
# 1. Serve from filesystem cache
|
||||
if os.path.exists(cached):
|
||||
logger.debug(f"Recording {recording_id} served from disk cache")
|
||||
return _read_file(cached)
|
||||
|
||||
# 2. DB lookup
|
||||
recording = await db_client.get_recording_by_recording_id(
|
||||
recording_id, organization_id
|
||||
)
|
||||
if not recording:
|
||||
logger.warning(f"Recording {recording_id} not found in database")
|
||||
return None
|
||||
|
||||
# 3. Download, convert, trim, and cache
|
||||
pcm_data = await _download_and_convert(
|
||||
recording, pipeline_sample_rate, _get_storage
|
||||
)
|
||||
return pcm_data
|
||||
|
||||
return fetch
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache warming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def warm_recording_cache(
|
||||
workflow_id: int,
|
||||
organization_id: int,
|
||||
pipeline_sample_rate: int,
|
||||
) -> None:
|
||||
"""Pre-fetch all active recordings for a workflow into the disk cache.
|
||||
|
||||
Launched as a background ``asyncio.Task`` at pipeline startup so that
|
||||
recordings are ready before the first playback request. Errors are logged
|
||||
but never propagated — a cache miss falls back to the on-demand fetch path.
|
||||
"""
|
||||
from api.db import db_client
|
||||
from api.services.storage import get_storage_for_backend
|
||||
|
||||
try:
|
||||
recordings = await db_client.get_recordings_for_workflow(
|
||||
workflow_id, organization_id
|
||||
)
|
||||
if not recordings:
|
||||
return
|
||||
|
||||
# Skip if every recording is already cached on disk
|
||||
uncached = [
|
||||
r
|
||||
for r in recordings
|
||||
if not os.path.exists(_cache_path(r.recording_id, pipeline_sample_rate))
|
||||
]
|
||||
if not uncached:
|
||||
logger.debug(f"Recording cache already warm for workflow {workflow_id}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Warming recording cache: {len(uncached)}/{len(recordings)} "
|
||||
f"recording(s) for workflow {workflow_id}"
|
||||
)
|
||||
|
||||
# Resolve storage instances once per backend, not per recording
|
||||
storage_by_backend: dict[str, object] = {}
|
||||
|
||||
def _get_storage(backend: str):
|
||||
if backend not in storage_by_backend:
|
||||
storage_by_backend[backend] = get_storage_for_backend(backend)
|
||||
return storage_by_backend[backend]
|
||||
|
||||
for recording in uncached:
|
||||
try:
|
||||
pcm_data = await _download_and_convert(
|
||||
recording, pipeline_sample_rate, _get_storage
|
||||
)
|
||||
if pcm_data:
|
||||
logger.debug(
|
||||
f"Cache warm: loaded {recording.recording_id} "
|
||||
f"({len(pcm_data)} bytes)"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Cache warm: error processing {recording.recording_id}"
|
||||
)
|
||||
|
||||
logger.info(f"Recording cache warm complete for workflow {workflow_id}")
|
||||
except Exception:
|
||||
logger.exception("Recording cache warm failed")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared download → convert → trim → cache-to-disk helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _download_and_convert(
|
||||
recording, sample_rate: int, get_storage_fn
|
||||
) -> Optional[bytes]:
|
||||
"""Download a recording from storage, convert to PCM, trim, and cache to disk.
|
||||
|
||||
Returns the processed PCM bytes, or None on failure.
|
||||
"""
|
||||
ext = _ext_from_key(recording.storage_key)
|
||||
fd, tmp_path = tempfile.mkstemp(suffix=ext, prefix=f"dograh_dl_{recording.recording_id}_")
|
||||
os.close(fd)
|
||||
try:
|
||||
storage = get_storage_fn(recording.storage_backend)
|
||||
success = await storage.adownload_file(recording.storage_key, tmp_path)
|
||||
if not success:
|
||||
logger.error(f"Failed to download recording {recording.recording_id}")
|
||||
return None
|
||||
|
||||
pcm_data = await _audio_file_to_pcm(tmp_path, sample_rate)
|
||||
if pcm_data is None:
|
||||
return None
|
||||
|
||||
pcm_data = _trim_silence(pcm_data, sample_rate)
|
||||
|
||||
# Write to disk cache atomically (write to tmp then rename)
|
||||
cached = _cache_path(recording.recording_id, sample_rate)
|
||||
fd, tmp_cache = tempfile.mkstemp(dir=_CACHE_DIR, suffix=".pcm.tmp")
|
||||
os.close(fd)
|
||||
_write_file(tmp_cache, pcm_data)
|
||||
os.replace(tmp_cache, cached)
|
||||
|
||||
return pcm_data
|
||||
except Exception:
|
||||
logger.exception(f"Error fetching recording {recording.recording_id}")
|
||||
return None
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# File I/O helpers (run via asyncio.to_thread)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _read_file(path: str) -> bytes:
|
||||
with open(path, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def _write_file(path: str, data: bytes) -> None:
|
||||
with open(path, "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Audio conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _audio_file_to_pcm(
|
||||
file_path: str, target_sample_rate: int
|
||||
) -> Optional[bytes]:
|
||||
"""Convert an audio file to raw 16-bit mono PCM bytes via ffmpeg."""
|
||||
ffmpeg = shutil.which("ffmpeg")
|
||||
if not ffmpeg:
|
||||
logger.error("ffmpeg not found on PATH — cannot decode recording")
|
||||
return None
|
||||
|
||||
cmd = [
|
||||
ffmpeg,
|
||||
"-i",
|
||||
file_path,
|
||||
"-f",
|
||||
"s16le", # raw 16-bit signed little-endian PCM
|
||||
"-acodec",
|
||||
"pcm_s16le",
|
||||
"-ac",
|
||||
"1", # mono
|
||||
"-ar",
|
||||
str(target_sample_rate),
|
||||
"-loglevel",
|
||||
"error",
|
||||
"pipe:1", # output to stdout
|
||||
]
|
||||
|
||||
try:
|
||||
proc = await asyncio.create_subprocess_exec(
|
||||
*cmd,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await proc.communicate()
|
||||
|
||||
if proc.returncode != 0:
|
||||
logger.error(f"ffmpeg failed (rc={proc.returncode}): {stderr.decode()}")
|
||||
return None
|
||||
|
||||
if not stdout:
|
||||
logger.error("ffmpeg produced no output")
|
||||
return None
|
||||
|
||||
return stdout
|
||||
except Exception:
|
||||
logger.exception("ffmpeg subprocess error")
|
||||
return None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Silence trimming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _trim_silence(pcm_data: bytes, sample_rate: int) -> bytes:
|
||||
"""Trim leading and trailing silence from raw 16-bit mono PCM bytes.
|
||||
|
||||
Uses 10ms frames and the same amplitude threshold as pipecat's
|
||||
``is_silence`` to detect speech boundaries.
|
||||
"""
|
||||
data = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
frame_size = int(sample_rate * 0.01) # 10ms frames
|
||||
num_frames = len(data) // frame_size
|
||||
|
||||
if num_frames == 0:
|
||||
return pcm_data
|
||||
|
||||
# Find first non-silent frame
|
||||
first_speech = None
|
||||
for i in range(num_frames):
|
||||
frame = data[i * frame_size : (i + 1) * frame_size]
|
||||
if np.abs(frame).max() > SPEAKING_THRESHOLD:
|
||||
first_speech = i
|
||||
break
|
||||
|
||||
if first_speech is None:
|
||||
# Entire clip is silence — return as-is to avoid empty audio
|
||||
return pcm_data
|
||||
|
||||
# Find last non-silent frame
|
||||
last_speech = first_speech
|
||||
for i in range(num_frames - 1, first_speech - 1, -1):
|
||||
frame = data[i * frame_size : (i + 1) * frame_size]
|
||||
if np.abs(frame).max() > SPEAKING_THRESHOLD:
|
||||
last_speech = i
|
||||
break
|
||||
|
||||
start = first_speech * frame_size
|
||||
end = (last_speech + 1) * frame_size
|
||||
trimmed = data[start:end]
|
||||
|
||||
trimmed_duration = len(trimmed) / sample_rate
|
||||
original_duration = len(data) / sample_rate
|
||||
if original_duration - trimmed_duration > 0.05:
|
||||
logger.debug(
|
||||
f"Trimmed silence: {original_duration:.2f}s → {trimmed_duration:.2f}s"
|
||||
)
|
||||
|
||||
return trimmed.tobytes()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ext_from_key(storage_key: str) -> str:
|
||||
"""Extract file extension from a storage key, defaulting to .wav."""
|
||||
_, ext = os.path.splitext(storage_key)
|
||||
return ext if ext else ".wav"
|
||||
254
api/services/pipecat/recording_router_processor.py
Normal file
254
api/services/pipecat/recording_router_processor.py
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
"""Recording router processor for routing LLM output between TTS and pre-recorded audio.
|
||||
|
||||
Sits between the LLM (after pipeline_engine_callbacks_processor) and TTS in the
|
||||
pipeline. Detects response mode markers (▸ for TTS, ● for recording) and routes
|
||||
accordingly:
|
||||
|
||||
- ▸ (TTS): Strips the marker, passes remaining text downstream to TTS.
|
||||
- ● (Recording): Suppresses TTS, fetches cached audio, pushes
|
||||
OutputAudioRawFrame downstream.
|
||||
|
||||
Pattern modelled after ``pipecat.turns.user_turn_completion_mixin`` – buffer
|
||||
streaming LLM text tokens until the mode marker is detected, then act.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Awaitable, Callable, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
RECORDING_MARKER,
|
||||
TTS_MARKER,
|
||||
)
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InterruptionFrame,
|
||||
LLMFullResponseEndFrame,
|
||||
LLMTextFrame,
|
||||
TTSAudioRawFrame,
|
||||
TTSStartedFrame,
|
||||
TTSStoppedFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class RecordingRouterProcessor(FrameProcessor):
|
||||
"""Routes LLM responses between TTS and pre-recorded audio playback.
|
||||
|
||||
When the LLM prefixes its response with:
|
||||
- ``▸`` – text flows to TTS as normal speech.
|
||||
- ``●`` – text is suppressed (skip_tts), and the referenced recording is
|
||||
fetched (with local disk cache) and streamed as ``OutputAudioRawFrame``.
|
||||
|
||||
If no marker is detected by the end of the response, text is passed through
|
||||
to TTS as a graceful degradation.
|
||||
|
||||
Args:
|
||||
audio_sample_rate: Pipeline sample rate for OutputAudioRawFrame.
|
||||
fetch_recording_audio: Async callback that takes a recording_id and
|
||||
returns raw 16-bit mono PCM bytes, or None on failure.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
audio_sample_rate: int,
|
||||
fetch_recording_audio: Callable[[str], Awaitable[Optional[bytes]]],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._audio_sample_rate = audio_sample_rate
|
||||
self._fetch_recording_audio = fetch_recording_audio
|
||||
|
||||
# Per-response state
|
||||
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
|
||||
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
|
||||
self._recording_id_buffer = ""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame dispatch
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
elif isinstance(frame, LLMTextFrame):
|
||||
await self._handle_llm_text(frame, direction)
|
||||
elif isinstance(frame, LLMFullResponseEndFrame):
|
||||
await self._handle_response_end(frame, direction)
|
||||
else:
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# LLMTextFrame handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_llm_text(self, frame: LLMTextFrame, direction: FrameDirection):
|
||||
# Pass through frames already marked skip_tts (e.g. turn completion ✓)
|
||||
if frame.skip_tts:
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- TTS mode established: pass text through normally ---
|
||||
if self._mode == "tts":
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- Recording mode: buffer recording_id, suppress TTS ---
|
||||
if self._mode == "recording":
|
||||
self._recording_id_buffer += frame.text
|
||||
frame.skip_tts = True
|
||||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- Detection mode: buffer until marker found ---
|
||||
self._frame_buffer.append((frame, direction))
|
||||
buffered_text = self._buffered_text()
|
||||
|
||||
# Check for recording marker (●)
|
||||
if RECORDING_MARKER in buffered_text:
|
||||
self._mode = "recording"
|
||||
marker_end = buffered_text.index(RECORDING_MARKER) + len(RECORDING_MARKER)
|
||||
|
||||
# Push buffered frames with skip_tts, extract recording_id from post-marker text
|
||||
cumulative = 0
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
buf_frame.skip_tts = True
|
||||
frame_start = cumulative
|
||||
cumulative += len(buf_frame.text)
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
# Capture any recording_id text after the marker
|
||||
if cumulative > marker_end:
|
||||
offset = max(marker_end - frame_start, 0)
|
||||
remaining = buf_frame.text[offset:]
|
||||
if not self._recording_id_buffer and remaining.startswith(" "):
|
||||
remaining = remaining[1:]
|
||||
self._recording_id_buffer += remaining
|
||||
|
||||
self._frame_buffer = []
|
||||
return
|
||||
|
||||
# Check for TTS marker (▸)
|
||||
if TTS_MARKER in buffered_text:
|
||||
self._mode = "tts"
|
||||
marker_end = buffered_text.index(TTS_MARKER) + len(TTS_MARKER)
|
||||
|
||||
# Push buffered frames — skip_tts for marker portion, normal for the rest
|
||||
cumulative = 0
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
frame_start = cumulative
|
||||
cumulative += len(buf_frame.text)
|
||||
|
||||
if cumulative <= marker_end:
|
||||
# Entirely within marker portion — suppress TTS
|
||||
buf_frame.skip_tts = True
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
elif frame_start >= marker_end:
|
||||
# Entirely after marker — normal TTS speech
|
||||
if frame_start == marker_end and buf_frame.text.startswith(" "):
|
||||
buf_frame.text = buf_frame.text[1:]
|
||||
if buf_frame.text:
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
else:
|
||||
# Frame spans the marker boundary — split
|
||||
offset = marker_end - frame_start
|
||||
original_text = buf_frame.text
|
||||
buf_frame.text = original_text[:offset]
|
||||
buf_frame.skip_tts = True
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
tts_text = original_text[offset:]
|
||||
if tts_text.startswith(" "):
|
||||
tts_text = tts_text[1:]
|
||||
if tts_text:
|
||||
await self.push_frame(LLMTextFrame(tts_text), buf_dir)
|
||||
|
||||
self._frame_buffer = []
|
||||
return
|
||||
|
||||
# Neither marker found yet — keep buffering (should arrive very soon)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# End-of-response handling
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _handle_response_end(
|
||||
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
|
||||
):
|
||||
if self._mode == "recording":
|
||||
recording_id = self._recording_id_buffer.strip()
|
||||
if recording_id:
|
||||
await self._play_recording(recording_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"RecordingRouterProcessor: recording mode but empty recording_id"
|
||||
)
|
||||
|
||||
elif self._mode is None and self._frame_buffer:
|
||||
# Graceful degradation: no marker detected, pass text to TTS as-is
|
||||
logger.warning(
|
||||
"RecordingRouterProcessor: no response mode marker found, "
|
||||
"passing text to TTS as-is"
|
||||
)
|
||||
for buf_frame, buf_dir in self._frame_buffer:
|
||||
await self.push_frame(buf_frame, buf_dir)
|
||||
|
||||
self._reset()
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Audio playback
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _play_recording(self, recording_id: str):
|
||||
"""Fetch recording audio and push TTSStarted → TTSAudioRaw → TTSStopped.
|
||||
|
||||
The transport handles chunking automatically. The Started/Stopped
|
||||
frames ensure downstream processors (transport, audio buffer, observers)
|
||||
treat this as a proper TTS utterance.
|
||||
"""
|
||||
logger.info(f"Playing pre-recorded audio: {recording_id}")
|
||||
|
||||
audio_data = await self._fetch_recording_audio(recording_id)
|
||||
if not audio_data:
|
||||
logger.warning(
|
||||
f"Failed to fetch recording {recording_id}, no audio will play"
|
||||
)
|
||||
return
|
||||
|
||||
context_id = str(uuid.uuid4())
|
||||
await self.push_frame(TTSStartedFrame(context_id=context_id))
|
||||
await self.push_frame(
|
||||
TTSAudioRawFrame(
|
||||
audio=audio_data,
|
||||
sample_rate=self._audio_sample_rate,
|
||||
num_channels=1,
|
||||
context_id=context_id,
|
||||
)
|
||||
)
|
||||
await self.push_frame(TTSStoppedFrame(context_id=context_id))
|
||||
|
||||
duration_secs = len(audio_data) / (self._audio_sample_rate * 2)
|
||||
logger.debug(
|
||||
f"Finished pushing recording {recording_id} "
|
||||
f"({len(audio_data)} bytes, {duration_secs:.1f}s)"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# State management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _buffered_text(self) -> str:
|
||||
"""Return concatenated text from the frame buffer."""
|
||||
return "".join(f.text for f, _ in self._frame_buffer)
|
||||
|
||||
def _reset(self):
|
||||
"""Reset per-response state."""
|
||||
self._frame_buffer = []
|
||||
self._mode = None
|
||||
self._recording_id_buffer = ""
|
||||
|
|
@ -27,6 +27,11 @@ from api.services.pipecat.realtime_feedback_observer import (
|
|||
RealtimeFeedbackObserver,
|
||||
register_turn_log_handlers,
|
||||
)
|
||||
from api.services.pipecat.recording_audio_cache import (
|
||||
create_recording_audio_fetcher,
|
||||
warm_recording_cache,
|
||||
)
|
||||
from api.services.pipecat.recording_router_processor import RecordingRouterProcessor
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
|
|
@ -558,6 +563,12 @@ async def _run_pipeline(
|
|||
embeddings_model = user_config.embeddings.model
|
||||
embeddings_base_url = getattr(user_config.embeddings, "base_url", None)
|
||||
|
||||
# Check if the workflow has any active recordings so the engine can
|
||||
# include recording response mode instructions in all node prompts.
|
||||
has_recordings = await db_client.has_active_recordings(
|
||||
workflow_id, workflow.organization_id
|
||||
)
|
||||
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
workflow=workflow_graph,
|
||||
|
|
@ -567,6 +578,7 @@ async def _run_pipeline(
|
|||
embeddings_api_key=embeddings_api_key,
|
||||
embeddings_model=embeddings_model,
|
||||
embeddings_base_url=embeddings_base_url,
|
||||
has_recordings=has_recordings,
|
||||
)
|
||||
|
||||
# Create pipeline components
|
||||
|
|
@ -680,6 +692,27 @@ async def _run_pipeline(
|
|||
abort_immediately=True,
|
||||
)
|
||||
|
||||
# Create recording router if workflow has active recordings
|
||||
recording_router = None
|
||||
if has_recordings:
|
||||
fetch_audio = create_recording_audio_fetcher(
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
recording_router = RecordingRouterProcessor(
|
||||
audio_sample_rate=audio_config.pipeline_sample_rate,
|
||||
fetch_recording_audio=fetch_audio,
|
||||
)
|
||||
# Warm the recording cache in the background so audio is ready
|
||||
# before the first playback request.
|
||||
asyncio.create_task(
|
||||
warm_recording_cache(
|
||||
workflow_id=workflow_id,
|
||||
organization_id=workflow.organization_id,
|
||||
pipeline_sample_rate=audio_config.pipeline_sample_rate,
|
||||
)
|
||||
)
|
||||
|
||||
# Build the pipeline with the STT mute filter and context controller
|
||||
pipeline = build_pipeline(
|
||||
transport,
|
||||
|
|
@ -692,6 +725,7 @@ async def _run_pipeline(
|
|||
pipeline_engine_callback_processor,
|
||||
pipeline_metrics_aggregator,
|
||||
voicemail_detector=voicemail_detector,
|
||||
recording_router=recording_router,
|
||||
)
|
||||
|
||||
# Create pipeline task with audio configuration
|
||||
|
|
|
|||
|
|
@ -5,25 +5,32 @@ from loguru import logger
|
|||
|
||||
from api.constants import MPS_API_URL
|
||||
from api.services.configuration.registry import ServiceProviders
|
||||
from pipecat.services.azure.llm import AzureLLMService
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService
|
||||
from pipecat.services.deepgram.flux.stt import DeepgramFluxSTTService
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService, LiveOptions
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
)
|
||||
from pipecat.services.deepgram.stt import DeepgramSTTService, DeepgramSTTSettings
|
||||
from pipecat.services.deepgram.tts import DeepgramTTSService, DeepgramTTSSettings
|
||||
from pipecat.services.dograh.llm import DograhLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService
|
||||
from pipecat.services.dograh.tts import DograhTTSService
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService
|
||||
from pipecat.services.google.llm import GoogleLLMService
|
||||
from pipecat.services.groq.llm import GroqLLMService
|
||||
from pipecat.services.dograh.stt import DograhSTTService, DograhSTTSettings
|
||||
from pipecat.services.dograh.tts import DograhTTSService, DograhTTSSettings
|
||||
from pipecat.services.elevenlabs.tts import ElevenLabsTTSService, ElevenLabsTTSSettings
|
||||
from pipecat.services.google.llm import GoogleLLMService, GoogleLLMSettings
|
||||
from pipecat.services.groq.llm import GroqLLMService, GroqLLMSettings
|
||||
from pipecat.services.openai.base_llm import OpenAILLMSettings
|
||||
from pipecat.services.openai.llm import OpenAILLMService
|
||||
from pipecat.services.openai.stt import OpenAISTTService
|
||||
from pipecat.services.openai.tts import OpenAITTSService
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService
|
||||
from pipecat.services.speechmatics.stt import SpeechmaticsSTTService
|
||||
from pipecat.services.openai.stt import OpenAISTTService, OpenAISTTSettings
|
||||
from pipecat.services.openai.tts import OpenAITTSService, OpenAITTSSettings
|
||||
from pipecat.services.openrouter.llm import OpenRouterLLMService, OpenRouterLLMSettings
|
||||
from pipecat.services.sarvam.stt import SarvamSTTService, SarvamSTTSettings
|
||||
from pipecat.services.sarvam.tts import SarvamTTSService, SarvamTTSSettings
|
||||
from pipecat.services.speechmatics.stt import (
|
||||
SpeechmaticsSTTService,
|
||||
SpeechmaticsSTTSettings,
|
||||
)
|
||||
from pipecat.transcriptions.language import Language
|
||||
from pipecat.utils.text.xml_function_tag_filter import XMLFunctionTagFilter
|
||||
|
||||
|
|
@ -49,8 +56,8 @@ def create_stt_service(
|
|||
logger.debug("Using DeepGram Flux Model")
|
||||
return DeepgramFluxSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=DeepgramFluxSTTService.InputParams(
|
||||
settings=DeepgramFluxSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
eot_timeout_ms=3000,
|
||||
eot_threshold=0.7,
|
||||
eager_eot_threshold=0.5,
|
||||
|
|
@ -63,23 +70,23 @@ def create_stt_service(
|
|||
# Other models than flux
|
||||
# Use language from user config, defaulting to "multi" for multilingual support
|
||||
language = getattr(user_config.stt, "language", None) or "multi"
|
||||
live_options = LiveOptions(
|
||||
language=language,
|
||||
profanity_filter=False,
|
||||
endpointing=100,
|
||||
model=user_config.stt.model,
|
||||
keyterm=keyterms or [],
|
||||
)
|
||||
logger.debug(f"Using DeepGram Model - {user_config.stt.model}")
|
||||
return DeepgramSTTService(
|
||||
live_options=live_options,
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=DeepgramSTTSettings(
|
||||
language=language,
|
||||
profanity_filter=False,
|
||||
endpointing=100,
|
||||
model=user_config.stt.model,
|
||||
keyterm=keyterms or [],
|
||||
),
|
||||
should_interrupt=False, # Let UserAggregator take care of sending InterruptionFrame
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAISTTService(
|
||||
api_key=user_config.stt.api_key, model=user_config.stt.model
|
||||
api_key=user_config.stt.api_key,
|
||||
settings=OpenAISTTSettings(model=user_config.stt.model),
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaSTTService(
|
||||
|
|
@ -92,8 +99,10 @@ def create_stt_service(
|
|||
return DograhSTTService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
settings=DograhSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=language,
|
||||
),
|
||||
keyterms=keyterms,
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
|
|
@ -117,8 +126,10 @@ def create_stt_service(
|
|||
pipecat_language = language_mapping.get(language, Language.HI_IN)
|
||||
return SarvamSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
model=user_config.stt.model,
|
||||
params=SarvamSTTService.InputParams(language=pipecat_language),
|
||||
settings=SarvamSTTSettings(
|
||||
model=user_config.stt.model,
|
||||
language=pipecat_language,
|
||||
),
|
||||
sample_rate=audio_config.transport_in_sample_rate,
|
||||
)
|
||||
elif user_config.stt.provider == ServiceProviders.SPEECHMATICS.value:
|
||||
|
|
@ -140,7 +151,7 @@ def create_stt_service(
|
|||
additional_vocab = [AdditionalVocabEntry(content=term) for term in keyterms]
|
||||
return SpeechmaticsSTTService(
|
||||
api_key=user_config.stt.api_key,
|
||||
params=SpeechmaticsSTTService.InputParams(
|
||||
settings=SpeechmaticsSTTSettings(
|
||||
language=language,
|
||||
operating_point=operating_point,
|
||||
additional_vocab=additional_vocab,
|
||||
|
|
@ -168,14 +179,16 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
if user_config.tts.provider == ServiceProviders.DEEPGRAM.value:
|
||||
return DeepgramTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice=user_config.tts.voice,
|
||||
settings=DeepgramTTSSettings(voice=user_config.tts.voice),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.OPENAI.value:
|
||||
return OpenAITTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
settings=OpenAITTSSettings(model=user_config.tts.model),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.ELEVENLABS.value:
|
||||
# Backward compatible with older configuration "Name - voice_id"
|
||||
|
|
@ -186,19 +199,25 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return ElevenLabsTTSService(
|
||||
reconnect_on_error=False,
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=voice_id,
|
||||
model=user_config.tts.model,
|
||||
params=ElevenLabsTTSService.InputParams(
|
||||
stability=0.8, speed=user_config.tts.speed, similarity_boost=0.75
|
||||
settings=ElevenLabsTTSSettings(
|
||||
voice=voice_id,
|
||||
model=user_config.tts.model,
|
||||
stability=0.8,
|
||||
speed=user_config.tts.speed,
|
||||
similarity_boost=0.75,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
|
||||
return CartesiaTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
voice_id=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
settings=CartesiaTTSSettings(
|
||||
voice=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.DOGRAH.value:
|
||||
# Convert HTTP URL to WebSocket URL for TTS
|
||||
|
|
@ -206,10 +225,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
return DograhTTSService(
|
||||
base_url=base_url,
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
params=DograhTTSService.InputParams(speed=user_config.tts.speed),
|
||||
settings=DograhTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=user_config.tts.voice,
|
||||
speed=user_config.tts.speed,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
elif user_config.tts.provider == ServiceProviders.SARVAM.value:
|
||||
# Map Sarvam language code to pipecat Language enum for TTS
|
||||
|
|
@ -232,10 +254,13 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
voice = getattr(user_config.tts, "voice", None) or "anushka"
|
||||
return SarvamTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
model=user_config.tts.model,
|
||||
voice_id=voice,
|
||||
params=SarvamTTSService.InputParams(language=pipecat_language),
|
||||
settings=SarvamTTSSettings(
|
||||
model=user_config.tts.model,
|
||||
voice=voice,
|
||||
language=pipecat_language,
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
|
@ -253,16 +278,15 @@ def create_llm_service(user_config):
|
|||
if "gpt-5" in model:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(
|
||||
reasoning_effort="minimal", verbosity="low"
|
||||
settings=OpenAILLMSettings(
|
||||
model=model,
|
||||
extra={"reasoning_effort": "minimal", "verbosity": "low"},
|
||||
),
|
||||
)
|
||||
else:
|
||||
return OpenAILLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=OpenAILLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GROQ.value:
|
||||
print(
|
||||
|
|
@ -270,36 +294,30 @@ def create_llm_service(user_config):
|
|||
)
|
||||
return GroqLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=GroqLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.OPENROUTER.value:
|
||||
return OpenRouterLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
base_url=user_config.llm.base_url,
|
||||
params=OpenAILLMService.InputParams(temperature=0.1),
|
||||
settings=OpenRouterLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.GOOGLE.value:
|
||||
# Use the correct InputParams class for Google to avoid propagating OpenAI-specific
|
||||
# NOT_GIVEN sentinels that break Pydantic validation in GoogleLLMService.
|
||||
return GoogleLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
params=GoogleLLMService.InputParams(temperature=0.1),
|
||||
settings=GoogleLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.AZURE.value:
|
||||
return AzureLLMService(
|
||||
api_key=user_config.llm.api_key,
|
||||
endpoint=user_config.llm.endpoint,
|
||||
model=model, # Azure uses deployment name as model
|
||||
params=AzureLLMService.InputParams(temperature=0.1),
|
||||
settings=AzureLLMSettings(model=model, temperature=0.1),
|
||||
)
|
||||
elif user_config.llm.provider == ServiceProviders.DOGRAH.value:
|
||||
return DograhLLMService(
|
||||
base_url=f"{MPS_API_URL}/api/v1/llm",
|
||||
api_key=user_config.llm.api_key,
|
||||
model=model,
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid LLM provider")
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from api.services.workflow.disposition_mapper import (
|
|||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
BotStartedSpeakingFrame,
|
||||
BotStoppedSpeakingFrame,
|
||||
|
|
@ -16,6 +17,7 @@ from pipecat.frames.frames import (
|
|||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
from pipecat.services.llm_service import FunctionCallParams
|
||||
from pipecat.services.settings import LLMSettings
|
||||
from pipecat.utils.enums import EndTaskReason
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -31,18 +33,19 @@ import asyncio
|
|||
from loguru import logger
|
||||
|
||||
from api.services.workflow import pipecat_engine_callbacks as engine_callbacks
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
from api.services.workflow.pipecat_engine_context_composer import (
|
||||
compose_functions_for_node,
|
||||
compose_system_prompt_for_node,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
get_function_schema,
|
||||
render_template,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_variable_extractor import (
|
||||
VariableExtractionManager,
|
||||
)
|
||||
from api.services.workflow.tools.calculator import get_calculator_tools, safe_calculator
|
||||
from api.services.workflow.tools.knowledge_base import (
|
||||
get_knowledge_base_tool,
|
||||
retrieve_from_knowledge_base,
|
||||
)
|
||||
from api.services.workflow.tools.timezone import (
|
||||
|
|
@ -50,6 +53,7 @@ from api.services.workflow.tools.timezone import (
|
|||
get_current_time,
|
||||
get_time_tools,
|
||||
)
|
||||
from api.utils.template_renderer import render_template
|
||||
|
||||
|
||||
class PipecatEngine:
|
||||
|
|
@ -68,6 +72,7 @@ class PipecatEngine:
|
|||
embeddings_api_key: Optional[str] = None,
|
||||
embeddings_model: Optional[str] = None,
|
||||
embeddings_base_url: Optional[str] = None,
|
||||
has_recordings: bool = False,
|
||||
):
|
||||
self.task = task
|
||||
self.llm = llm
|
||||
|
|
@ -113,6 +118,10 @@ class PipecatEngine:
|
|||
# Audio configuration (set via set_audio_config from _run_pipeline)
|
||||
self._audio_config = None
|
||||
|
||||
# True when the workflow has active recordings; enables recording
|
||||
# response mode instructions on all nodes for in-context learning.
|
||||
self._has_recordings: bool = has_recordings
|
||||
|
||||
async def _get_organization_id(self) -> Optional[int]:
|
||||
"""Get and cache the organization ID from workflow run."""
|
||||
if self._custom_tool_manager:
|
||||
|
|
@ -194,15 +203,14 @@ class PipecatEngine:
|
|||
logger.error(f"Error initializing {self.__class__.__name__}: {e}")
|
||||
raise
|
||||
|
||||
def _get_function_schema(self, function_name: str, description: str):
|
||||
"""Thin wrapper around utils.get_function_schema for backwards compatibility."""
|
||||
async def _update_llm_context(self, system_prompt: str, functions: list[dict]):
|
||||
"""Update LLM settings with the composed system prompt and tool list."""
|
||||
|
||||
return get_function_schema(function_name, description)
|
||||
await self.llm._update_settings(LLMSettings(system_instruction=system_prompt))
|
||||
|
||||
async def _update_llm_context(self, system_message: dict, functions: list[dict]):
|
||||
"""Delegate context update to the shared workflow.utils implementation."""
|
||||
|
||||
update_llm_context(self.context, system_message, functions)
|
||||
if functions:
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
self.context.set_tools(tools_schema)
|
||||
|
||||
def _format_prompt(self, prompt: str) -> str:
|
||||
"""Delegate prompt formatting to the shared workflow.utils implementation."""
|
||||
|
|
@ -473,12 +481,19 @@ class PipecatEngine:
|
|||
if node.document_uuids:
|
||||
await self._register_knowledge_base_function(node.document_uuids)
|
||||
|
||||
# Set up system message and functions
|
||||
(
|
||||
system_message,
|
||||
functions,
|
||||
) = await self._compose_system_message_functions_for_node(node)
|
||||
await self._update_llm_context(system_message, functions)
|
||||
# Compose prompt and functions via the context composer module
|
||||
system_prompt = compose_system_prompt_for_node(
|
||||
node=node,
|
||||
workflow=self.workflow,
|
||||
format_prompt=self._format_prompt,
|
||||
has_recordings=self._has_recordings,
|
||||
)
|
||||
functions = await compose_functions_for_node(
|
||||
node=node,
|
||||
builtin_function_schemas=self.builtin_function_schemas,
|
||||
custom_tool_manager=self._custom_tool_manager,
|
||||
)
|
||||
await self._update_llm_context(system_prompt, functions)
|
||||
|
||||
async def set_node(self, node_id: str):
|
||||
"""
|
||||
|
|
@ -610,62 +625,6 @@ class PipecatEngine:
|
|||
)
|
||||
await self.task.queue_frame(frame_to_push)
|
||||
|
||||
async def _compose_system_message_functions_for_node(
|
||||
self, node: "Node"
|
||||
) -> tuple[list[dict], list[dict]]:
|
||||
"""Generate the system messages and function schemas for the given node.
|
||||
|
||||
This performs the same formatting logic used when entering a node but
|
||||
does **not** register the functions with the LLM; callers are
|
||||
responsible for that.
|
||||
"""
|
||||
|
||||
global_prompt = ""
|
||||
if self.workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = self.workflow.nodes[self.workflow.global_node_id]
|
||||
global_prompt = self._format_prompt(global_node.prompt)
|
||||
|
||||
functions: list[dict] = []
|
||||
|
||||
# Add built-in function schemas (calculator and timezone tools)
|
||||
functions.extend(self.builtin_function_schemas)
|
||||
|
||||
# Add knowledge base retrieval tool if node has documents
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Add custom tools from node.tool_uuids
|
||||
if node.tool_uuids and self._custom_tool_manager:
|
||||
custom_tool_schemas = await self._custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition functions (schema only; registration handled elsewhere)
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = self._get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
formatted_node_prompt = self._format_prompt(node.prompt)
|
||||
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": "\n\n".join(
|
||||
p for p in (global_prompt, formatted_node_prompt) if p
|
||||
),
|
||||
}
|
||||
|
||||
return system_message, functions
|
||||
|
||||
async def should_mute_user(self, frame: "Frame") -> bool:
|
||||
"""
|
||||
Callback for CallbackUserMuteStrategy to determine if the user should be muted.
|
||||
|
|
|
|||
138
api/services/workflow/pipecat_engine_context_composer.py
Normal file
138
api/services/workflow/pipecat_engine_context_composer.py
Normal file
|
|
@ -0,0 +1,138 @@
|
|||
"""System prompt and function schema composition for PipecatEngine nodes.
|
||||
|
||||
Extracts prompt and function composition logic from PipecatEngine into
|
||||
reusable functions. Defines recording response mode markers and instructions.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.workflow import Node, WorkflowGraph
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
|
||||
from api.services.workflow.tools.knowledge_base import get_knowledge_base_tool
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording response mode markers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RECORDING_MARKER = "●" # Play pre-recorded audio
|
||||
TTS_MARKER = "▸" # Generate dynamic TTS text
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Recording response mode system prompt instructions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
RECORDING_RESPONSE_MODE_INSTRUCTIONS = """\
|
||||
RESPONSE MODE INSTRUCTIONS - MANDATORY FORMAT:
|
||||
Every response you generate MUST begin with a response mode indicator.
|
||||
You have two modes for responding:
|
||||
|
||||
1. DYNAMIC SPEECH (▸): Generate text that will be converted to speech by TTS.
|
||||
Format: `▸` followed by a space and your full spoken response.
|
||||
Example: ▸ Hello! How can I help you today?
|
||||
|
||||
2. PRE-RECORDED AUDIO (●): Play a pre-recorded audio message.
|
||||
Format: `●` followed by a space and ONLY the recording_id. Nothing else.
|
||||
Example: ● rec_greeting_01
|
||||
|
||||
RULES:
|
||||
- Your response MUST start with either `▸` or `●` as the very first character.
|
||||
- For `▸` (dynamic speech): Follow with a space and your full response text.
|
||||
- For `●` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
|
||||
- Use `●` when a pre-recorded message matches the situation well.
|
||||
- Use `▸` when you need to generate a dynamic, contextual response.
|
||||
- NEVER mix modes in a single response. Choose one."""
|
||||
|
||||
|
||||
def compose_system_prompt_for_node(
|
||||
*,
|
||||
node: "Node",
|
||||
workflow: "WorkflowGraph",
|
||||
format_prompt: Callable[[str], str],
|
||||
has_recordings: bool,
|
||||
) -> str:
|
||||
"""Compose the full system prompt text for a workflow node.
|
||||
|
||||
Combines the global prompt, node-specific prompt, and (when recordings
|
||||
are enabled anywhere in the workflow) the recording response mode
|
||||
instructions into a single string.
|
||||
|
||||
Args:
|
||||
node: The workflow node to compose the prompt for.
|
||||
workflow: The full workflow graph (needed for global node prompt).
|
||||
format_prompt: Callable to render template variables in prompts.
|
||||
has_recordings: Whether any node in the workflow uses recordings.
|
||||
|
||||
Returns:
|
||||
The composed system prompt text.
|
||||
"""
|
||||
global_prompt = ""
|
||||
if workflow.global_node_id and node.add_global_prompt:
|
||||
global_node = workflow.nodes[workflow.global_node_id]
|
||||
global_prompt = format_prompt(global_node.prompt)
|
||||
|
||||
formatted_node_prompt = format_prompt(node.prompt)
|
||||
|
||||
parts = [p for p in (global_prompt, formatted_node_prompt) if p]
|
||||
|
||||
if has_recordings:
|
||||
parts.append(RECORDING_RESPONSE_MODE_INSTRUCTIONS)
|
||||
# TODO: Append per-node available recordings list here once
|
||||
# Node.recording_ids is populated. The list should include
|
||||
# recording_id and a short description so the LLM can choose.
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
||||
async def compose_functions_for_node(
|
||||
*,
|
||||
node: "Node",
|
||||
builtin_function_schemas: list[dict],
|
||||
custom_tool_manager: Optional["CustomToolManager"],
|
||||
) -> list[dict]:
|
||||
"""Compose the function/tool schemas for a workflow node.
|
||||
|
||||
Gathers built-in tools, knowledge-base tools, custom tools,
|
||||
and transition function schemas into a single list.
|
||||
|
||||
Args:
|
||||
node: The workflow node to compose functions for.
|
||||
builtin_function_schemas: Pre-computed schemas for built-in tools.
|
||||
custom_tool_manager: Manager for user-defined custom tools (may be None).
|
||||
|
||||
Returns:
|
||||
A list of function schemas to register with the LLM.
|
||||
"""
|
||||
functions: list[dict] = []
|
||||
|
||||
# Built-in tools (calculator, timezone)
|
||||
functions.extend(builtin_function_schemas)
|
||||
|
||||
# Knowledge base retrieval tool
|
||||
if node.document_uuids:
|
||||
kb_tool_def = get_knowledge_base_tool(node.document_uuids)
|
||||
kb_schema = get_function_schema(
|
||||
kb_tool_def["function"]["name"],
|
||||
kb_tool_def["function"]["description"],
|
||||
properties=kb_tool_def["function"]["parameters"].get("properties", {}),
|
||||
required=kb_tool_def["function"]["parameters"].get("required", []),
|
||||
)
|
||||
functions.append(kb_schema)
|
||||
|
||||
# Custom tools
|
||||
if node.tool_uuids and custom_tool_manager:
|
||||
custom_tool_schemas = await custom_tool_manager.get_tool_schemas(
|
||||
node.tool_uuids
|
||||
)
|
||||
functions.extend(custom_tool_schemas)
|
||||
|
||||
# Transition function schemas
|
||||
for outgoing_edge in node.out_edges:
|
||||
function_schema = get_function_schema(
|
||||
outgoing_edge.get_function_name(), outgoing_edge.condition
|
||||
)
|
||||
functions.append(function_schema)
|
||||
|
||||
return functions
|
||||
|
|
@ -10,7 +10,7 @@ import asyncio
|
|||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -23,7 +23,6 @@ from api.services.telephony.transfer_event_protocol import TransferContext
|
|||
from api.services.workflow.disposition_mapper import (
|
||||
get_organization_id_from_workflow_run,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_utils import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
|
|
@ -42,6 +41,29 @@ if TYPE_CHECKING:
|
|||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
class CustomToolManager:
|
||||
"""Manager for custom tool registration and execution.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,68 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from api.utils.template_renderer import render_template
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
__all__ = [
|
||||
"get_function_schema",
|
||||
"update_llm_context",
|
||||
"render_template",
|
||||
]
|
||||
|
||||
|
||||
def get_function_schema(
|
||||
function_name: str,
|
||||
description: str,
|
||||
*,
|
||||
properties: Dict[str, Any] | None = None,
|
||||
required: List[str] | None = None,
|
||||
) -> FunctionSchema:
|
||||
"""Create a FunctionSchema definition that can later be transformed into
|
||||
the provider-specific format (OpenAI, Gemini, etc.).
|
||||
|
||||
The helper keeps the public signature backward-compatible – callers that
|
||||
only pass ``function_name`` and ``description`` continue to work and will
|
||||
define a parameter-less function.
|
||||
"""
|
||||
|
||||
return FunctionSchema(
|
||||
name=function_name,
|
||||
description=description,
|
||||
properties=properties or {},
|
||||
required=required or [],
|
||||
)
|
||||
|
||||
|
||||
def update_llm_context(
|
||||
context: LLMContext,
|
||||
system_message: Dict[str, Any],
|
||||
functions: List[FunctionSchema],
|
||||
) -> None:
|
||||
"""Update *context* with an up-to-date system message and tool list.
|
||||
|
||||
This helper removes any previous system messages before inserting the new
|
||||
*system_message* at the top of the conversation history and then instructs
|
||||
the LLM which *functions* (a.k.a. tools) are currently available.
|
||||
"""
|
||||
|
||||
# Wrap the provided function schemas in a ToolsSchema so that the adapter
|
||||
# associated with the current LLM service can convert them to the correct
|
||||
# provider-specific representation when required.
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
# Replace the first message if it's a system message, otherwise prepend.
|
||||
# Keep any system messages that appear in the middle of the conversation.
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
|
@ -13,14 +13,12 @@ from unittest.mock import AsyncMock, Mock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.services.workflow.pipecat_engine_custom_tools import get_function_schema
|
||||
from api.services.workflow.tools.custom_tool import (
|
||||
execute_http_tool,
|
||||
tool_to_function_schema,
|
||||
)
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.frames.frames import (
|
||||
FunctionCallInProgressFrame,
|
||||
FunctionCallResultFrame,
|
||||
|
|
@ -862,11 +860,27 @@ class TestCustomToolManagerUnit:
|
|||
assert result_received["status"] == "success"
|
||||
|
||||
|
||||
def _update_llm_context(context, system_message, functions):
|
||||
"""Inline helper replicating the old update_llm_context for tests."""
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
|
||||
class TestUpdateLLMContext:
|
||||
"""Tests for update_llm_context function."""
|
||||
"""Tests for _update_llm_context inline logic."""
|
||||
|
||||
def test_replaces_system_message(self):
|
||||
"""Test that update_llm_context replaces existing system messages."""
|
||||
"""Test that _update_llm_context replaces existing system messages."""
|
||||
context = LLMContext()
|
||||
context.set_messages(
|
||||
[
|
||||
|
|
@ -877,7 +891,7 @@ class TestUpdateLLMContext:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "New system message"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
messages = context.messages
|
||||
# Should have new system message at the start
|
||||
|
|
@ -902,7 +916,7 @@ class TestUpdateLLMContext:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
messages = context.messages
|
||||
assert len(messages) == 5
|
||||
|
|
@ -923,7 +937,7 @@ class TestUpdateLLMContext:
|
|||
]
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt with tools"}
|
||||
update_llm_context(context, new_system, functions)
|
||||
_update_llm_context(context, new_system, functions)
|
||||
|
||||
# Verify tools were set
|
||||
tools = context.tools
|
||||
|
|
@ -936,7 +950,7 @@ class TestUpdateLLMContext:
|
|||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "New prompt without tools"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
# Tools should not be set (or remain None)
|
||||
# Note: The function only calls set_tools if functions is truthy
|
||||
|
|
@ -952,7 +966,7 @@ class TestUpdateLLMContext:
|
|||
new_system = {"role": "system", "content": "Initial prompt"}
|
||||
functions = [get_function_schema("test_func", "A test function")]
|
||||
|
||||
update_llm_context(context, new_system, functions)
|
||||
_update_llm_context(context, new_system, functions)
|
||||
|
||||
messages = context.messages
|
||||
assert len(messages) == 1
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""Integration tests for CustomToolManager with update_llm_context.
|
||||
"""Integration tests for CustomToolManager with LLM context updates.
|
||||
|
||||
This module tests the full flow of:
|
||||
1. CustomToolManager fetching and converting tool schemas
|
||||
2. update_llm_context setting those tools on the LLM context
|
||||
2. Setting those tools on the LLM context
|
||||
3. Verifying the context is properly configured for LLM generation
|
||||
"""
|
||||
|
||||
|
|
@ -10,16 +10,32 @@ from unittest.mock import AsyncMock, patch
|
|||
|
||||
import pytest
|
||||
|
||||
from api.services.workflow.pipecat_engine_custom_tools import CustomToolManager
|
||||
from api.services.workflow.pipecat_engine_utils import (
|
||||
from api.services.workflow.pipecat_engine_custom_tools import (
|
||||
CustomToolManager,
|
||||
get_function_schema,
|
||||
update_llm_context,
|
||||
)
|
||||
from api.tests.conftest import MockToolModel
|
||||
from pipecat.adapters.schemas.function_schema import FunctionSchema
|
||||
from pipecat.adapters.schemas.tools_schema import ToolsSchema
|
||||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
def _update_llm_context(context, system_message, functions):
|
||||
"""Inline helper replicating the update_llm_context logic for tests."""
|
||||
tools_schema = ToolsSchema(standard_tools=functions)
|
||||
previous_interactions = context.messages
|
||||
|
||||
if previous_interactions and previous_interactions[0]["role"] == "system":
|
||||
messages = [system_message] + previous_interactions[1:]
|
||||
else:
|
||||
messages = [system_message] + previous_interactions
|
||||
|
||||
context.set_messages(messages)
|
||||
|
||||
if functions:
|
||||
context.set_tools(tools_schema)
|
||||
|
||||
|
||||
class TestCustomToolManagerContextIntegration:
|
||||
"""Integration tests for CustomToolManager with LLMContext."""
|
||||
|
||||
|
|
@ -69,7 +85,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
"role": "system",
|
||||
"content": "You are a scheduling assistant with access to weather and booking tools.",
|
||||
}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
_update_llm_context(context, new_system, schemas)
|
||||
|
||||
# Verify context was updated correctly
|
||||
messages = context.messages
|
||||
|
|
@ -195,7 +211,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
"role": "system",
|
||||
"content": "Assistant with calculator and weather tools",
|
||||
}
|
||||
update_llm_context(context, new_system, all_functions)
|
||||
_update_llm_context(context, new_system, all_functions)
|
||||
|
||||
# Verify all tools are present
|
||||
tools = context.tools
|
||||
|
|
@ -259,7 +275,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
)
|
||||
|
||||
new_system = {"role": "system", "content": "Updated weather assistant"}
|
||||
update_llm_context(context, new_system, schemas)
|
||||
_update_llm_context(context, new_system, schemas)
|
||||
|
||||
messages = context.messages
|
||||
# System + user + assistant(tool_call) + tool + assistant = 5
|
||||
|
|
@ -296,7 +312,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
|
||||
new_system = {"role": "system", "content": "No tools available"}
|
||||
update_llm_context(context, new_system, [])
|
||||
_update_llm_context(context, new_system, [])
|
||||
|
||||
# Context should have updated message but no tools set
|
||||
assert context.messages[0]["content"] == "No tools available"
|
||||
|
|
@ -362,7 +378,7 @@ class TestCustomToolManagerContextIntegration:
|
|||
# Update context - pass schema directly
|
||||
context = LLMContext()
|
||||
context.set_messages([{"role": "system", "content": "Old"}])
|
||||
update_llm_context(
|
||||
_update_llm_context(
|
||||
context, {"role": "system", "content": "Order assistant"}, schemas
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -68,9 +68,7 @@ class ContextCapturingMockLLM(MockLLMService):
|
|||
{
|
||||
"step": self._current_step,
|
||||
"messages": messages_snapshot,
|
||||
"system_prompt": messages_snapshot[0]["content"]
|
||||
if messages_snapshot
|
||||
else None,
|
||||
"system_prompt": self._settings.system_instruction,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -101,12 +99,10 @@ class ContextCapturingMockLLM(MockLLMService):
|
|||
return False
|
||||
|
||||
def get_system_prompt_at_step(self, step: int) -> str:
|
||||
"""Get the system prompt from context at a specific step."""
|
||||
"""Get the system prompt from settings at a specific step."""
|
||||
ctx = self.get_context_at_step(step)
|
||||
if ctx and ctx["messages"]:
|
||||
first_msg = ctx["messages"][0]
|
||||
if first_msg.get("role") == "system":
|
||||
return first_msg.get("content", "")
|
||||
if ctx:
|
||||
return ctx.get("system_prompt") or ""
|
||||
return ""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_1(
|
||||
|
|
@ -233,7 +233,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parallel_builtin_and_transition_calls_through_engine_with_text(
|
||||
|
|
@ -281,7 +281,7 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_transition_call_through_engine(
|
||||
|
|
@ -315,4 +315,4 @@ class TestPipecatEngineToolCalls:
|
|||
)
|
||||
|
||||
# Assert that the context was updated with END_CALL_SYSTEM_PROMPT
|
||||
assert context.messages[0]["content"] == END_CALL_SYSTEM_PROMPT
|
||||
assert llm._settings.system_instruction == END_CALL_SYSTEM_PROMPT
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue