mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-22 08:38:13 +02:00
Merge branch 'main' into feat/telnyx-telephony
This commit is contained in:
commit
9dc64456d8
39 changed files with 1071 additions and 313 deletions
|
|
@ -64,7 +64,7 @@ class WorkflowRecordingClient(BaseDBClient):
|
|||
storage_key=storage_key,
|
||||
storage_backend=storage_backend,
|
||||
created_by=created_by,
|
||||
metadata=metadata or {},
|
||||
recording_metadata=metadata or {},
|
||||
)
|
||||
|
||||
session.add(recording)
|
||||
|
|
|
|||
|
|
@ -40,6 +40,38 @@ class PresignedUploadUrlResponse(BaseModel):
|
|||
router = APIRouter(prefix="/s3", tags=["s3"])
|
||||
|
||||
|
||||
def _extract_org_id_from_key(key: str) -> Optional[int]:
|
||||
"""Try to extract an organization ID from a storage key.
|
||||
|
||||
Matches keys of the form ``{prefix}/{org_id}/...`` where *org_id* is a
|
||||
positive integer. Returns ``None`` when the pattern does not match.
|
||||
"""
|
||||
parts = key.split("/")
|
||||
if len(parts) >= 3 and parts[1].isdigit():
|
||||
return int(parts[1])
|
||||
return None
|
||||
|
||||
|
||||
def _extract_legacy_workflow_run_id(key: str) -> Optional[int]:
|
||||
"""Extract a workflow_run_id from legacy key formats.
|
||||
|
||||
Supports:
|
||||
- ``transcripts/{run_id}.txt``
|
||||
- ``recordings/{run_id}.wav``
|
||||
|
||||
Returns ``None`` when the key does not match a legacy pattern.
|
||||
"""
|
||||
if key.startswith("transcripts/") and key.endswith(".txt"):
|
||||
run_id_str = key[len("transcripts/") : -4]
|
||||
elif key.startswith("recordings/") and key.endswith(".wav"):
|
||||
run_id_str = key[len("recordings/") : -4]
|
||||
else:
|
||||
return None
|
||||
|
||||
return int(run_id_str) if run_id_str.isdigit() else None
|
||||
|
||||
|
||||
# Keep for backward compat with file-metadata endpoint
|
||||
async def _validate_and_extract_workflow_run_id(
|
||||
key: str, allow_special_paths: bool = False
|
||||
) -> Optional[int]:
|
||||
|
|
@ -118,64 +150,68 @@ async def get_signed_url(
|
|||
key: Annotated[str, Query(description="S3 object key")],
|
||||
expires_in: int = 3600,
|
||||
inline: bool = False,
|
||||
storage_backend: Annotated[
|
||||
Optional[str],
|
||||
Query(
|
||||
description="Storage backend to use (e.g. 'minio', 's3'). "
|
||||
"When omitted the backend is inferred from the resource."
|
||||
),
|
||||
] = None,
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Return a short-lived signed URL for a transcript or recording file stored on S3.
|
||||
"""Return a short-lived signed URL for a file stored on S3 / MinIO.
|
||||
|
||||
Access Control:
|
||||
* Keys that embed an organization ID (``{prefix}/{org_id}/...``) are
|
||||
authorized by matching the org_id against the requesting user's
|
||||
organization.
|
||||
* Legacy keys (``recordings/{run_id}.wav``, ``transcripts/{run_id}.txt``)
|
||||
are authorized via the workflow run they belong to.
|
||||
* Superusers can request any key.
|
||||
* Regular users can only request resources belonging to **their** workflow runs.
|
||||
"""
|
||||
|
||||
# Validate key and extract workflow_run_id (don't allow special paths for signed URLs)
|
||||
run_id = await _validate_and_extract_workflow_run_id(key, allow_special_paths=False)
|
||||
if run_id is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid key format")
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Authorize
|
||||
# ------------------------------------------------------------------
|
||||
workflow_run = None
|
||||
|
||||
# Authorize and get workflow run
|
||||
workflow_run = await _authorize_and_get_workflow_run(run_id, user)
|
||||
org_id = _extract_org_id_from_key(key)
|
||||
if org_id is not None:
|
||||
# Generic org-based auth
|
||||
if not user.is_superuser and org_id != user.selected_organization_id:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
else:
|
||||
# Legacy workflow-run-based auth
|
||||
run_id = _extract_legacy_workflow_run_id(key)
|
||||
if run_id is None:
|
||||
raise HTTPException(status_code=400, detail="Invalid key format")
|
||||
workflow_run = await _authorize_and_get_workflow_run(run_id, user)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Generate the signed URL using the correct storage backend
|
||||
# 2. Resolve storage backend
|
||||
# ------------------------------------------------------------------
|
||||
try:
|
||||
# Use the storage backend recorded when the file was uploaded
|
||||
if (
|
||||
if storage_backend:
|
||||
storage = get_storage_for_backend(storage_backend)
|
||||
elif (
|
||||
workflow_run
|
||||
and hasattr(workflow_run, "storage_backend")
|
||||
and workflow_run.storage_backend
|
||||
):
|
||||
backend = workflow_run.storage_backend
|
||||
storage = get_storage_for_backend(backend)
|
||||
logger.info(
|
||||
f"DOWNLOAD: Using stored {backend} (value: {backend}) for signed URL generation - workflow_run_id: {run_id}, key: {key}"
|
||||
)
|
||||
storage = get_storage_for_backend(workflow_run.storage_backend)
|
||||
else:
|
||||
# Fallback to current storage for legacy records without storage_backend
|
||||
storage = storage_fs
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
logger.warning(
|
||||
f"DOWNLOAD: No storage_backend found for workflow run {run_id}, falling back to current {current_backend.name} - key: {key}"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Generate the signed URL
|
||||
# ------------------------------------------------------------------
|
||||
url = await storage.aget_signed_url(
|
||||
key, expiration=expires_in, force_inline=inline
|
||||
)
|
||||
if not url:
|
||||
raise HTTPException(status_code=500, detail="Failed to generate signed URL")
|
||||
|
||||
# Log successful URL generation
|
||||
backend_info = (
|
||||
f"stored {backend}"
|
||||
if workflow_run
|
||||
and hasattr(workflow_run, "storage_backend")
|
||||
and workflow_run.storage_backend
|
||||
else f"current {StorageBackend.get_current_backend().name}"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully generated signed URL using {backend_info} - expires in {expires_in}s"
|
||||
)
|
||||
|
||||
logger.info(f"Generated signed URL for key={key}, expires_in={expires_in}s")
|
||||
return {"url": url, "expires_in": expires_in}
|
||||
except ClientError as exc:
|
||||
logger.error(f"Error generating signed URL: {exc}")
|
||||
|
|
|
|||
|
|
@ -2,9 +2,10 @@
|
|||
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi import APIRouter, Depends, File, Form, HTTPException, Query, UploadFile
|
||||
from loguru import logger
|
||||
|
||||
from api.constants import DEPLOYMENT_MODE
|
||||
from api.db import db_client
|
||||
from api.db.workflow_recording_client import generate_short_id
|
||||
from api.enums import StorageBackend
|
||||
|
|
@ -16,6 +17,7 @@ from api.schemas.workflow_recording import (
|
|||
RecordingUploadResponseSchema,
|
||||
)
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.mps_service_key_client import mps_service_key_client
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
router = APIRouter(prefix="/workflow-recordings", tags=["workflow-recordings"])
|
||||
|
|
@ -216,3 +218,42 @@ async def delete_recording(
|
|||
raise HTTPException(
|
||||
status_code=500, detail="Failed to delete recording"
|
||||
) from exc
|
||||
|
||||
|
||||
@router.post(
|
||||
"/transcribe",
|
||||
summary="Transcribe an audio file",
|
||||
)
|
||||
async def transcribe_audio(
|
||||
file: UploadFile = File(...),
|
||||
language: str = Form("en"),
|
||||
user=Depends(get_user),
|
||||
):
|
||||
"""Transcribe an uploaded audio file using MPS STT."""
|
||||
try:
|
||||
audio_data = await file.read()
|
||||
|
||||
if DEPLOYMENT_MODE == "oss":
|
||||
result = await mps_service_key_client.transcribe_audio(
|
||||
audio_data=audio_data,
|
||||
filename=file.filename or "audio.wav",
|
||||
content_type=file.content_type or "audio/wav",
|
||||
language=language,
|
||||
created_by=str(user.provider_id),
|
||||
)
|
||||
else:
|
||||
result = await mps_service_key_client.transcribe_audio(
|
||||
audio_data=audio_data,
|
||||
filename=file.filename or "audio.wav",
|
||||
content_type=file.content_type or "audio/wav",
|
||||
language=language,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error transcribing audio: {exc}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to transcribe audio"
|
||||
) from exc
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ class UserConfigurationValidator:
|
|||
ServiceProviders.SPEECHMATICS.value: self._check_speechmatics_api_key,
|
||||
ServiceProviders.CAMB.value: self._check_camb_api_key,
|
||||
ServiceProviders.AWS_BEDROCK.value: self._check_aws_bedrock_api_key,
|
||||
ServiceProviders.SELF_HOSTED.value: self._check_self_hosted_api_key,
|
||||
}
|
||||
|
||||
async def validate(self, configuration: UserConfiguration) -> APIKeyStatusResponse:
|
||||
|
|
@ -74,6 +75,20 @@ class UserConfigurationValidator:
|
|||
|
||||
provider = service_config.provider
|
||||
|
||||
# Self-hosted doesn't require an API key
|
||||
if provider == ServiceProviders.SELF_HOSTED.value:
|
||||
try:
|
||||
if not self._check_self_hosted_api_key(provider, service_config):
|
||||
return [
|
||||
{
|
||||
"model": service_name,
|
||||
"message": f"Invalid {provider} configuration",
|
||||
}
|
||||
]
|
||||
except ValueError as e:
|
||||
return [{"model": service_name, "message": str(e)}]
|
||||
return []
|
||||
|
||||
# AWS Bedrock uses AWS credentials instead of api_key
|
||||
if provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
try:
|
||||
|
|
@ -163,7 +178,12 @@ class UserConfigurationValidator:
|
|||
|
||||
def _check_camb_api_key(self, model: str, api_key: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
def _check_self_hosted_api_key(self, model: str, service_config) -> bool:
|
||||
if not getattr(service_config, "base_url", None):
|
||||
raise ValueError("base_url is required for self-hosted LLM")
|
||||
return True
|
||||
|
||||
def _check_aws_bedrock_api_key(self, model: str, service_config) -> bool:
|
||||
if not service_config.aws_access_key or not service_config.aws_secret_key:
|
||||
raise ValueError("AWS access key and secret key are required for Bedrock")
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ class ServiceProviders(str, Enum):
|
|||
SPEECHMATICS = "speechmatics"
|
||||
CAMB = "camb"
|
||||
AWS_BEDROCK = "aws_bedrock"
|
||||
SELF_HOSTED = "self_hosted"
|
||||
|
||||
|
||||
class BaseServiceConfiguration(BaseModel):
|
||||
|
|
@ -40,6 +41,7 @@ class BaseServiceConfiguration(BaseModel):
|
|||
ServiceProviders.AZURE,
|
||||
ServiceProviders.DOGRAH,
|
||||
ServiceProviders.AWS_BEDROCK,
|
||||
ServiceProviders.SELF_HOSTED,
|
||||
# ServiceProviders.SARVAM,
|
||||
]
|
||||
api_key: str | list[str]
|
||||
|
|
@ -249,6 +251,22 @@ class AWSBedrockLLMConfiguration(BaseLLMConfiguration):
|
|||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
SELF_HOSTED_LLM_MODELS = ["llama3", "mistral", "phi3", "qwen2", "gemma2", "deepseek-r1"]
|
||||
|
||||
|
||||
@register_llm
|
||||
class SelfHostedLLMConfiguration(BaseLLMConfiguration):
|
||||
provider: Literal[ServiceProviders.SELF_HOSTED] = ServiceProviders.SELF_HOSTED
|
||||
model: str = Field(
|
||||
default="llama3", json_schema_extra={"examples": SELF_HOSTED_LLM_MODELS}
|
||||
)
|
||||
base_url: str = Field(
|
||||
default="http://localhost:11434/v1",
|
||||
description="OpenAI-compatible endpoint (Ollama, vLLM, etc.)",
|
||||
)
|
||||
api_key: str | list[str] | None = Field(default=None)
|
||||
|
||||
|
||||
LLMConfig = Annotated[
|
||||
Union[
|
||||
OpenAILLMService,
|
||||
|
|
@ -258,6 +276,7 @@ LLMConfig = Annotated[
|
|||
AzureLLMService,
|
||||
DograhLLMService,
|
||||
AWSBedrockLLMConfiguration,
|
||||
SelfHostedLLMConfiguration,
|
||||
],
|
||||
Field(discriminator="provider"),
|
||||
]
|
||||
|
|
@ -334,6 +353,12 @@ class CartesiaTTSConfiguration(BaseTTSConfiguration):
|
|||
)
|
||||
voice: str = Field(default="3faa81ae-d3d8-4ab1-9e44-e50e46d33c30")
|
||||
speed: float = Field(default=1.0, ge=0.6, le=1.5, description="Speed of the voice")
|
||||
volume: float = Field(
|
||||
default=1.0,
|
||||
ge=0.5,
|
||||
le=2.0,
|
||||
description="Volume multiplier for generated speech",
|
||||
)
|
||||
|
||||
|
||||
SARVAM_TTS_MODELS = ["bulbul:v2", "bulbul:v3"]
|
||||
|
|
|
|||
|
|
@ -351,6 +351,71 @@ class MPSServiceKeyClient:
|
|||
response=response,
|
||||
)
|
||||
|
||||
async def transcribe_audio(
|
||||
self,
|
||||
audio_data: bytes,
|
||||
filename: str = "audio.wav",
|
||||
content_type: str = "audio/wav",
|
||||
language: str = "en",
|
||||
model: str = "default",
|
||||
correlation_id: Optional[str] = None,
|
||||
organization_id: Optional[int] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transcribe an audio file via MPS STT API.
|
||||
|
||||
Args:
|
||||
audio_data: Raw audio bytes
|
||||
filename: Name of the audio file
|
||||
content_type: MIME type of the audio (e.g., audio/wav, audio/mp3)
|
||||
language: Language code for transcription (default: "en")
|
||||
model: Model tier name (default: "default")
|
||||
correlation_id: Optional correlation ID for tracking
|
||||
organization_id: Organization ID (for authenticated mode)
|
||||
created_by: User provider ID (for OSS mode)
|
||||
|
||||
Returns:
|
||||
Dictionary containing transcription result with keys like
|
||||
'transcript', 'duration_seconds', etc.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the API call fails
|
||||
"""
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(60.0)) as client:
|
||||
files = {
|
||||
"file": (filename, audio_data, content_type),
|
||||
}
|
||||
data = {
|
||||
"language": language,
|
||||
"model": model,
|
||||
}
|
||||
if correlation_id:
|
||||
data["correlation_id"] = correlation_id
|
||||
|
||||
headers = self._get_headers(organization_id, created_by)
|
||||
# Remove Content-Type so httpx sets the correct multipart boundary
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = await client.post(
|
||||
f"{self.base_url}/api/v1/stt/transcribe",
|
||||
files=files,
|
||||
data=data,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to transcribe audio: {response.status_code} - {response.text}"
|
||||
)
|
||||
raise httpx.HTTPStatusError(
|
||||
f"Failed to transcribe audio: {response.text}",
|
||||
request=response.request,
|
||||
response=response,
|
||||
)
|
||||
|
||||
def validate_service_key(self, service_key: str) -> bool:
|
||||
"""
|
||||
Synchronously validate a Dograh service key by checking usage via MPS.
|
||||
|
|
|
|||
|
|
@ -165,49 +165,39 @@ class RealtimeFeedbackObserver(BaseObserver):
|
|||
frame = data.frame
|
||||
frame_direction = data.direction
|
||||
|
||||
logger.trace(f"{self} Received Frame: {frame} Direction: {frame_direction}")
|
||||
|
||||
# Handle pipeline termination - stop clock task
|
||||
if isinstance(frame, (EndFrame, CancelFrame, StopFrame)):
|
||||
await self._cancel_clock_task()
|
||||
return
|
||||
|
||||
# Handle interruptions - clear any queued bot text
|
||||
if isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
return
|
||||
|
||||
# Bot speaking state - WS only (ephemeral state signals, not persisted)
|
||||
if isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.BOT_STARTED_SPEAKING.value, "payload": {}}
|
||||
)
|
||||
return
|
||||
if isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.BOT_STOPPED_SPEAKING.value, "payload": {}}
|
||||
)
|
||||
return
|
||||
|
||||
# User mute state - WS only (ephemeral state signals, not persisted)
|
||||
if isinstance(frame, UserMuteStartedFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.USER_MUTE_STARTED.value, "payload": {}}
|
||||
)
|
||||
return
|
||||
if isinstance(frame, UserMuteStoppedFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.USER_MUTE_STOPPED.value, "payload": {}}
|
||||
)
|
||||
return
|
||||
|
||||
# Skip already processed frames (frames can be observed multiple times)
|
||||
if frame.id in self._frames_seen:
|
||||
return
|
||||
self._frames_seen.add(frame.id)
|
||||
|
||||
logger.trace(f"{self} Received Frame: {frame} Direction: {frame_direction}")
|
||||
|
||||
# Handle pipeline termination - stop clock task
|
||||
if isinstance(frame, (EndFrame, CancelFrame, StopFrame)):
|
||||
await self._cancel_clock_task()
|
||||
# Handle interruptions - clear any queued bot text
|
||||
elif isinstance(frame, InterruptionFrame):
|
||||
await self._handle_interruption()
|
||||
# Bot speaking state - WS only (ephemeral state signals, not persisted)
|
||||
elif isinstance(frame, BotStartedSpeakingFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.BOT_STARTED_SPEAKING.value, "payload": {}}
|
||||
)
|
||||
elif isinstance(frame, BotStoppedSpeakingFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.BOT_STOPPED_SPEAKING.value, "payload": {}}
|
||||
)
|
||||
# User mute state - WS only (ephemeral state signals, not persisted)
|
||||
elif isinstance(frame, UserMuteStartedFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.USER_MUTE_STARTED.value, "payload": {}}
|
||||
)
|
||||
elif isinstance(frame, UserMuteStoppedFrame):
|
||||
await self._send_ws(
|
||||
{"type": RealtimeFeedbackType.USER_MUTE_STOPPED.value, "payload": {}}
|
||||
)
|
||||
# Handle user transcriptions (interim) - WebSocket only
|
||||
if isinstance(frame, InterimTranscriptionFrame):
|
||||
elif isinstance(frame, InterimTranscriptionFrame):
|
||||
await self._send_ws(
|
||||
{
|
||||
"type": RealtimeFeedbackType.USER_TRANSCRIPTION.value,
|
||||
|
|
|
|||
|
|
@ -66,6 +66,7 @@ class RecordingRouterProcessor(FrameProcessor):
|
|||
self._frame_buffer: list[tuple[LLMTextFrame, FrameDirection]] = []
|
||||
self._mode: Optional[str] = None # None = detecting, "tts", "recording"
|
||||
self._recording_id_buffer = ""
|
||||
self._recording_playback_started = False
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Frame dispatch
|
||||
|
|
@ -99,9 +100,15 @@ class RecordingRouterProcessor(FrameProcessor):
|
|||
await self.push_frame(frame, direction)
|
||||
return
|
||||
|
||||
# --- Recording mode: accumulate recording_id silently ---
|
||||
# --- Recording mode: accumulate text and start playback ASAP ---
|
||||
if self._mode == "recording":
|
||||
self._recording_id_buffer += frame.text
|
||||
if not self._recording_playback_started:
|
||||
buf = self._recording_id_buffer.lstrip()
|
||||
if " " in buf:
|
||||
recording_id = buf.split()[0]
|
||||
self._recording_playback_started = True
|
||||
await self._play_recording(recording_id)
|
||||
return
|
||||
|
||||
# --- Detection mode: buffer until marker found ---
|
||||
|
|
@ -178,16 +185,21 @@ class RecordingRouterProcessor(FrameProcessor):
|
|||
self, frame: LLMFullResponseEndFrame, direction: FrameDirection
|
||||
):
|
||||
if self._mode == "recording":
|
||||
recording_id = self._recording_id_buffer.strip()
|
||||
if recording_id:
|
||||
# Push accumulated text as TTSTextFrame for UI feedback via observer
|
||||
full_text = self._recording_id_buffer.strip()
|
||||
if full_text:
|
||||
recording_id = full_text.split()[0]
|
||||
|
||||
# Push full text (marker + id + transcript) for assistant context
|
||||
await self.push_frame(
|
||||
TTSTextFrame(
|
||||
text=RECORDING_MARKER + self._recording_id_buffer,
|
||||
aggregated_by="recording_router",
|
||||
)
|
||||
)
|
||||
await self._play_recording(recording_id)
|
||||
|
||||
# Fallback: if response ended before a space arrived (no transcript)
|
||||
if not self._recording_playback_started:
|
||||
await self._play_recording(recording_id)
|
||||
else:
|
||||
logger.warning(
|
||||
"RecordingRouterProcessor: recording mode but empty recording_id"
|
||||
|
|
@ -256,3 +268,4 @@ class RecordingRouterProcessor(FrameProcessor):
|
|||
self._frame_buffer = []
|
||||
self._mode = None
|
||||
self._recording_id_buffer = ""
|
||||
self._recording_playback_started = False
|
||||
|
|
|
|||
|
|
@ -8,7 +8,11 @@ from api.services.configuration.registry import ServiceProviders
|
|||
from pipecat.services.aws.llm import AWSBedrockLLMService, AWSBedrockLLMSettings
|
||||
from pipecat.services.azure.llm import AzureLLMService, AzureLLMSettings
|
||||
from pipecat.services.cartesia.stt import CartesiaSTTService
|
||||
from pipecat.services.cartesia.tts import CartesiaTTSService, CartesiaTTSSettings, GenerationConfig
|
||||
from pipecat.services.cartesia.tts import (
|
||||
CartesiaTTSService,
|
||||
CartesiaTTSSettings,
|
||||
GenerationConfig,
|
||||
)
|
||||
from pipecat.services.deepgram.flux.stt import (
|
||||
DeepgramFluxSTTService,
|
||||
DeepgramFluxSTTSettings,
|
||||
|
|
@ -212,13 +216,19 @@ def create_tts_service(user_config, audio_config: "AudioConfig"):
|
|||
)
|
||||
elif user_config.tts.provider == ServiceProviders.CARTESIA.value:
|
||||
speed = getattr(user_config.tts, "speed", None)
|
||||
generation_config = GenerationConfig(speed=speed) if speed and speed != 1.0 else None
|
||||
generation_config = (
|
||||
GenerationConfig(speed=speed) if speed and speed != 1.0 else None
|
||||
)
|
||||
return CartesiaTTSService(
|
||||
api_key=user_config.tts.api_key,
|
||||
settings=CartesiaTTSSettings(
|
||||
voice=user_config.tts.voice,
|
||||
model=user_config.tts.model,
|
||||
**({"generation_config": generation_config} if generation_config else {}),
|
||||
**(
|
||||
{"generation_config": generation_config}
|
||||
if generation_config
|
||||
else {}
|
||||
),
|
||||
),
|
||||
text_filters=[xml_function_tag_filter],
|
||||
silence_time_s=1.0,
|
||||
|
|
@ -353,6 +363,12 @@ def create_llm_service_from_provider(
|
|||
aws_region=aws_region,
|
||||
settings=AWSBedrockLLMSettings(model=model),
|
||||
)
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
return OpenAILLMService(
|
||||
base_url=base_url or "http://localhost:11434/v1",
|
||||
api_key=api_key or "none",
|
||||
settings=OpenAILLMSettings(model=model),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid LLM provider {provider}")
|
||||
|
||||
|
|
@ -368,6 +384,8 @@ def create_llm_service(user_config):
|
|||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AZURE.value:
|
||||
kwargs["endpoint"] = user_config.llm.endpoint
|
||||
elif provider == ServiceProviders.SELF_HOSTED.value:
|
||||
kwargs["base_url"] = user_config.llm.base_url
|
||||
elif provider == ServiceProviders.AWS_BEDROCK.value:
|
||||
kwargs["aws_access_key"] = user_config.llm.aws_access_key
|
||||
kwargs["aws_secret_key"] = user_config.llm.aws_secret_key
|
||||
|
|
|
|||
|
|
@ -437,9 +437,7 @@ class PipecatEngine:
|
|||
|
||||
async def _do_extraction():
|
||||
try:
|
||||
logger.debug(
|
||||
f"Starting variable extraction for node: {node.name}"
|
||||
)
|
||||
logger.debug(f"Starting variable extraction for node: {node.name}")
|
||||
extracted_data = (
|
||||
await self._variable_extraction_manager._perform_extraction(
|
||||
extraction_variables, parent_context, extraction_prompt
|
||||
|
|
@ -454,7 +452,9 @@ class PipecatEngine:
|
|||
f"Variable extraction completed for node: {node.name}. Extracted: {extracted_data}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during variable extraction for node {node.name}: {str(e)}")
|
||||
logger.error(
|
||||
f"Error during variable extraction for node {node.name}: {str(e)}"
|
||||
)
|
||||
|
||||
if run_in_background:
|
||||
logger.debug(
|
||||
|
|
@ -497,9 +497,7 @@ class PipecatEngine:
|
|||
logger.error(
|
||||
f"Pending extraction task '{task_name}' failed: {result}"
|
||||
)
|
||||
logger.debug(
|
||||
f"All pending extraction tasks completed in {elapsed:.2f}s"
|
||||
)
|
||||
logger.debug(f"All pending extraction tasks completed in {elapsed:.2f}s")
|
||||
except asyncio.TimeoutError:
|
||||
incomplete = [
|
||||
t.get_name() for t in self._pending_extraction_tasks if not t.done()
|
||||
|
|
|
|||
|
|
@ -34,13 +34,13 @@ You have two modes for responding:
|
|||
Example: ▸ Hello! How can I help you today?
|
||||
|
||||
2. PRE-RECORDED AUDIO (●): Play a pre-recorded audio message.
|
||||
Format: `●` followed by a space and ONLY the recording_id. Nothing else.
|
||||
Example: ● rec_greeting_01
|
||||
Format: `●` followed by a space followed by recording_id followed by provided transcript. Nothing else.
|
||||
Example: ● rec_greeting_01 [ Provided Transcript ]
|
||||
|
||||
RULES:
|
||||
- Your response MUST start with either `▸` or `●` as the very first character.
|
||||
- For `▸` (dynamic speech): Follow with a space and your full response text.
|
||||
- For `●` (pre-recorded audio): Follow with a space and ONLY the recording_id. No other text.
|
||||
- For `●` (pre-recorded audio): Follow with a space and the recording_id and the provided transcript. No other text.
|
||||
- Use `●` when a pre-recorded message matches the situation well.
|
||||
- Use `▸` when you need to generate a dynamic, contextual response.
|
||||
- NEVER mix modes in a single response. Choose one."""
|
||||
|
|
@ -77,11 +77,8 @@ def compose_system_prompt_for_node(
|
|||
|
||||
parts = [p for p in (global_prompt, formatted_node_prompt) if p]
|
||||
|
||||
if has_recordings:
|
||||
if has_recordings and "RECORDING_ID:" in formatted_node_prompt:
|
||||
parts.append(RECORDING_RESPONSE_MODE_INSTRUCTIONS)
|
||||
# TODO: Append per-node available recordings list here once
|
||||
# Node.recording_ids is populated. The list should include
|
||||
# recording_id and a short description so the LLM can choose.
|
||||
|
||||
return "\n\n".join(parts)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,9 @@ from api.utils.template_renderer import render_template
|
|||
from pipecat.processors.aggregators.llm_context import LLMContext
|
||||
|
||||
|
||||
async def _run_llm_inference(llm, messages: list[dict], system_prompt: str) -> str | None:
|
||||
async def _run_llm_inference(
|
||||
llm, messages: list[dict], system_prompt: str
|
||||
) -> str | None:
|
||||
"""Run a one-shot LLM inference using the pipecat service."""
|
||||
context = LLMContext()
|
||||
context.set_messages(messages)
|
||||
|
|
@ -51,7 +53,10 @@ async def _generate_conversation_summary(
|
|||
]
|
||||
|
||||
try:
|
||||
summary = await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT) or ""
|
||||
summary = (
|
||||
await _run_llm_inference(llm, messages, CONVERSATION_SUMMARY_SYSTEM_PROMPT)
|
||||
or ""
|
||||
)
|
||||
|
||||
span_name = f"conversation-summary-before-{node_name}"
|
||||
add_qa_span_to_trace(parent_ctx, model, messages, summary, span_name)
|
||||
|
|
|
|||
|
|
@ -154,7 +154,12 @@ async def ensure_node_summaries(
|
|||
try:
|
||||
context = LLMContext()
|
||||
context.set_messages(messages)
|
||||
summary_text = await llm.run_inference(context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT) or ""
|
||||
summary_text = (
|
||||
await llm.run_inference(
|
||||
context, system_instruction=NODE_SUMMARY_SYSTEM_PROMPT
|
||||
)
|
||||
or ""
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate summary for node {node_id}: {e}")
|
||||
updated_summaries[node_id] = {"summary": ""}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ Covers:
|
|||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
|
@ -17,13 +17,12 @@ from pydantic import ValidationError
|
|||
from api.services.configuration.check_validity import UserConfigurationValidator
|
||||
from api.services.configuration.registry import (
|
||||
CAMB_TTS_MODELS,
|
||||
CambTTSConfiguration,
|
||||
REGISTRY,
|
||||
CambTTSConfiguration,
|
||||
ServiceProviders,
|
||||
ServiceType,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. CambTTSConfiguration model tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue