mirror of
https://github.com/MODSetter/SurfSense.git
synced 2026-06-12 20:45:20 +02:00
Compare commits
73 commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e539311a2 | ||
|
|
c855be8ccd | ||
|
|
cb7cb90732 | ||
|
|
fed83269d0 | ||
|
|
cff721aa42 | ||
|
|
6c8c559254 | ||
|
|
0f73db5aa1 | ||
|
|
f166a532bd | ||
|
|
05190da0a9 | ||
|
|
27218304ae | ||
|
|
7b30a76856 | ||
|
|
aee0c1a3ac | ||
|
|
e4803d4ed3 | ||
|
|
8f80900ab0 | ||
|
|
41f4a58663 | ||
|
|
e7762cda97 | ||
|
|
d27616ad0a | ||
|
|
c3695e7837 | ||
|
|
4dc06fa918 | ||
|
|
741aa8d8f7 | ||
|
|
ca9b157676 | ||
|
|
aa7f14d94f | ||
|
|
f0fc660d70 | ||
|
|
eb56acc407 | ||
|
|
11a6b178a0 | ||
|
|
ccd8209d12 | ||
|
|
1f9fd61c9e | ||
|
|
6f6c056404 | ||
|
|
a3d1fafb0b | ||
|
|
64b36f2622 | ||
|
|
65e511f77b | ||
|
|
c84525897b | ||
|
|
a7407502d3 | ||
|
|
8f38737ad9 | ||
|
|
97ab7a88fd | ||
|
|
003d1d2b95 | ||
|
|
8b52cd0ac9 | ||
|
|
3eb7cdb2d8 | ||
|
|
b7604167d8 | ||
|
|
bae59140a6 | ||
|
|
aa7aa81c16 | ||
|
|
e61308387c | ||
|
|
15e44616f3 | ||
|
|
0bed4a0d38 | ||
|
|
0c7987cd9e | ||
|
|
fa7ab8a06d | ||
|
|
36c201f9e2 | ||
|
|
0c92ee963e | ||
|
|
e926990d8e | ||
|
|
aaa9f01087 | ||
|
|
9d8e4e4f9d | ||
|
|
f61e8af8c0 | ||
|
|
eaaeebc1bb | ||
|
|
467bcd4f7b | ||
|
|
63f5f12834 | ||
|
|
1ebb57e1df | ||
|
|
5d956e8d03 | ||
|
|
89ceae8bab | ||
|
|
7087f7866d | ||
|
|
b2970ba37e | ||
|
|
4271048dcf | ||
|
|
470af28688 | ||
|
|
a3386cd5f9 | ||
|
|
0004abdc79 | ||
|
|
bd6d079030 | ||
|
|
75287020e1 | ||
|
|
ee24925747 | ||
|
|
65b6c2d357 | ||
|
|
73e191af09 | ||
|
|
8dd29fa833 | ||
|
|
4fe216856d | ||
|
|
b4c6061353 | ||
|
|
a024b03fb0 |
225 changed files with 10524 additions and 4528 deletions
2
VERSION
2
VERSION
|
|
@ -1 +1 @@
|
|||
0.0.27
|
||||
0.0.28
|
||||
|
|
|
|||
|
|
@ -166,25 +166,26 @@ EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
|
|||
# REDIS_URL=redis://redis:6379/0
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Stripe (pay-as-you-go page packs, disabled by default)
|
||||
# Stripe (unified credit wallet, disabled by default)
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Set TRUE to allow users to buy additional page packs via Stripe Checkout
|
||||
STRIPE_PAGE_BUYING_ENABLED=FALSE
|
||||
# Set TRUE to allow users to buy credit packs via Stripe Checkout. $1 buys
|
||||
# 1_000_000 micro-USD of credit; both ETL page processing and premium turns
|
||||
# debit this balance at the actual per-call provider cost from LiteLLM.
|
||||
STRIPE_CREDIT_BUYING_ENABLED=FALSE
|
||||
# STRIPE_SECRET_KEY=sk_test_...
|
||||
# STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
# STRIPE_PRICE_ID=price_...
|
||||
# STRIPE_PAGES_PER_UNIT=1000
|
||||
# STRIPE_CREDIT_PRICE_ID=price_...
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# STRIPE_RECONCILIATION_INTERVAL=10m
|
||||
# STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
# STRIPE_RECONCILIATION_BATCH_SIZE=100
|
||||
|
||||
# Premium credit purchases via Stripe ($1 buys 1_000_000 micro-USD of
|
||||
# credit; premium turns debit the actual per-call provider cost
|
||||
# reported by LiteLLM, so cheap and expensive models bill proportionally)
|
||||
# STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
# STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
# STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# Auto-reload: top up via a saved Stripe card when the balance drops below
|
||||
# the user-chosen threshold. Off by default.
|
||||
# AUTO_RELOAD_ENABLED=FALSE
|
||||
# AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000
|
||||
# AUTO_RELOAD_COOLDOWN_MINUTES=10
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# TTS & STT (Text-to-Speech / Speech-to-Text)
|
||||
|
|
@ -407,13 +408,16 @@ SURFSENSE_ENABLE_DOOM_LOOP=true
|
|||
# ACCESS_TOKEN_LIFETIME_SECONDS=86400
|
||||
# REFRESH_TOKEN_LIFETIME_SECONDS=1209600
|
||||
|
||||
# Pages limit per user for ETL (default: unlimited)
|
||||
# PAGES_LIMIT=500
|
||||
# Unified credit wallet starting balance for new users, in micro-USD
|
||||
# (default: $5). Funds both ETL page processing and premium model calls,
|
||||
# debited at the actual per-call provider cost reported by LiteLLM.
|
||||
# DEFAULT_CREDIT_MICROS_BALANCE=5000000
|
||||
|
||||
# Premium credit quota per registered user, in micro-USD (default: $5).
|
||||
# Premium turns are debited at the actual per-call provider cost reported
|
||||
# by LiteLLM. Only applies to models with billing_tier=premium.
|
||||
# PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL
|
||||
# effectively free for self-hosted installs. 1 page == MICROS_PER_PAGE
|
||||
# micro-USD ($0.001); premium ETL mode is 10x.
|
||||
# ETL_CREDIT_BILLING_ENABLED=FALSE
|
||||
# MICROS_PER_PAGE=1000
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD ($1.00 default).
|
||||
# QUOTA_MAX_RESERVE_MICROS=1000000
|
||||
|
|
|
|||
|
|
@ -17,10 +17,14 @@
|
|||
# into the new PostgreSQL 17 stack. The user runs one command for both.
|
||||
# =============================================================================
|
||||
|
||||
# NOTE: Do not use [ValidateSet()] (or other validation attributes without a
|
||||
# valid default) on these params. When the script is piped into iex, PowerShell
|
||||
# applies the attributes to variables in the caller's scope, and an empty
|
||||
# $Variant fails ValidateSet with a ValidationMetadataException. Validate
|
||||
# manually below instead.
|
||||
param(
|
||||
[switch]$NoWatchtower,
|
||||
[int]$WatchtowerInterval = 86400,
|
||||
[ValidateSet("cpu", "cuda", "cuda126")]
|
||||
[string]$Variant,
|
||||
[string]$GpuCount,
|
||||
[switch]$Quiet
|
||||
|
|
@ -40,6 +44,11 @@ $MigrationMode = $false
|
|||
$SetupWatchtower = -not $NoWatchtower
|
||||
$WatchtowerContainer = "watchtower"
|
||||
|
||||
if ($Variant -and $Variant -notin @("cpu", "cuda", "cuda126")) {
|
||||
Write-Host "[SurfSense] ERROR: Invalid -Variant '$Variant'. Use 'cpu', 'cuda', or 'cuda126'." -ForegroundColor Red
|
||||
exit 1
|
||||
}
|
||||
|
||||
if ($GpuCount -and $GpuCount -notmatch '^([0-9]+|all)$') {
|
||||
Write-Host "[SurfSense] ERROR: Invalid -GpuCount '$GpuCount'. Use a number or 'all'." -ForegroundColor Red
|
||||
exit 1
|
||||
|
|
|
|||
|
|
@ -75,23 +75,16 @@ SECRET_KEY=SECRET
|
|||
|
||||
NEXT_FRONTEND_URL=http://localhost:3000
|
||||
|
||||
# Stripe Checkout for pay-as-you-go page packs
|
||||
# Configure STRIPE_PRICE_ID to point at your 1,000-page price in Stripe.
|
||||
# Pages granted per purchase = quantity * STRIPE_PAGES_PER_UNIT.
|
||||
# Stripe Checkout for the unified credit wallet.
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Both ETL page processing and premium model
|
||||
# turns are billed against this single balance at actual provider cost.
|
||||
STRIPE_SECRET_KEY=sk_test_...
|
||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
STRIPE_PRICE_ID=price_...
|
||||
STRIPE_PAGES_PER_UNIT=1000
|
||||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_PAGE_BUYING_ENABLED=TRUE
|
||||
|
||||
# Premium credit purchases via Stripe (for premium-tier model usage).
|
||||
# Each pack grants STRIPE_CREDIT_MICROS_PER_UNIT micro-USD of credit
|
||||
# (default 1_000_000 = $1.00). Premium turns are billed at the actual
|
||||
# per-call provider cost reported by LiteLLM.
|
||||
STRIPE_TOKEN_BUYING_ENABLED=FALSE
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID=price_...
|
||||
STRIPE_CREDIT_PRICE_ID=price_...
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT=1000000
|
||||
# Set FALSE to disable new checkout session creation temporarily
|
||||
STRIPE_CREDIT_BUYING_ENABLED=FALSE
|
||||
|
||||
# Periodic Stripe safety net for purchases left in PENDING (minutes old)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES=10
|
||||
|
|
@ -221,15 +214,25 @@ VIDEO_PRESENTATION_FPS=30
|
|||
VIDEO_PRESENTATION_DEFAULT_DURATION_IN_FRAMES=300
|
||||
|
||||
|
||||
# (Optional) Maximum pages limit per user for ETL services (default: `999999999` for unlimited in OSS version)
|
||||
PAGES_LIMIT=500
|
||||
# Unified credit wallet starting balance for new users, in micro-USD
|
||||
# (default: 5,000,000 == $5.00). The same balance funds ETL page processing
|
||||
# and premium model calls, debited at actual provider cost.
|
||||
DEFAULT_CREDIT_MICROS_BALANCE=5000000
|
||||
|
||||
# Premium credit quota per registered user, in micro-USD
|
||||
# (default: 5,000,000 == $5.00 of credit). Premium turns are debited at the
|
||||
# actual per-call provider cost reported by LiteLLM, so cheap and expensive
|
||||
# models bill proportionally. Applies only to models with
|
||||
# billing_tier=premium in global_llm_config.yaml.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT=5000000
|
||||
# Debit the credit wallet for ETL page processing. Default FALSE keeps ETL
|
||||
# effectively free for self-hosted/OSS installs; hosted deployments set TRUE.
|
||||
# 1 page == MICROS_PER_PAGE micro-USD ($0.001); premium ETL mode is 10x.
|
||||
ETL_CREDIT_BILLING_ENABLED=FALSE
|
||||
MICROS_PER_PAGE=1000
|
||||
|
||||
# Low-balance warning threshold (micro-USD), surfaced to the UI. Default $0.50.
|
||||
CREDIT_LOW_BALANCE_WARNING_MICROS=500000
|
||||
|
||||
# Auto-reload: automatically top up via a saved Stripe card when the balance
|
||||
# drops below the user-chosen threshold. Off by default.
|
||||
AUTO_RELOAD_ENABLED=FALSE
|
||||
AUTO_RELOAD_MIN_AMOUNT_MICROS=1000000
|
||||
AUTO_RELOAD_COOLDOWN_MINUTES=10
|
||||
|
||||
# Safety ceiling on per-call premium reservation, in micro-USD.
|
||||
# stream_new_chat estimates an upper-bound cost from the model's
|
||||
|
|
|
|||
4
surfsense_backend/.gitignore
vendored
4
surfsense_backend/.gitignore
vendored
|
|
@ -1,12 +1,12 @@
|
|||
.env
|
||||
.venv
|
||||
venv/
|
||||
data/
|
||||
/data/
|
||||
.local_object_store/
|
||||
__pycache__/
|
||||
.flashrank_cache
|
||||
surf_new_backend.egg-info/
|
||||
podcasts/
|
||||
/podcasts/
|
||||
video_presentation_audio/
|
||||
sandbox_files/
|
||||
temp_audio/
|
||||
|
|
|
|||
235
surfsense_backend/alembic/versions/156_unify_credits_wallet.py
Normal file
235
surfsense_backend/alembic/versions/156_unify_credits_wallet.py
Normal file
|
|
@ -0,0 +1,235 @@
|
|||
"""unify page limits and premium credits into a single credit_micros_balance wallet
|
||||
|
||||
Collapses the two separate economies (ETL ``pages_limit``/``pages_used`` and
|
||||
premium ``premium_credit_micros_limit``/``premium_credit_micros_used``) into one
|
||||
USD-micro wallet column ``user.credit_micros_balance`` that decreases on use and
|
||||
increases on purchase / grant. ``premium_credit_micros_reserved`` is kept (renamed
|
||||
to ``credit_micros_reserved``) for in-flight reservation holds.
|
||||
|
||||
Backfill (per existing user row):
|
||||
|
||||
balance = GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used)
|
||||
+ (CASE WHEN pages_limit < 100000000
|
||||
THEN GREATEST(0, pages_limit - pages_used) * 1000
|
||||
ELSE 0 END)
|
||||
|
||||
The ``pages_limit < 100000000`` guard skips the OSS "unlimited" default
|
||||
(``PAGES_LIMIT=999999999``) so self-hosters don't get a ~$1M credit grant.
|
||||
1 page == 1000 micros == $0.001 (matches the prior $1 / 1000 pages price).
|
||||
|
||||
Table / type renames:
|
||||
|
||||
premium_token_purchases -> credit_purchases
|
||||
premiumtokenpurchasestatus (enum)-> creditpurchasestatus
|
||||
user_incentive_tasks.pages_awarded -> credit_micros_awarded (backfilled * 1000)
|
||||
|
||||
The "user" table is in zero_publication's column list, so this migration updates
|
||||
the publication via ``apply_publication`` (canonical reconcile, per migration 155)
|
||||
BEFORE dropping the old columns it referenced.
|
||||
|
||||
IMPORTANT - before AND after running this migration (same as migration 140):
|
||||
1. Stop zero-cache (it holds replication locks that will deadlock DDL)
|
||||
2. Run: alembic upgrade head
|
||||
3. Delete / reset the zero-cache data volume
|
||||
4. Restart zero-cache (it will do a fresh initial sync)
|
||||
|
||||
Revision ID: 156
|
||||
Revises: 155
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
from app.zero_publication import apply_publication
|
||||
|
||||
revision: str = "156"
|
||||
down_revision: str | None = "155"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col "
|
||||
"AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _table_exists(conn, table: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.tables "
|
||||
"WHERE table_name = :tbl AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
def _terminate_blocked_pids(conn, table: str) -> None:
|
||||
"""Kill backends whose locks on *table* would block our AccessExclusiveLock."""
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT pg_terminate_backend(l.pid) "
|
||||
"FROM pg_locks l "
|
||||
"JOIN pg_class c ON c.oid = l.relation "
|
||||
"WHERE c.relname = :tbl "
|
||||
" AND l.pid != pg_backend_pid()"
|
||||
),
|
||||
{"tbl": table},
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 1. Add credit_micros_balance + backfill from both legacy economies.
|
||||
# ------------------------------------------------------------------
|
||||
if not _column_exists(conn, "user", "credit_micros_balance"):
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"credit_micros_balance",
|
||||
sa.BigInteger(),
|
||||
nullable=False,
|
||||
server_default="5000000",
|
||||
),
|
||||
)
|
||||
|
||||
# Backfill only when ALL legacy source columns are present (fresh DBs
|
||||
# created from current models won't have them).
|
||||
if all(
|
||||
_column_exists(conn, "user", col)
|
||||
for col in (
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
)
|
||||
):
|
||||
conn.execute(
|
||||
sa.text(
|
||||
'UPDATE "user" SET credit_micros_balance = '
|
||||
"GREATEST(0, premium_credit_micros_limit - premium_credit_micros_used) "
|
||||
"+ (CASE WHEN pages_limit < 100000000 "
|
||||
" THEN GREATEST(0, pages_limit - pages_used) * 1000 "
|
||||
" ELSE 0 END)"
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 2. Rename premium_credit_micros_reserved -> credit_micros_reserved.
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "user", "premium_credit_micros_reserved"
|
||||
) and not _column_exists(conn, "user", "credit_micros_reserved"):
|
||||
op.alter_column(
|
||||
"user",
|
||||
"premium_credit_micros_reserved",
|
||||
new_column_name="credit_micros_reserved",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 3. Reconcile the Zero publication to the new column list
|
||||
# (id, credit_micros_balance) BEFORE dropping the columns it used
|
||||
# to reference, otherwise Postgres rejects the column drops with
|
||||
# "cannot drop column ... referenced by publication".
|
||||
# ------------------------------------------------------------------
|
||||
conn.execute(sa.text("SET lock_timeout = '10s'"))
|
||||
_terminate_blocked_pids(conn, "user")
|
||||
apply_publication(conn)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 4. Drop the legacy quota columns now that nothing references them.
|
||||
# ------------------------------------------------------------------
|
||||
for col in (
|
||||
"premium_credit_micros_limit",
|
||||
"premium_credit_micros_used",
|
||||
"pages_limit",
|
||||
"pages_used",
|
||||
):
|
||||
if _column_exists(conn, "user", col):
|
||||
op.drop_column("user", col)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 5. Rename premium_token_purchases -> credit_purchases and its enum.
|
||||
# ------------------------------------------------------------------
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE t.typname = 'premiumtokenpurchasestatus'
|
||||
AND n.nspname = current_schema()
|
||||
)
|
||||
AND NOT EXISTS (
|
||||
SELECT 1 FROM pg_type t
|
||||
JOIN pg_namespace n ON n.oid = t.typnamespace
|
||||
WHERE t.typname = 'creditpurchasestatus'
|
||||
AND n.nspname = current_schema()
|
||||
)
|
||||
THEN
|
||||
ALTER TYPE premiumtokenpurchasestatus RENAME TO creditpurchasestatus;
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
|
||||
if _table_exists(conn, "premium_token_purchases") and not _table_exists(
|
||||
conn, "credit_purchases"
|
||||
):
|
||||
op.rename_table("premium_token_purchases", "credit_purchases")
|
||||
|
||||
# ``source`` distinguishes user checkout from auto-reload top-ups.
|
||||
if _table_exists(conn, "credit_purchases") and not _column_exists(
|
||||
conn, "credit_purchases", "source"
|
||||
):
|
||||
op.add_column(
|
||||
"credit_purchases",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=20),
|
||||
nullable=False,
|
||||
server_default="checkout",
|
||||
),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 6. Rename user_incentive_tasks.pages_awarded -> credit_micros_awarded
|
||||
# and convert page counts to micros (1 page == 1000 micros).
|
||||
# ------------------------------------------------------------------
|
||||
if _column_exists(
|
||||
conn, "user_incentive_tasks", "pages_awarded"
|
||||
) and not _column_exists(conn, "user_incentive_tasks", "credit_micros_awarded"):
|
||||
op.alter_column(
|
||||
"user_incentive_tasks",
|
||||
"pages_awarded",
|
||||
new_column_name="credit_micros_awarded",
|
||||
type_=sa.BigInteger(),
|
||||
)
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"UPDATE user_incentive_tasks "
|
||||
"SET credit_micros_awarded = credit_micros_awarded * 1000"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op. This is a one-way data-model unification; the legacy split
|
||||
columns cannot be faithfully reconstructed from a single balance."""
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
"""add auto-reload (off-session Stripe top-up) columns to user
|
||||
|
||||
Adds the saved-card + threshold plumbing that powers feature-flagged credit
|
||||
auto-reload (``AUTO_RELOAD_ENABLED``):
|
||||
|
||||
user.stripe_customer_id (text, nullable)
|
||||
user.auto_reload_enabled (bool, default false)
|
||||
user.auto_reload_threshold_micros (bigint, nullable)
|
||||
user.auto_reload_amount_micros (bigint, nullable)
|
||||
user.auto_reload_payment_method_id (text, nullable)
|
||||
user.auto_reload_failed_at (timestamptz, nullable)
|
||||
|
||||
None of these columns are part of the Zero publication (``USER_COLS`` is
|
||||
``["id", "credit_micros_balance"]``), so this migration does NOT touch the
|
||||
publication and is safe to run without the zero-cache stop/reset dance that
|
||||
migration 156 required.
|
||||
|
||||
The ``credit_purchases.source`` column (``checkout`` | ``auto_reload``) was
|
||||
already added in migration 156, so it is not repeated here.
|
||||
|
||||
Revision ID: 157
|
||||
Revises: 156
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "157"
|
||||
down_revision: str | None = "156"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _column_exists(conn, table: str, column: str) -> bool:
|
||||
return (
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM information_schema.columns "
|
||||
"WHERE table_name = :tbl AND column_name = :col "
|
||||
"AND table_schema = current_schema()"
|
||||
),
|
||||
{"tbl": table, "col": column},
|
||||
).fetchone()
|
||||
is not None
|
||||
)
|
||||
|
||||
|
||||
_COLUMNS: list[tuple[str, sa.Column]] = [
|
||||
("stripe_customer_id", sa.Column("stripe_customer_id", sa.String(), nullable=True)),
|
||||
(
|
||||
"auto_reload_enabled",
|
||||
sa.Column(
|
||||
"auto_reload_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
),
|
||||
(
|
||||
"auto_reload_threshold_micros",
|
||||
sa.Column("auto_reload_threshold_micros", sa.BigInteger(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_amount_micros",
|
||||
sa.Column("auto_reload_amount_micros", sa.BigInteger(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_payment_method_id",
|
||||
sa.Column("auto_reload_payment_method_id", sa.String(), nullable=True),
|
||||
),
|
||||
(
|
||||
"auto_reload_failed_at",
|
||||
sa.Column("auto_reload_failed_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for name, column in _COLUMNS:
|
||||
if not _column_exists(conn, "user", name):
|
||||
op.add_column("user", column)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
for name, _ in reversed(_COLUMNS):
|
||||
if _column_exists(conn, "user", name):
|
||||
op.drop_column("user", name)
|
||||
|
|
@ -0,0 +1,118 @@
|
|||
"""evolve podcasts: expand status lifecycle and add brief/transcript/storage columns
|
||||
|
||||
Revision ID: 158
|
||||
Revises: 157
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "158"
|
||||
down_revision: str | None = "157"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def _drop_podcasts_from_publication() -> None:
|
||||
"""Detach podcasts from zero_publication so status can be retyped.
|
||||
|
||||
Postgres refuses ``ALTER COLUMN ... TYPE`` on a column a publication
|
||||
depends on. Some databases reach this migration with podcasts already
|
||||
published (an interim apply_publication ran during 156); drop it here and
|
||||
let migration 159 reconcile the publication to the canonical shape.
|
||||
"""
|
||||
conn = op.get_bind()
|
||||
published = conn.execute(
|
||||
sa.text(
|
||||
"SELECT 1 FROM pg_publication_tables "
|
||||
"WHERE pubname = 'zero_publication' "
|
||||
"AND schemaname = current_schema() AND tablename = 'podcasts'"
|
||||
)
|
||||
).fetchone()
|
||||
if published:
|
||||
op.execute('ALTER PUBLICATION "zero_publication" DROP TABLE "podcasts";')
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_podcasts_from_publication()
|
||||
|
||||
# Retype the status enum by swapping in a fresh type and casting existing
|
||||
# rows. The legacy transient value 'generating' maps onto 'rendering'.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_old;")
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TYPE podcast_status AS ENUM (
|
||||
'pending', 'awaiting_brief', 'drafting', 'awaiting_review',
|
||||
'rendering', 'ready', 'failed', 'cancelled'
|
||||
);
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'generating' THEN 'rendering'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'pending';")
|
||||
op.execute("DROP TYPE podcast_status_old;")
|
||||
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS source_content TEXT;")
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec JSONB;")
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS spec_version "
|
||||
"INTEGER NOT NULL DEFAULT 1;"
|
||||
)
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_backend VARCHAR(32);"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS storage_key TEXT;")
|
||||
op.execute(
|
||||
"ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS duration_seconds INTEGER;"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ADD COLUMN IF NOT EXISTS error TEXT;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS error;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS duration_seconds;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_key;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS storage_backend;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec_version;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS spec;")
|
||||
op.execute("ALTER TABLE podcasts DROP COLUMN IF EXISTS source_content;")
|
||||
|
||||
# Collapse the expanded lifecycle back onto the original four values.
|
||||
op.execute("ALTER TYPE podcast_status RENAME TO podcast_status_new;")
|
||||
op.execute(
|
||||
"CREATE TYPE podcast_status AS ENUM "
|
||||
"('pending', 'generating', 'ready', 'failed');"
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status DROP DEFAULT;")
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE podcasts
|
||||
ALTER COLUMN status TYPE podcast_status
|
||||
USING (
|
||||
CASE status::text
|
||||
WHEN 'awaiting_brief' THEN 'pending'
|
||||
WHEN 'drafting' THEN 'generating'
|
||||
WHEN 'awaiting_review' THEN 'generating'
|
||||
WHEN 'rendering' THEN 'generating'
|
||||
WHEN 'cancelled' THEN 'failed'
|
||||
ELSE status::text
|
||||
END
|
||||
)::podcast_status;
|
||||
"""
|
||||
)
|
||||
op.execute("ALTER TABLE podcasts ALTER COLUMN status SET DEFAULT 'ready';")
|
||||
op.execute("DROP TYPE podcast_status_new;")
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""publish podcasts to zero_publication
|
||||
|
||||
Reconciles ``zero_publication`` after migration 158 added the lifecycle columns,
|
||||
so the frontend observes podcast status and the reviewable brief by push.
|
||||
|
||||
Revision ID: 159
|
||||
Revises: 158
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
from app.zero_publication import apply_publication
|
||||
|
||||
revision: str = "159"
|
||||
down_revision: str | None = "158"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
apply_publication(op.get_bind())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""No-op. Historical publication shapes are immutable."""
|
||||
|
|
@ -126,23 +126,25 @@ user: "Create issues in Linear for each of these five bugs: <list>"
|
|||
|
||||
<example>
|
||||
user: "Make a 30-second podcast of this conversation."
|
||||
→ Celery-backed deliverable. The `deliverables` subagent dispatches the
|
||||
Celery job and then **waits for it to finish** before returning. The
|
||||
call may take 10-60 seconds (or longer for video presentations) —
|
||||
that is intentional, not a hang. You always get back one of two
|
||||
Receipt shapes:
|
||||
→ Podcast deliverable. The `deliverables` subagent sets the podcast up and
|
||||
returns **immediately** — generation does not happen during the call. A
|
||||
live card in the chat takes over from there: the user reviews the brief
|
||||
(language, voices, length) on the card, and the episode drafts and
|
||||
renders automatically after they approve.
|
||||
task(deliverables, "Generate a podcast titled '<title>' from the
|
||||
following content. Use a 30-second style brief. Return the podcast
|
||||
id and title.\n\n<source content>")
|
||||
following content. Aim for a 30-second style brief. Return the
|
||||
podcast id and title.\n\n<source content>")
|
||||
Outcomes:
|
||||
- **`status="success"`**: the audio is saved. Tell the user the
|
||||
podcast is **ready** and quote the `external_id` / `preview` so
|
||||
they can find it in the podcast panel.
|
||||
- **`status="success"`**: the podcast is set up. Do NOT describe its
|
||||
current status or promise it is ready — the card tracks progress
|
||||
live and will outlive whatever you say. Just point the user at the
|
||||
card in the chat.
|
||||
- **`status="failed"`**: surface the Receipt's `error` field
|
||||
verbatim. Do NOT silently re-dispatch — the backend already tried
|
||||
and reported a real error.
|
||||
Same two-way pattern applies to video presentations (which take
|
||||
longer to render, but still return a terminal status). If a
|
||||
Video presentations differ: that Celery-backed call **waits for the
|
||||
render to finish** before returning (possibly minutes — intentional,
|
||||
not a hang) and ends with a terminal status. If a
|
||||
`task(deliverables, ...)` invocation itself times out at the subagent
|
||||
layer (separate from the Receipt), that's an operator-side problem
|
||||
with the subagent invoke timeout, not a deliverable failure — pass
|
||||
|
|
|
|||
|
|
@ -1,11 +1,10 @@
|
|||
"""Factory for a podcast-generation tool.
|
||||
|
||||
Dispatches the heavy generation to Celery and then polls the podcast row
|
||||
until it reaches a terminal status (READY/FAILED). The tool always
|
||||
returns a real terminal ``Receipt`` — never a pending one. The wait is
|
||||
bounded by the existing per-invocation safety net
|
||||
(``SURFSENSE_SUBAGENT_INVOKE_TIMEOUT_SECONDS`` in multi-agent mode,
|
||||
HTTP / process lifetime in single-agent mode).
|
||||
Creates the podcast and proposes its brief (language, voices, length) inline,
|
||||
then returns immediately with the row awaiting review. Everything after —
|
||||
brief approval, drafting, rendering — happens on the live podcast card, so
|
||||
this tool never blocks on generation and the chat text must not describe a
|
||||
status that the card will outgrow.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -18,13 +17,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.command import with_receipt
|
||||
from app.agents.chat.multi_agent_chat.shared.receipts.receipt import make_receipt
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait import (
|
||||
wait_for_deliverable,
|
||||
)
|
||||
from app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.tools.thread_resolver import (
|
||||
resolve_root_thread_id,
|
||||
)
|
||||
from app.db import Podcast, PodcastStatus, shielded_async_session
|
||||
from app.db import PodcastStatus, shielded_async_session
|
||||
from app.podcasts.generation.brief import propose_brief
|
||||
from app.podcasts.service import PodcastService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -45,7 +43,7 @@ def create_generate_podcast_tool(
|
|||
user_prompt: str | None = None,
|
||||
) -> Command:
|
||||
"""
|
||||
Generate a podcast from the provided content.
|
||||
Prepare a podcast from the provided content for the user to review.
|
||||
|
||||
Use this tool when the user asks to create, generate, or make a podcast.
|
||||
Common triggers include phrases like:
|
||||
|
|
@ -55,100 +53,59 @@ def create_generate_podcast_tool(
|
|||
- "Make a podcast about..."
|
||||
- "Turn this into a podcast"
|
||||
|
||||
This sets up the podcast and proposes its brief (language, voices,
|
||||
length). The user reviews the brief on the live podcast card in the
|
||||
chat; after approval the episode drafts and renders automatically.
|
||||
Generation does not start here, and the card tracks all progress — do
|
||||
not describe the podcast's current status in your reply.
|
||||
|
||||
Args:
|
||||
source_content: The text content to convert into a podcast.
|
||||
podcast_title: Title for the podcast (default: "SurfSense Podcast")
|
||||
user_prompt: Optional instructions for podcast style, tone, or format.
|
||||
user_prompt: Optional steer for what the episode should focus on.
|
||||
|
||||
Returns:
|
||||
A dictionary containing:
|
||||
- status: PodcastStatus value (pending, generating, or failed)
|
||||
- podcast_id: The podcast ID for polling (when status is pending or generating)
|
||||
- title: The podcast title
|
||||
- message: Status message (or "error" field if status is failed)
|
||||
- status: the podcast lifecycle status (awaiting_brief on success)
|
||||
- podcast_id: the podcast ID to review in the panel
|
||||
- title: the podcast title
|
||||
- message: what the user should do next (or "error" when failed)
|
||||
"""
|
||||
try:
|
||||
# One DB session per tool call so parallel invocations never share an AsyncSession.
|
||||
async with shielded_async_session() as session:
|
||||
podcast = Podcast(
|
||||
service = PodcastService(session)
|
||||
podcast = await service.create(
|
||||
title=podcast_title,
|
||||
status=PodcastStatus.PENDING,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=resolve_root_thread_id(runtime, thread_id),
|
||||
)
|
||||
session.add(podcast)
|
||||
podcast.source_content = source_content
|
||||
spec = await propose_brief(
|
||||
session,
|
||||
search_space_id=search_space_id,
|
||||
focus=user_prompt,
|
||||
)
|
||||
await service.attach_brief(podcast, spec)
|
||||
await session.commit()
|
||||
await session.refresh(podcast)
|
||||
podcast_id = podcast.id
|
||||
|
||||
from app.tasks.celery_tasks.podcast_tasks import (
|
||||
generate_content_podcast_task,
|
||||
)
|
||||
|
||||
task = generate_content_podcast_task.delay(
|
||||
podcast_id=podcast_id,
|
||||
source_content=source_content,
|
||||
search_space_id=search_space_id,
|
||||
user_prompt=user_prompt,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"[generate_podcast] Created podcast %s, task: %s",
|
||||
"[generate_podcast] Prepared podcast %s awaiting brief review",
|
||||
podcast_id,
|
||||
task.id,
|
||||
)
|
||||
|
||||
# Wait until the Celery worker flips the row to a terminal
|
||||
# state. The wait is bounded only by the subagent invoke
|
||||
# timeout (multi-agent) or HTTP lifetime (single-agent) —
|
||||
# see app.agents.chat.multi_agent_chat.subagents.builtins.deliverables.deliverable_wait for details.
|
||||
terminal_status, columns, elapsed = await wait_for_deliverable(
|
||||
model=Podcast,
|
||||
row_id=podcast_id,
|
||||
columns=[Podcast.status, Podcast.file_location],
|
||||
terminal_statuses={PodcastStatus.READY, PodcastStatus.FAILED},
|
||||
)
|
||||
|
||||
if terminal_status == PodcastStatus.READY:
|
||||
file_location = columns[1] if columns else None
|
||||
logger.info(
|
||||
"[generate_podcast] Podcast %s READY in %.2fs (file=%s)",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
file_location,
|
||||
)
|
||||
payload: dict[str, Any] = {
|
||||
"status": PodcastStatus.READY.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"file_location": file_location,
|
||||
"message": ("Podcast generated and saved to your podcast panel."),
|
||||
}
|
||||
return with_receipt(
|
||||
payload=payload,
|
||||
receipt=make_receipt(
|
||||
route="deliverables",
|
||||
type="podcast",
|
||||
operation="generate",
|
||||
status="success",
|
||||
external_id=str(podcast_id),
|
||||
preview=podcast_title,
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
||||
# Only other terminal state is FAILED.
|
||||
logger.warning(
|
||||
"[generate_podcast] Podcast %s FAILED in %.2fs",
|
||||
podcast_id,
|
||||
elapsed,
|
||||
)
|
||||
err = "Background worker reported FAILED status for this podcast."
|
||||
payload = {
|
||||
"status": PodcastStatus.FAILED.value,
|
||||
payload: dict[str, Any] = {
|
||||
"status": PodcastStatus.AWAITING_BRIEF.value,
|
||||
"podcast_id": podcast_id,
|
||||
"title": podcast_title,
|
||||
"error": err,
|
||||
"message": (
|
||||
"Podcast set up. The card in the chat handles the rest: "
|
||||
"the user reviews the brief (language, voices, length) "
|
||||
"there, and the episode drafts and renders automatically "
|
||||
"after approval. The card tracks progress live, so do not "
|
||||
"state the podcast's current status in your reply."
|
||||
),
|
||||
}
|
||||
return with_receipt(
|
||||
payload=payload,
|
||||
|
|
@ -156,10 +113,9 @@ def create_generate_podcast_tool(
|
|||
route="deliverables",
|
||||
type="podcast",
|
||||
operation="generate",
|
||||
status="failed",
|
||||
status="success",
|
||||
external_id=str(podcast_id),
|
||||
preview=podcast_title,
|
||||
error=err,
|
||||
),
|
||||
tool_call_id=runtime.tool_call_id,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +0,0 @@
|
|||
"""New LangGraph Agent.
|
||||
|
||||
This module defines a custom graph.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
||||
__all__ = ["graph"]
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
"""Define the configurable parameters for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class Configuration:
|
||||
"""The configuration for the agent."""
|
||||
|
||||
# Changeme: Add configurable values here!
|
||||
# these values can be pre-set when you
|
||||
# create assistants (https://langchain-ai.github.io/langgraph/cloud/how-tos/configuration_cloud/)
|
||||
# and when you invoke the graph
|
||||
podcast_title: str
|
||||
search_space_id: int
|
||||
user_prompt: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> Configuration:
|
||||
"""Create a Configuration instance from a RunnableConfig object."""
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
_fields = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
from langgraph.graph import StateGraph
|
||||
|
||||
from .configuration import Configuration
|
||||
from .nodes import create_merged_podcast_audio, create_podcast_transcript
|
||||
from .state import State
|
||||
|
||||
|
||||
def build_graph():
|
||||
# Define a new graph
|
||||
workflow = StateGraph(State, config_schema=Configuration)
|
||||
|
||||
# Add the node to the graph
|
||||
workflow.add_node("create_podcast_transcript", create_podcast_transcript)
|
||||
workflow.add_node("create_merged_podcast_audio", create_merged_podcast_audio)
|
||||
|
||||
# Set the entrypoint as `call_model`
|
||||
workflow.add_edge("__start__", "create_podcast_transcript")
|
||||
workflow.add_edge("create_podcast_transcript", "create_merged_podcast_audio")
|
||||
workflow.add_edge("create_merged_podcast_audio", "__end__")
|
||||
|
||||
# Compile the workflow into an executable graph
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcaster" # This defines the custom name in LangSmith
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
# Compile the graph once when the module is loaded
|
||||
graph = build_graph()
|
||||
|
|
@ -1,195 +0,0 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from litellm import aspeech
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.services.kokoro_tts_service import get_kokoro_tts_service
|
||||
from app.services.llm_service import get_agent_llm
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
from .configuration import Configuration
|
||||
from .prompts import get_podcast_generation_prompt
|
||||
from .state import PodcastTranscriptEntry, PodcastTranscripts, State
|
||||
from .utils import get_voice_for_provider
|
||||
|
||||
|
||||
async def create_podcast_transcript(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate the podcast transcript from the source content."""
|
||||
configuration = Configuration.from_runnable_config(config)
|
||||
search_space_id = configuration.search_space_id
|
||||
user_prompt = configuration.user_prompt
|
||||
|
||||
llm = await get_agent_llm(state.db_session, search_space_id)
|
||||
if not llm:
|
||||
error_message = f"No agent LLM configured for search space {search_space_id}"
|
||||
print(error_message)
|
||||
raise RuntimeError(error_message)
|
||||
|
||||
prompt = get_podcast_generation_prompt(user_prompt)
|
||||
messages = [
|
||||
SystemMessage(content=prompt),
|
||||
HumanMessage(
|
||||
content=f"<source_content>{state.source_content}</source_content>"
|
||||
),
|
||||
]
|
||||
llm_response = await llm.ainvoke(messages)
|
||||
|
||||
# Reasoning models may return content as blocks; normalise to a string.
|
||||
content = strip_markdown_fences(extract_text_content(llm_response.content))
|
||||
|
||||
try:
|
||||
podcast_transcript = PodcastTranscripts.model_validate(json.loads(content))
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e:
|
||||
print(f"Direct JSON parsing failed, trying fallback approach: {e!s}")
|
||||
|
||||
try:
|
||||
json_start = content.find("{")
|
||||
json_end = content.rfind("}") + 1
|
||||
if json_start >= 0 and json_end > json_start:
|
||||
json_str = content[json_start:json_end]
|
||||
parsed_data = json.loads(json_str)
|
||||
podcast_transcript = PodcastTranscripts.model_validate(parsed_data)
|
||||
print("Successfully parsed podcast transcript using fallback approach")
|
||||
else:
|
||||
error_message = f"Could not find valid JSON in LLM response. Raw response: {content}"
|
||||
print(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except (json.JSONDecodeError, TypeError, ValueError) as e2:
|
||||
error_message = f"Error parsing LLM response (fallback also failed): {e2!s}"
|
||||
print(f"Error parsing LLM response: {e2!s}")
|
||||
print(f"Raw response: {content}")
|
||||
raise
|
||||
|
||||
return {"podcast_transcript": podcast_transcript.podcast_transcripts}
|
||||
|
||||
|
||||
async def create_merged_podcast_audio(
|
||||
state: State, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Generate audio for each transcript and merge them into a single podcast file."""
|
||||
starting_transcript = PodcastTranscriptEntry(
|
||||
speaker_id=1, dialog="Welcome to Surfsense Podcast."
|
||||
)
|
||||
|
||||
transcript = state.podcast_transcript
|
||||
|
||||
# transcript may be a PodcastTranscripts object or already a list.
|
||||
if hasattr(transcript, "podcast_transcripts"):
|
||||
transcript_entries = transcript.podcast_transcripts
|
||||
else:
|
||||
transcript_entries = transcript
|
||||
|
||||
merged_transcript = [starting_transcript, *transcript_entries]
|
||||
|
||||
temp_dir = Path("temp_audio")
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
session_id = str(uuid.uuid4())
|
||||
output_path = f"podcasts/{session_id}_podcast.mp3"
|
||||
os.makedirs("podcasts", exist_ok=True)
|
||||
|
||||
audio_files = []
|
||||
|
||||
async def generate_speech_for_segment(segment, index):
|
||||
if hasattr(segment, "speaker_id"):
|
||||
speaker_id = segment.speaker_id
|
||||
dialog = segment.dialog
|
||||
else:
|
||||
speaker_id = segment.get("speaker_id", 0)
|
||||
dialog = segment.get("dialog", "")
|
||||
|
||||
voice = get_voice_for_provider(app_config.TTS_SERVICE, speaker_id)
|
||||
|
||||
if app_config.TTS_SERVICE == "local/kokoro":
|
||||
filename = f"{temp_dir}/{session_id}_{index}.wav"
|
||||
else:
|
||||
filename = f"{temp_dir}/{session_id}_{index}.mp3"
|
||||
|
||||
try:
|
||||
if app_config.TTS_SERVICE == "local/kokoro":
|
||||
kokoro_service = await get_kokoro_tts_service(
|
||||
lang_code="a"
|
||||
) # American English
|
||||
audio_path = await kokoro_service.generate_speech(
|
||||
text=dialog, voice=voice, speed=1.0, output_path=filename
|
||||
)
|
||||
return audio_path
|
||||
else:
|
||||
if app_config.TTS_SERVICE_API_BASE:
|
||||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_base=app_config.TTS_SERVICE_API_BASE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
else:
|
||||
response = await aspeech(
|
||||
model=app_config.TTS_SERVICE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
voice=voice,
|
||||
input=dialog,
|
||||
max_retries=2,
|
||||
timeout=600,
|
||||
)
|
||||
|
||||
with open(filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
return filename
|
||||
except Exception as e:
|
||||
print(f"Error generating speech for segment {index}: {e!s}")
|
||||
raise
|
||||
|
||||
tasks = [
|
||||
generate_speech_for_segment(segment, i)
|
||||
for i, segment in enumerate(merged_transcript)
|
||||
]
|
||||
audio_files = await asyncio.gather(*tasks)
|
||||
|
||||
try:
|
||||
ffmpeg = FFmpeg().option("y")
|
||||
for audio_file in audio_files:
|
||||
ffmpeg = ffmpeg.input(audio_file)
|
||||
|
||||
filter_complex = []
|
||||
for i in range(len(audio_files)):
|
||||
filter_complex.append(f"[{i}:0]")
|
||||
|
||||
filter_complex_str = (
|
||||
"".join(filter_complex) + f"concat=n={len(audio_files)}:v=0:a=1[outa]"
|
||||
)
|
||||
ffmpeg = ffmpeg.option("filter_complex", filter_complex_str)
|
||||
ffmpeg = ffmpeg.output(output_path, map="[outa]")
|
||||
await ffmpeg.execute()
|
||||
|
||||
print(f"Successfully created podcast audio: {output_path}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error merging audio files: {e!s}")
|
||||
raise
|
||||
finally:
|
||||
for audio_file in audio_files:
|
||||
try:
|
||||
os.remove(audio_file)
|
||||
except Exception as e:
|
||||
print(f"Error removing audio file {audio_file}: {e!s}")
|
||||
pass
|
||||
|
||||
return {
|
||||
"podcast_transcript": merged_transcript,
|
||||
"final_podcast_file_path": output_path,
|
||||
}
|
||||
|
|
@ -1,122 +0,0 @@
|
|||
import datetime
|
||||
|
||||
|
||||
def get_podcast_generation_prompt(user_prompt: str | None = None):
|
||||
return f"""
|
||||
Today's date: {datetime.datetime.now().strftime("%Y-%m-%d")}
|
||||
<podcast_generation_system>
|
||||
You are a master podcast scriptwriter, adept at transforming diverse input content into a lively, engaging, and natural-sounding conversation between two distinct podcast hosts. Your primary objective is to craft authentic, flowing dialogue that captures the spontaneity and chemistry of a real podcast discussion, completely avoiding any hint of robotic scripting or stiff formality. Think dynamic interplay, not just information delivery.
|
||||
|
||||
{
|
||||
f'''
|
||||
You **MUST** strictly adhere to the following user instruction while generating the podcast script:
|
||||
<user_instruction>
|
||||
{user_prompt}
|
||||
</user_instruction>
|
||||
'''
|
||||
if user_prompt
|
||||
else ""
|
||||
}
|
||||
|
||||
<input>
|
||||
- '<source_content>': A block of text containing the information to be discussed in the podcast. This could be research findings, an article summary, a detailed outline, user chat history related to the topic, or any other relevant raw information. The content might be unstructured but serves as the factual basis for the podcast dialogue.
|
||||
</input>
|
||||
|
||||
<output_format>
|
||||
A JSON object containing the podcast transcript with alternating speakers:
|
||||
{{
|
||||
"podcast_transcripts": [
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Speaker 0 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Speaker 1 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Speaker 0 dialog here"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Speaker 1 dialog here"
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</output_format>
|
||||
|
||||
<guidelines>
|
||||
1. **Establish Distinct & Consistent Host Personas:**
|
||||
* **Speaker 0 (Lead Host):** Drives the conversation forward, introduces segments, poses key questions derived from the source content, and often summarizes takeaways. Maintain a guiding, clear, and engaging tone.
|
||||
* **Speaker 1 (Co-Host/Expert):** Offers deeper insights, provides alternative viewpoints or elaborations on the source content, asks clarifying or challenging questions, and shares relevant anecdotes or examples. Adopt a complementary tone (e.g., analytical, enthusiastic, reflective, slightly skeptical).
|
||||
* **Consistency is Key:** Ensure each speaker maintains their distinct voice, vocabulary choice, sentence structure, and perspective throughout the entire script. Avoid having them sound interchangeable. Their interaction should feel like a genuine partnership.
|
||||
|
||||
2. **Craft Natural & Dynamic Dialogue:**
|
||||
* **Emulate Real Conversation:** Use contractions (e.g., "don't", "it's"), interjections ("Oh!", "Wow!", "Hmm"), discourse markers ("you know", "right?", "well"), and occasional natural pauses or filler words. Avoid overly formal language or complex sentence structures typical of written text.
|
||||
* **Foster Interaction & Chemistry:** Write dialogue where speakers genuinely react *to each other*. They should build on points ("Exactly, and that reminds me..."), ask follow-up questions ("Could you expand on that?"), express agreement/disagreement respectfully ("That's a fair point, but have you considered...?"), and show active listening.
|
||||
* **Vary Rhythm & Pace:** Mix short, punchy lines with longer, more explanatory ones. Vary sentence beginnings. Use questions to break up exposition. The rhythm should feel spontaneous, not monotonous.
|
||||
* **Inject Personality & Relatability:** Allow for appropriate humor, moments of surprise or curiosity, brief personal reflections ("I actually experienced something similar..."), or relatable asides that fit the hosts' personas and the topic. Lightly reference past discussions if it enhances context ("Remember last week when we touched on...?").
|
||||
|
||||
3. **Structure for Flow and Listener Engagement:**
|
||||
* **Natural Beginning:** Start with dialogue that flows naturally after an introduction (which will be added manually). Avoid redundant greetings or podcast name mentions since these will be added separately.
|
||||
* **Logical Progression & Signposting:** Guide the listener through the information smoothly. Use clear transitions to link different ideas or segments ("So, now that we've covered X, let's dive into Y...", "That actually brings me to another key finding..."). Ensure topics flow logically from one to the next.
|
||||
* **Meaningful Conclusion:** Summarize the key takeaways or main points discussed, reinforcing the core message derived from the source content. End with a final thought, a lingering question for the audience, or a brief teaser for what's next, providing a sense of closure. Avoid abrupt endings.
|
||||
|
||||
4. **Integrate Source Content Seamlessly & Accurately:**
|
||||
* **Translate, Don't Recite:** Rephrase information from the `<source_content>` into conversational language suitable for each host's persona. Avoid directly copying dense sentences or technical jargon without explanation. The goal is discussion, not narration.
|
||||
* **Explain & Contextualize:** Use analogies, simple examples, storytelling, or have one host ask clarifying questions (acting as a listener surrogate) to break down complex ideas from the source.
|
||||
* **Weave Information Naturally:** Integrate facts, data, or key points from the source *within* the dialogue, not as standalone, undigested blocks. Attribute information conversationally where appropriate ("The research mentioned...", "Apparently, the key factor is...").
|
||||
* **Balance Depth & Accessibility:** Ensure the conversation is informative and factually accurate based on the source content, but prioritize clear communication and engaging delivery over exhaustive technical detail. Make it understandable and interesting for a general audience.
|
||||
|
||||
5. **Length & Pacing:**
|
||||
* **Six-Minute Duration:** Create a transcript that, when read at a natural speaking pace, would result in approximately 6 minutes of audio. Typically, this means around 1000 words total (based on average speaking rate of 150 words per minute).
|
||||
* **Concise Speaking Turns:** Keep most speaking turns relatively brief and focused. Aim for a natural back-and-forth rhythm rather than extended monologues.
|
||||
* **Essential Content Only:** Prioritize the most important information from the source content. Focus on quality over quantity, ensuring every line contributes meaningfully to the topic.
|
||||
</guidelines>
|
||||
|
||||
<examples>
|
||||
Input: "Quantum computing uses quantum bits or qubits which can exist in multiple states simultaneously due to superposition."
|
||||
|
||||
Output:
|
||||
{{
|
||||
"podcast_transcripts": [
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Today we're diving into the mind-bending world of quantum computing. You know, this is a topic I've been excited to cover for weeks."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Same here! And I know our listeners have been asking for it. But I have to admit, the concept of quantum computing makes my head spin a little. Can we start with the basics?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Absolutely. So regular computers use bits, right? Little on-off switches that are either 1 or 0. But quantum computers use something called qubits, and this is where it gets fascinating."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Wait, what makes qubits so special compared to regular bits?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "The magic is in something called superposition. These qubits can exist in multiple states at the same time, not just 1 or 0."
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "That sounds impossible! How would you even picture that?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 0,
|
||||
"dialog": "Think of it like a coin spinning in the air. Before it lands, is it heads or tails?"
|
||||
}},
|
||||
{{
|
||||
"speaker_id": 1,
|
||||
"dialog": "Well, it's... neither? Or I guess both, until it lands? Oh, I think I see where you're going with this."
|
||||
}}
|
||||
]
|
||||
}}
|
||||
</examples>
|
||||
|
||||
Transform the source material into a lively and engaging podcast conversation. Craft dialogue that showcases authentic host chemistry and natural interaction (including occasional disagreement, building on points, or asking follow-up questions). Use varied speech patterns reflecting real human conversation, ensuring the final script effectively educates *and* entertains the listener while keeping within a 5-minute audio duration.
|
||||
</podcast_generation_system>
|
||||
"""
|
||||
|
|
@ -1,43 +0,0 @@
|
|||
"""Define the state structures for the agent."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
|
||||
class PodcastTranscriptEntry(BaseModel):
|
||||
"""
|
||||
Represents a single entry in a podcast transcript.
|
||||
"""
|
||||
|
||||
speaker_id: int = Field(..., description="The ID of the speaker (0 or 1)")
|
||||
dialog: str = Field(..., description="The dialog text spoken by the speaker")
|
||||
|
||||
|
||||
class PodcastTranscripts(BaseModel):
|
||||
"""
|
||||
Represents the full podcast transcript structure.
|
||||
"""
|
||||
|
||||
podcast_transcripts: list[PodcastTranscriptEntry] = Field(
|
||||
..., description="List of transcript entries with alternating speakers"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class State:
|
||||
"""Defines the input state for the agent, representing a narrower interface to the outside world.
|
||||
|
||||
This class is used to define the initial state and structure of incoming data.
|
||||
See: https://langchain-ai.github.io/langgraph/concepts/low_level/#state
|
||||
for more information.
|
||||
"""
|
||||
|
||||
# Runtime context
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
podcast_transcript: list[PodcastTranscriptEntry] | None = None
|
||||
final_podcast_file_path: str | None = None
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
def get_voice_for_provider(provider: str, speaker_id: int) -> dict | str:
|
||||
"""
|
||||
Get the appropriate voice configuration based on the TTS provider and speaker ID.
|
||||
|
||||
Args:
|
||||
provider: The TTS provider (e.g., "openai/tts-1", "vertex_ai/test")
|
||||
speaker_id: The ID of the speaker (0-5)
|
||||
|
||||
Returns:
|
||||
Voice configuration - string for OpenAI, dict for Vertex AI
|
||||
"""
|
||||
if provider == "local/kokoro":
|
||||
# Kokoro voice mapping - https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
|
||||
kokoro_voices = {
|
||||
0: "am_adam", # Default/intro voice
|
||||
1: "af_bella", # First speaker
|
||||
}
|
||||
return kokoro_voices.get(speaker_id, "af_heart")
|
||||
|
||||
# Extract provider type from the model string
|
||||
provider_type = (
|
||||
provider.split("/")[0].lower() if "/" in provider else provider.lower()
|
||||
)
|
||||
|
||||
if provider_type == "openai":
|
||||
# OpenAI voice mapping - simple string values
|
||||
openai_voices = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
2: "fable", # Second speaker
|
||||
3: "onyx", # Third speaker
|
||||
4: "nova", # Fourth speaker
|
||||
5: "shimmer", # Fifth speaker
|
||||
}
|
||||
return openai_voices.get(speaker_id, "alloy")
|
||||
|
||||
elif provider_type == "vertex_ai":
|
||||
# Vertex AI voice mapping - dict with languageCode and name
|
||||
vertex_voices = {
|
||||
0: {
|
||||
"languageCode": "en-US",
|
||||
"name": "en-US-Studio-O",
|
||||
},
|
||||
1: {
|
||||
"languageCode": "en-US",
|
||||
"name": "en-US-Studio-M",
|
||||
},
|
||||
2: {
|
||||
"languageCode": "en-UK",
|
||||
"name": "en-UK-Studio-A",
|
||||
},
|
||||
3: {
|
||||
"languageCode": "en-UK",
|
||||
"name": "en-UK-Studio-B",
|
||||
},
|
||||
4: {
|
||||
"languageCode": "en-AU",
|
||||
"name": "en-AU-Studio-A",
|
||||
},
|
||||
5: {
|
||||
"languageCode": "en-AU",
|
||||
"name": "en-AU-Studio-B",
|
||||
},
|
||||
}
|
||||
return vertex_voices.get(speaker_id, vertex_voices[0])
|
||||
elif provider_type == "azure":
|
||||
# OpenAI voice mapping - simple string values
|
||||
azure_voices = {
|
||||
0: "alloy", # Default/intro voice
|
||||
1: "echo", # First speaker
|
||||
2: "fable", # Second speaker
|
||||
3: "onyx", # Third speaker
|
||||
4: "nova", # Fourth speaker
|
||||
5: "shimmer", # Fifth speaker
|
||||
}
|
||||
return azure_voices.get(speaker_id, "alloy")
|
||||
|
||||
else:
|
||||
# Default fallback to OpenAI format for unknown providers
|
||||
default_voices = {
|
||||
0: {},
|
||||
1: {},
|
||||
}
|
||||
return default_voices.get(speaker_id, default_voices[0])
|
||||
|
|
@ -1,8 +1,7 @@
|
|||
"""Video Presentation LangGraph Agent.
|
||||
|
||||
This module defines a graph for generating video presentations
|
||||
from source content, similar to the podcaster agent but producing
|
||||
slide-based video presentations with TTS narration.
|
||||
This module defines a graph for generating slide-based video presentations
|
||||
from source content, with TTS narration per slide.
|
||||
"""
|
||||
|
||||
from .graph import graph
|
||||
|
|
|
|||
|
|
@ -181,7 +181,8 @@ celery_app = Celery(
|
|||
backend=CELERY_RESULT_BACKEND,
|
||||
include=[
|
||||
"app.tasks.celery_tasks.document_tasks",
|
||||
"app.tasks.celery_tasks.podcast_tasks",
|
||||
"app.podcasts.tasks.draft",
|
||||
"app.podcasts.tasks.render",
|
||||
"app.tasks.celery_tasks.video_presentation_tasks",
|
||||
"app.tasks.celery_tasks.connector_tasks",
|
||||
"app.tasks.celery_tasks.obsidian_tasks",
|
||||
|
|
@ -189,6 +190,7 @@ celery_app = Celery(
|
|||
"app.tasks.celery_tasks.document_reindex_tasks",
|
||||
"app.tasks.celery_tasks.stale_notification_cleanup_task",
|
||||
"app.tasks.celery_tasks.stripe_reconciliation_task",
|
||||
"app.tasks.celery_tasks.auto_reload_task",
|
||||
"app.tasks.celery_tasks.gateway_tasks",
|
||||
"app.automations.tasks.execute_run",
|
||||
"app.automations.triggers.builtin.schedule.selector",
|
||||
|
|
@ -281,16 +283,9 @@ celery_app.conf.beat_schedule = {
|
|||
"expires": 60, # Task expires after 60 seconds if not picked up
|
||||
},
|
||||
},
|
||||
# Reconcile Stripe purchases that were paid but remained pending
|
||||
"reconcile-pending-stripe-page-purchases": {
|
||||
"task": "reconcile_pending_stripe_page_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
"reconcile-pending-stripe-token-purchases": {
|
||||
"task": "reconcile_pending_stripe_token_purchases",
|
||||
# Reconcile Stripe credit purchases that were paid but remained pending
|
||||
"reconcile-pending-stripe-credit-purchases": {
|
||||
"task": "reconcile_pending_stripe_credit_purchases",
|
||||
"schedule": crontab(**stripe_reconciliation_schedule_params),
|
||||
"options": {
|
||||
"expires": 60,
|
||||
|
|
|
|||
|
|
@ -640,14 +640,9 @@ class Config:
|
|||
)
|
||||
GATEWAY_DISCORD_REDIRECT_URI = os.getenv("GATEWAY_DISCORD_REDIRECT_URI")
|
||||
|
||||
# Stripe checkout for pay-as-you-go page packs
|
||||
# Stripe checkout (shared secrets for the unified credit wallet)
|
||||
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
||||
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
STRIPE_PRICE_ID = os.getenv("STRIPE_PRICE_ID")
|
||||
STRIPE_PAGES_PER_UNIT = int(os.getenv("STRIPE_PAGES_PER_UNIT", "1000"))
|
||||
STRIPE_PAGE_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_PAGE_BUYING_ENABLED", "TRUE").upper() == "TRUE"
|
||||
)
|
||||
STRIPE_RECONCILIATION_LOOKBACK_MINUTES = int(
|
||||
os.getenv("STRIPE_RECONCILIATION_LOOKBACK_MINUTES", "10")
|
||||
)
|
||||
|
|
@ -655,27 +650,56 @@ class Config:
|
|||
os.getenv("STRIPE_RECONCILIATION_BATCH_SIZE", "100")
|
||||
)
|
||||
|
||||
# Premium credit (micro-USD) quota settings.
|
||||
# Unified credit wallet (micro-USD) settings.
|
||||
#
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). The legacy
|
||||
# ``PREMIUM_TOKEN_LIMIT`` and ``STRIPE_TOKENS_PER_UNIT`` env vars are
|
||||
# still honoured for one release as fall-back values — the prior
|
||||
# $1-per-1M-tokens Stripe price means every existing value maps 1:1
|
||||
# to micros, so operators upgrading without changing their .env still
|
||||
# get correct behaviour. A startup deprecation warning fires below if
|
||||
# they're set.
|
||||
PREMIUM_CREDIT_MICROS_LIMIT = int(
|
||||
os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
# Storage unit is integer micro-USD (1_000_000 = $1.00). A single
|
||||
# ``credit_micros_balance`` funds both ETL page processing and premium
|
||||
# model calls. New users start with ``DEFAULT_CREDIT_MICROS_BALANCE``
|
||||
# ($5 by default).
|
||||
#
|
||||
# Legacy env names (``PREMIUM_CREDIT_MICROS_LIMIT`` / ``PREMIUM_TOKEN_LIMIT``,
|
||||
# ``STRIPE_PREMIUM_TOKEN_PRICE_ID``, ``STRIPE_CREDIT_MICROS_PER_UNIT`` /
|
||||
# ``STRIPE_TOKENS_PER_UNIT``, ``STRIPE_TOKEN_BUYING_ENABLED``) are still
|
||||
# honoured as fall-backs for one release; deprecation warnings fire below.
|
||||
DEFAULT_CREDIT_MICROS_BALANCE = int(
|
||||
os.getenv("DEFAULT_CREDIT_MICROS_BALANCE")
|
||||
or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
or os.getenv("PREMIUM_TOKEN_LIMIT", "5000000")
|
||||
)
|
||||
STRIPE_PREMIUM_TOKEN_PRICE_ID = os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID")
|
||||
STRIPE_CREDIT_PRICE_ID = os.getenv("STRIPE_CREDIT_PRICE_ID") or os.getenv(
|
||||
"STRIPE_PREMIUM_TOKEN_PRICE_ID"
|
||||
)
|
||||
STRIPE_CREDIT_MICROS_PER_UNIT = int(
|
||||
os.getenv("STRIPE_CREDIT_MICROS_PER_UNIT")
|
||||
or os.getenv("STRIPE_TOKENS_PER_UNIT", "1000000")
|
||||
)
|
||||
STRIPE_TOKEN_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
STRIPE_CREDIT_BUYING_ENABLED = (
|
||||
os.getenv("STRIPE_CREDIT_BUYING_ENABLED")
|
||||
or os.getenv("STRIPE_TOKEN_BUYING_ENABLED", "FALSE")
|
||||
).upper() == "TRUE"
|
||||
|
||||
# ETL page processing debits the credit wallet only when enabled. Defaults
|
||||
# to FALSE so self-hosted / OSS installs keep effectively-free ETL; hosted
|
||||
# deployments set this TRUE. 1 page == ``MICROS_PER_PAGE`` micro-USD.
|
||||
ETL_CREDIT_BILLING_ENABLED = (
|
||||
os.getenv("ETL_CREDIT_BILLING_ENABLED", "FALSE").upper() == "TRUE"
|
||||
)
|
||||
MICROS_PER_PAGE = int(os.getenv("MICROS_PER_PAGE", "1000"))
|
||||
|
||||
# Low-balance WARNING threshold (micro-USD). Surfaced by the quota service
|
||||
# so the UI can nudge the user to top up / enable auto-reload. $0.50.
|
||||
CREDIT_LOW_BALANCE_WARNING_MICROS = int(
|
||||
os.getenv("CREDIT_LOW_BALANCE_WARNING_MICROS", "500000")
|
||||
)
|
||||
|
||||
# Auto-reload (off-session Stripe top-up) feature flag and guards.
|
||||
AUTO_RELOAD_ENABLED = os.getenv("AUTO_RELOAD_ENABLED", "FALSE").upper() == "TRUE"
|
||||
# Minimum configurable reload amount (micro-USD). $1.00 to match pack pricing.
|
||||
AUTO_RELOAD_MIN_AMOUNT_MICROS = int(
|
||||
os.getenv("AUTO_RELOAD_MIN_AMOUNT_MICROS", "1000000")
|
||||
)
|
||||
# Cooldown so a burst of debits can't fire multiple charges (minutes).
|
||||
AUTO_RELOAD_COOLDOWN_MINUTES = int(os.getenv("AUTO_RELOAD_COOLDOWN_MINUTES", "10"))
|
||||
|
||||
# Safety ceiling on the per-call premium reservation. ``stream_new_chat``
|
||||
# estimates an upper-bound cost from ``litellm.get_model_info`` x the
|
||||
|
|
@ -685,14 +709,13 @@ class Config:
|
|||
# reserve_tokens ≈ $0.36) with headroom.
|
||||
QUOTA_MAX_RESERVE_MICROS = int(os.getenv("QUOTA_MAX_RESERVE_MICROS", "1000000"))
|
||||
|
||||
if os.getenv("PREMIUM_TOKEN_LIMIT") and not os.getenv(
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT"
|
||||
):
|
||||
if (
|
||||
os.getenv("PREMIUM_TOKEN_LIMIT") or os.getenv("PREMIUM_CREDIT_MICROS_LIMIT")
|
||||
) and not os.getenv("DEFAULT_CREDIT_MICROS_BALANCE"):
|
||||
print(
|
||||
"Warning: PREMIUM_TOKEN_LIMIT is deprecated; rename to "
|
||||
"PREMIUM_CREDIT_MICROS_LIMIT (1:1 numerical mapping under the "
|
||||
"current Stripe price). The old key will be removed in a "
|
||||
"future release."
|
||||
"Warning: PREMIUM_TOKEN_LIMIT / PREMIUM_CREDIT_MICROS_LIMIT are "
|
||||
"deprecated; rename to DEFAULT_CREDIT_MICROS_BALANCE. The old keys "
|
||||
"will be removed in a future release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKENS_PER_UNIT") and not os.getenv(
|
||||
"STRIPE_CREDIT_MICROS_PER_UNIT"
|
||||
|
|
@ -702,6 +725,22 @@ class Config:
|
|||
"STRIPE_CREDIT_MICROS_PER_UNIT (1:1 numerical mapping). "
|
||||
"The old key will be removed in a future release."
|
||||
)
|
||||
if os.getenv("STRIPE_PREMIUM_TOKEN_PRICE_ID") and not os.getenv(
|
||||
"STRIPE_CREDIT_PRICE_ID"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_PREMIUM_TOKEN_PRICE_ID is deprecated; rename to "
|
||||
"STRIPE_CREDIT_PRICE_ID. The old key will be removed in a future "
|
||||
"release."
|
||||
)
|
||||
if os.getenv("STRIPE_TOKEN_BUYING_ENABLED") and not os.getenv(
|
||||
"STRIPE_CREDIT_BUYING_ENABLED"
|
||||
):
|
||||
print(
|
||||
"Warning: STRIPE_TOKEN_BUYING_ENABLED is deprecated; rename to "
|
||||
"STRIPE_CREDIT_BUYING_ENABLED. The old key will be removed in a "
|
||||
"future release."
|
||||
)
|
||||
|
||||
# Anonymous / no-login mode settings
|
||||
NOLOGIN_MODE_ENABLED = os.getenv("NOLOGIN_MODE_ENABLED", "FALSE").upper() == "TRUE"
|
||||
|
|
@ -903,9 +942,6 @@ class Config:
|
|||
# ETL Service
|
||||
ETL_SERVICE = os.getenv("ETL_SERVICE")
|
||||
|
||||
# Pages limit for ETL services (default to very high number for OSS unlimited usage)
|
||||
PAGES_LIMIT = int(os.getenv("PAGES_LIMIT", "999999999"))
|
||||
|
||||
if ETL_SERVICE == "UNSTRUCTURED":
|
||||
# Unstructured API Key
|
||||
UNSTRUCTURED_API_KEY = os.getenv("UNSTRUCTURED_API_KEY")
|
||||
|
|
|
|||
|
|
@ -114,13 +114,6 @@ class SearchSourceConnectorType(StrEnum):
|
|||
COMPOSIO_GOOGLE_CALENDAR_CONNECTOR = "COMPOSIO_GOOGLE_CALENDAR_CONNECTOR"
|
||||
|
||||
|
||||
class PodcastStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class VideoPresentationStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
GENERATING = "generating"
|
||||
|
|
@ -320,7 +313,7 @@ class PagePurchaseStatus(StrEnum):
|
|||
FAILED = "failed"
|
||||
|
||||
|
||||
class PremiumTokenPurchaseStatus(StrEnum):
|
||||
class CreditPurchaseStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
|
@ -332,26 +325,27 @@ INCENTIVE_TASKS_CONFIG = {
|
|||
IncentiveTaskType.GITHUB_STAR: {
|
||||
"title": "Star our GitHub repository",
|
||||
"description": "Show your support by starring SurfSense on GitHub",
|
||||
"pages_reward": 30,
|
||||
# Credit reward in USD micro-units (1_000_000 == $1.00). $0.03.
|
||||
"credit_micros_reward": 30000,
|
||||
"action_url": "https://github.com/MODSetter/SurfSense",
|
||||
},
|
||||
IncentiveTaskType.REDDIT_FOLLOW: {
|
||||
"title": "Join our Subreddit",
|
||||
"description": "Join the SurfSense community on Reddit",
|
||||
"pages_reward": 30,
|
||||
"credit_micros_reward": 30000,
|
||||
"action_url": "https://www.reddit.com/r/SurfSense/",
|
||||
},
|
||||
IncentiveTaskType.DISCORD_JOIN: {
|
||||
"title": "Join our Discord",
|
||||
"description": "Join the SurfSense community on Discord",
|
||||
"pages_reward": 40,
|
||||
"credit_micros_reward": 40000,
|
||||
"action_url": "https://discord.gg/ejRNvftDp9",
|
||||
},
|
||||
# Future tasks can be configured here:
|
||||
# IncentiveTaskType.GITHUB_ISSUE: {
|
||||
# "title": "Create an issue",
|
||||
# "description": "Help improve SurfSense by reporting bugs or suggesting features",
|
||||
# "pages_reward": 50,
|
||||
# "credit_micros_reward": 50000,
|
||||
# "action_url": "https://github.com/MODSetter/SurfSense/issues/new/choose",
|
||||
# },
|
||||
}
|
||||
|
|
@ -1536,41 +1530,6 @@ class Chunk(BaseModel, TimestampMixin):
|
|||
document = relationship("Document", back_populates="chunks")
|
||||
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
"""Podcast model for storing generated podcasts."""
|
||||
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
podcast_transcript = Column(JSONB, nullable=True)
|
||||
file_location = Column(Text, nullable=True)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
PodcastStatus,
|
||||
name="podcast_status",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=PodcastStatus.READY,
|
||||
server_default="ready",
|
||||
index=True,
|
||||
)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
thread = relationship("NewChatThread")
|
||||
|
||||
|
||||
class VideoPresentation(BaseModel, TimestampMixin):
|
||||
"""Video presentation model for storing AI-generated video presentations.
|
||||
|
||||
|
|
@ -2069,7 +2028,7 @@ class UserIncentiveTask(BaseModel, TimestampMixin):
|
|||
"""
|
||||
Tracks completed incentive tasks for users.
|
||||
Each user can only complete each task type once.
|
||||
When a task is completed, the user's pages_limit is increased.
|
||||
When a task is completed, the user's credit_micros_balance is increased.
|
||||
"""
|
||||
|
||||
__tablename__ = "user_incentive_tasks"
|
||||
|
|
@ -2088,7 +2047,8 @@ class UserIncentiveTask(BaseModel, TimestampMixin):
|
|||
index=True,
|
||||
)
|
||||
task_type = Column(SQLAlchemyEnum(IncentiveTaskType), nullable=False, index=True)
|
||||
pages_awarded = Column(Integer, nullable=False)
|
||||
# Credit reward granted in USD micro-units (1_000_000 == $1.00).
|
||||
credit_micros_awarded = Column(BigInteger, nullable=False)
|
||||
completed_at = Column(
|
||||
TIMESTAMP(timezone=True),
|
||||
nullable=False,
|
||||
|
|
@ -2131,18 +2091,18 @@ class PagePurchase(Base, TimestampMixin):
|
|||
user = relationship("User", back_populates="page_purchases")
|
||||
|
||||
|
||||
class PremiumTokenPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant additional premium credit (USD micro-units).
|
||||
class CreditPurchase(Base, TimestampMixin):
|
||||
"""Tracks Stripe checkout sessions used to grant credit (USD micro-units).
|
||||
|
||||
Note: the table name is preserved (``premium_token_purchases``) for
|
||||
operational continuity even though the unit is now USD micro-credits
|
||||
instead of raw tokens. The ``credit_micros_granted`` column replaced
|
||||
the legacy ``tokens_granted`` in migration 140; the stored values
|
||||
were not transformed because the prior $1 = 1M tokens Stripe price
|
||||
makes the unit conversion 1:1 numerically.
|
||||
Renamed from ``premium_token_purchases`` in migration 156 as part of the
|
||||
unified-credits wallet. ``credit_micros_granted`` stores the USD-micro
|
||||
amount added to ``user.credit_micros_balance`` on fulfillment.
|
||||
|
||||
``source`` distinguishes a user-initiated checkout from an automatic
|
||||
off-session top-up (auto-reload), added in the auto-reload migration.
|
||||
"""
|
||||
|
||||
__tablename__ = "premium_token_purchases"
|
||||
__tablename__ = "credit_purchases"
|
||||
__allow_unmapped__ = True
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
|
@ -2160,15 +2120,18 @@ class PremiumTokenPurchase(Base, TimestampMixin):
|
|||
credit_micros_granted = Column(BigInteger, nullable=False)
|
||||
amount_total = Column(Integer, nullable=True)
|
||||
currency = Column(String(10), nullable=True)
|
||||
source = Column(
|
||||
String(20), nullable=False, default="checkout", server_default="checkout"
|
||||
)
|
||||
status = Column(
|
||||
SQLAlchemyEnum(PremiumTokenPurchaseStatus),
|
||||
SQLAlchemyEnum(CreditPurchaseStatus),
|
||||
nullable=False,
|
||||
default=PremiumTokenPurchaseStatus.PENDING,
|
||||
default=CreditPurchaseStatus.PENDING,
|
||||
index=True,
|
||||
)
|
||||
completed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
user = relationship("User", back_populates="premium_token_purchases")
|
||||
user = relationship("User", back_populates="credit_purchases")
|
||||
|
||||
|
||||
class SearchSpaceRole(BaseModel, TimestampMixin):
|
||||
|
|
@ -2448,33 +2411,40 @@ if config.AUTH_TYPE == "GOOGLE":
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
credit_purchases = relationship(
|
||||
"CreditPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=config.PAGES_LIMIT,
|
||||
server_default=str(config.PAGES_LIMIT),
|
||||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_credit_micros_limit = Column(
|
||||
# Unified credit wallet (USD micro-units, 1_000_000 == $1.00).
|
||||
# Decreases on use (ETL pages + premium model calls), increases on
|
||||
# purchase / incentive grant / auto-reload. May dip slightly negative
|
||||
# when an actual cost exceeds its pre-charge estimate; UI clamps at $0.
|
||||
credit_micros_balance = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
default=config.DEFAULT_CREDIT_MICROS_BALANCE,
|
||||
server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE),
|
||||
)
|
||||
premium_credit_micros_used = Column(
|
||||
# In-flight reservation holds (released/settled at finalize).
|
||||
credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
|
||||
# Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED.
|
||||
# ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the
|
||||
# saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at``
|
||||
# is set (and ``auto_reload_enabled`` flipped off) when an off-session
|
||||
# charge is declined so the UI can prompt the user to fix their card.
|
||||
stripe_customer_id = Column(String, nullable=True)
|
||||
auto_reload_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
auto_reload_threshold_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_amount_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_payment_method_id = Column(String, nullable=True)
|
||||
auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# User profile from OAuth
|
||||
display_name = Column(String, nullable=True)
|
||||
|
|
@ -2587,33 +2557,40 @@ else:
|
|||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
premium_token_purchases = relationship(
|
||||
"PremiumTokenPurchase",
|
||||
credit_purchases = relationship(
|
||||
"CreditPurchase",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
# Page usage tracking for ETL services
|
||||
pages_limit = Column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
default=config.PAGES_LIMIT,
|
||||
server_default=str(config.PAGES_LIMIT),
|
||||
)
|
||||
pages_used = Column(Integer, nullable=False, default=0, server_default="0")
|
||||
|
||||
premium_credit_micros_limit = Column(
|
||||
# Unified credit wallet (USD micro-units, 1_000_000 == $1.00).
|
||||
# Decreases on use (ETL pages + premium model calls), increases on
|
||||
# purchase / incentive grant / auto-reload. May dip slightly negative
|
||||
# when an actual cost exceeds its pre-charge estimate; UI clamps at $0.
|
||||
credit_micros_balance = Column(
|
||||
BigInteger,
|
||||
nullable=False,
|
||||
default=config.PREMIUM_CREDIT_MICROS_LIMIT,
|
||||
server_default=str(config.PREMIUM_CREDIT_MICROS_LIMIT),
|
||||
default=config.DEFAULT_CREDIT_MICROS_BALANCE,
|
||||
server_default=str(config.DEFAULT_CREDIT_MICROS_BALANCE),
|
||||
)
|
||||
premium_credit_micros_used = Column(
|
||||
# In-flight reservation holds (released/settled at finalize).
|
||||
credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
)
|
||||
premium_credit_micros_reserved = Column(
|
||||
BigInteger, nullable=False, default=0, server_default="0"
|
||||
|
||||
# Auto-reload (off-session Stripe top-up), behind AUTO_RELOAD_ENABLED.
|
||||
# ``stripe_customer_id`` + ``auto_reload_payment_method_id`` are the
|
||||
# saved-card plumbing; thresholds are micro-USD. ``auto_reload_failed_at``
|
||||
# is set (and ``auto_reload_enabled`` flipped off) when an off-session
|
||||
# charge is declined so the UI can prompt the user to fix their card.
|
||||
stripe_customer_id = Column(String, nullable=True)
|
||||
auto_reload_enabled = Column(
|
||||
Boolean, nullable=False, default=False, server_default="false"
|
||||
)
|
||||
auto_reload_threshold_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_amount_micros = Column(BigInteger, nullable=True)
|
||||
auto_reload_payment_method_id = Column(String, nullable=True)
|
||||
auto_reload_failed_at = Column(TIMESTAMP(timezone=True), nullable=True)
|
||||
|
||||
# User profile (can be set manually for non-OAuth users)
|
||||
display_name = Column(String, nullable=True)
|
||||
|
|
@ -2889,6 +2866,10 @@ from app.automations.persistence import ( # noqa: E402, F401
|
|||
)
|
||||
from app.file_storage.persistence import DocumentFile # noqa: E402, F401
|
||||
from app.notifications.persistence import Notification # noqa: E402, F401
|
||||
from app.podcasts.persistence import ( # noqa: E402, F401
|
||||
Podcast,
|
||||
PodcastStatus,
|
||||
)
|
||||
|
||||
engine = create_async_engine(
|
||||
DATABASE_URL,
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ async def list_notifications(
|
|||
query = query.where(unread_filter)
|
||||
count_query = count_query.where(unread_filter)
|
||||
elif filter == "errors":
|
||||
error_filter = (Notification.type == "page_limit_exceeded") | (
|
||||
error_filter = (Notification.type == "insufficient_credits") | (
|
||||
Notification.notification_metadata["status"].astext == "failed"
|
||||
)
|
||||
query = query.where(error_filter)
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ CATEGORY_TYPES: dict[str, tuple[str, ...]] = {
|
|||
"connector_indexing",
|
||||
"connector_deletion",
|
||||
"document_processing",
|
||||
"page_limit_exceeded",
|
||||
"insufficient_credits",
|
||||
"auto_reload_failed",
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,11 +10,12 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.service.handlers import (
|
||||
AutoReloadFailedNotificationHandler,
|
||||
CommentReplyNotificationHandler,
|
||||
ConnectorIndexingNotificationHandler,
|
||||
DocumentProcessingNotificationHandler,
|
||||
InsufficientCreditsNotificationHandler,
|
||||
MentionNotificationHandler,
|
||||
PageLimitNotificationHandler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
|
@ -27,7 +28,8 @@ class NotificationService:
|
|||
document_processing = DocumentProcessingNotificationHandler()
|
||||
mention = MentionNotificationHandler()
|
||||
comment_reply = CommentReplyNotificationHandler()
|
||||
page_limit = PageLimitNotificationHandler()
|
||||
insufficient_credits = InsufficientCreditsNotificationHandler()
|
||||
auto_reload_failed = AutoReloadFailedNotificationHandler()
|
||||
|
||||
@staticmethod
|
||||
async def create_notification(
|
||||
|
|
|
|||
|
|
@ -2,16 +2,18 @@
|
|||
|
||||
from __future__ import annotations
|
||||
|
||||
from .auto_reload_failed import AutoReloadFailedNotificationHandler
|
||||
from .comment_reply import CommentReplyNotificationHandler
|
||||
from .connector_indexing import ConnectorIndexingNotificationHandler
|
||||
from .document_processing import DocumentProcessingNotificationHandler
|
||||
from .insufficient_credits import InsufficientCreditsNotificationHandler
|
||||
from .mention import MentionNotificationHandler
|
||||
from .page_limit import PageLimitNotificationHandler
|
||||
|
||||
__all__ = [
|
||||
"AutoReloadFailedNotificationHandler",
|
||||
"CommentReplyNotificationHandler",
|
||||
"ConnectorIndexingNotificationHandler",
|
||||
"DocumentProcessingNotificationHandler",
|
||||
"InsufficientCreditsNotificationHandler",
|
||||
"MentionNotificationHandler",
|
||||
"PageLimitNotificationHandler",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,54 @@
|
|||
"""Notifications for failed off-session credit auto-reload charges."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.service.base import BaseNotificationHandler
|
||||
from app.notifications.service.messages import auto_reload_failed as msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoReloadFailedNotificationHandler(BaseNotificationHandler):
|
||||
"""Notifications for declined auto-reload top-ups."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("auto_reload_failed")
|
||||
|
||||
async def notify_auto_reload_failed(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
amount_micros: int,
|
||||
payment_intent_id: str | None = None,
|
||||
reason: str | None = None,
|
||||
) -> Notification:
|
||||
"""Notify that an off-session auto-reload charge was declined.
|
||||
|
||||
Not tied to a search space (``search_space_id`` is None); the action
|
||||
links to the billing settings so the user can fix their card.
|
||||
"""
|
||||
op_id = msg.operation_id(payment_intent_id or "")
|
||||
title, message = msg.summary(amount_micros, reason)
|
||||
|
||||
return await self.find_or_create_notification(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
operation_id=op_id,
|
||||
title=title,
|
||||
message=message,
|
||||
search_space_id=None,
|
||||
initial_metadata={
|
||||
"amount_micros": amount_micros,
|
||||
"payment_intent_id": payment_intent_id,
|
||||
"status": "failed",
|
||||
"error_type": "auto_reload_failed",
|
||||
"action_url": "/dashboard",
|
||||
"action_label": "Update card",
|
||||
},
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
"""Notifications for exceeding the page limit."""
|
||||
"""Notifications for running out of credit during document processing."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -9,46 +9,42 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from app.notifications.persistence import Notification
|
||||
from app.notifications.service.base import BaseNotificationHandler
|
||||
from app.notifications.service.messages import page_limit as msg
|
||||
from app.notifications.service.messages import insufficient_credits as msg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PageLimitNotificationHandler(BaseNotificationHandler):
|
||||
"""Notifications for exceeding the page limit."""
|
||||
class InsufficientCreditsNotificationHandler(BaseNotificationHandler):
|
||||
"""Notifications for running out of credit during document processing."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("page_limit_exceeded")
|
||||
super().__init__("insufficient_credits")
|
||||
|
||||
async def notify_page_limit_exceeded(
|
||||
async def notify_insufficient_credits(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
user_id: UUID,
|
||||
document_name: str,
|
||||
document_type: str,
|
||||
search_space_id: int,
|
||||
pages_used: int,
|
||||
pages_limit: int,
|
||||
pages_to_add: int,
|
||||
balance_micros: int,
|
||||
required_micros: int,
|
||||
) -> Notification:
|
||||
"""Notify that a document was blocked by the page limit."""
|
||||
"""Notify that a document was blocked by insufficient credit."""
|
||||
operation_id = msg.operation_id(document_name, search_space_id)
|
||||
title, message = msg.summary(
|
||||
document_name, pages_used, pages_limit, pages_to_add
|
||||
)
|
||||
title, message = msg.summary(document_name, balance_micros, required_micros)
|
||||
|
||||
metadata = {
|
||||
"operation_id": operation_id,
|
||||
"document_name": document_name,
|
||||
"document_type": document_type,
|
||||
"pages_used": pages_used,
|
||||
"pages_limit": pages_limit,
|
||||
"pages_to_add": pages_to_add,
|
||||
"balance_micros": balance_micros,
|
||||
"required_micros": required_micros,
|
||||
"status": "failed",
|
||||
"error_type": "page_limit_exceeded",
|
||||
"error_type": "insufficient_credits",
|
||||
# Where the inbox item links to.
|
||||
"action_url": f"/dashboard/{search_space_id}/more-pages",
|
||||
"action_label": "Upgrade Plan",
|
||||
"action_url": f"/dashboard/{search_space_id}/buy-more",
|
||||
"action_label": "Buy credits",
|
||||
}
|
||||
|
||||
notification = Notification(
|
||||
|
|
@ -63,6 +59,7 @@ class PageLimitNotificationHandler(BaseNotificationHandler):
|
|||
await session.commit()
|
||||
await session.refresh(notification)
|
||||
logger.info(
|
||||
f"Created page_limit_exceeded notification {notification.id} for user {user_id}"
|
||||
f"Created insufficient_credits notification {notification.id} "
|
||||
f"for user {user_id}"
|
||||
)
|
||||
return notification
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
"""Pure presentation logic for auto-reload-failure notifications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def operation_id(payment_intent_id: str) -> str:
|
||||
"""Build a unique id for an auto-reload-failure notification.
|
||||
|
||||
Keyed on the failed PaymentIntent so retries of the same charge collapse
|
||||
into a single inbox item rather than spamming the user.
|
||||
"""
|
||||
if payment_intent_id:
|
||||
return f"auto_reload_failed_{payment_intent_id}"
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f")
|
||||
return f"auto_reload_failed_{timestamp}"
|
||||
|
||||
|
||||
def summary(amount_micros: int, reason: str | None) -> tuple[str, str]:
|
||||
"""Compute the title and message for a failed off-session auto-reload charge."""
|
||||
amount_usd = max(0, amount_micros) / 1_000_000
|
||||
title = "Auto-reload failed"
|
||||
base = (
|
||||
f"We couldn't automatically add ${amount_usd:.2f} of credit because your "
|
||||
"saved card was declined. Auto-reload has been turned off — update your "
|
||||
"card and re-enable it to keep topping up automatically."
|
||||
)
|
||||
if reason:
|
||||
base = f"{base} (Reason: {reason}.)"
|
||||
return title, base
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
"""Pure presentation logic for insufficient-credit notifications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.notifications.service.messages.text import truncate
|
||||
|
||||
|
||||
def operation_id(document_name: str, search_space_id: int) -> str:
|
||||
"""Build a unique id for an insufficient-credits notification."""
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f")
|
||||
doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8]
|
||||
return f"insufficient_credits_{search_space_id}_{timestamp}_{doc_hash}"
|
||||
|
||||
|
||||
def summary(
|
||||
document_name: str, balance_micros: int, required_micros: int
|
||||
) -> tuple[str, str]:
|
||||
"""Compute the title and message for a blocked-by-insufficient-credits document."""
|
||||
display_name = truncate(document_name, 40)
|
||||
title = f"Insufficient credits: {display_name}"
|
||||
balance_usd = max(0, balance_micros) / 1_000_000
|
||||
required_usd = max(0, required_micros) / 1_000_000
|
||||
message = (
|
||||
f"This document costs about ${required_usd:.2f} to process but you have "
|
||||
f"${balance_usd:.2f} of credit left. Add more credits to continue."
|
||||
)
|
||||
return title, message
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
"""Pure presentation logic for page-limit notifications."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.notifications.service.messages.text import truncate
|
||||
|
||||
|
||||
def operation_id(document_name: str, search_space_id: int) -> str:
|
||||
"""Build a unique id for a page-limit notification."""
|
||||
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S_%f")
|
||||
doc_hash = hashlib.md5(document_name.encode()).hexdigest()[:8]
|
||||
return f"page_limit_{search_space_id}_{timestamp}_{doc_hash}"
|
||||
|
||||
|
||||
def summary(
|
||||
document_name: str, pages_used: int, pages_limit: int, pages_to_add: int
|
||||
) -> tuple[str, str]:
|
||||
"""Compute the title and message for a blocked-by-page-limit document."""
|
||||
display_name = truncate(document_name, 40)
|
||||
title = f"Page limit exceeded: {display_name}"
|
||||
message = f"This document has ~{pages_to_add} page(s) but you've used {pages_used}/{pages_limit} pages. Upgrade to process more documents."
|
||||
return title, message
|
||||
|
|
@ -10,7 +10,8 @@ NotificationType = Literal[
|
|||
"document_processing",
|
||||
"new_mention",
|
||||
"comment_reply",
|
||||
"page_limit_exceeded",
|
||||
"insufficient_credits",
|
||||
"auto_reload_failed",
|
||||
]
|
||||
|
||||
NotificationCategory = Literal["comments", "status"]
|
||||
|
|
|
|||
9
surfsense_backend/app/podcasts/__init__.py
Normal file
9
surfsense_backend/app/podcasts/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Podcast feature: brief resolution, transcript drafting, and audio rendering.
|
||||
|
||||
Owns the ``podcasts`` table model, which :mod:`app.db` re-exports so existing
|
||||
``from app.db import Podcast`` imports keep resolving.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
__all__: list[str] = []
|
||||
7
surfsense_backend/app/podcasts/api/__init__.py
Normal file
7
surfsense_backend/app/podcasts/api/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""HTTP API for the podcast lifecycle."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .routes import router
|
||||
|
||||
__all__ = ["router"]
|
||||
337
surfsense_backend/app/podcasts/api/routes.py
Normal file
337
surfsense_backend/app/podcasts/api/routes.py
Normal file
|
|
@ -0,0 +1,337 @@
|
|||
"""HTTP surface for the podcast lifecycle.
|
||||
|
||||
Status is observed by the frontend through Zero, so these routes are about
|
||||
actions (create, edit/approve the brief, regenerate, cancel) and audio delivery.
|
||||
Each mutating route performs the guarded transition via the service, commits,
|
||||
then enqueues the matching Celery task; lifecycle errors map to 409/422.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.db import (
|
||||
Permission,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.podcasts.generation.brief import propose_brief
|
||||
from app.podcasts.persistence import Podcast, PodcastRepository
|
||||
from app.podcasts.service import (
|
||||
InvalidTransitionError,
|
||||
PodcastService,
|
||||
PreconditionFailedError,
|
||||
SpecConflictError,
|
||||
)
|
||||
from app.podcasts.storage import open_audio_stream, purge_audio
|
||||
from app.podcasts.tasks import draft_transcript_task
|
||||
from app.podcasts.tts import get_text_to_speech
|
||||
from app.podcasts.voices import (
|
||||
get_voice_catalog,
|
||||
provider_from_service,
|
||||
render_voice_preview,
|
||||
)
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
from .schemas import (
|
||||
CreatePodcastRequest,
|
||||
PodcastDetail,
|
||||
PodcastSummary,
|
||||
UpdateSpecRequest,
|
||||
VoiceOption,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/podcasts", response_model=list[PodcastSummary])
|
||||
async def list_podcasts(
|
||||
search_space_id: int | None = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
|
||||
if search_space_id is not None:
|
||||
await _require(session, user, search_space_id, Permission.PODCASTS_READ)
|
||||
query = (
|
||||
select(Podcast)
|
||||
.where(Podcast.search_space_id == search_space_id)
|
||||
.order_by(Podcast.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
query = (
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.where(SearchSpaceMembership.user_id == user.id)
|
||||
.order_by(Podcast.created_at.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
result = await session.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/podcasts/voices", response_model=list[VoiceOption])
|
||||
async def list_voices(language: str | None = None):
|
||||
"""Voices the active TTS provider offers, optionally filtered by language."""
|
||||
if not app_config.TTS_SERVICE:
|
||||
raise HTTPException(status_code=503, detail="No TTS provider configured")
|
||||
|
||||
provider = provider_from_service(app_config.TTS_SERVICE)
|
||||
catalog = get_voice_catalog()
|
||||
voices = (
|
||||
catalog.for_language(provider, language)
|
||||
if language
|
||||
else catalog.for_provider(provider)
|
||||
)
|
||||
return [
|
||||
VoiceOption(
|
||||
voice_id=v.voice_id,
|
||||
display_name=v.display_name,
|
||||
language=v.language,
|
||||
gender=v.gender.value,
|
||||
)
|
||||
for v in voices
|
||||
]
|
||||
|
||||
|
||||
@router.get("/podcasts/voices/{voice_id}/preview")
|
||||
async def preview_voice(
|
||||
voice_id: str,
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""A short audio sample of a voice, so users pick by sound."""
|
||||
if not app_config.TTS_SERVICE:
|
||||
raise HTTPException(status_code=503, detail="No TTS provider configured")
|
||||
|
||||
provider = provider_from_service(app_config.TTS_SERVICE)
|
||||
try:
|
||||
voice = get_voice_catalog().get(voice_id)
|
||||
except KeyError:
|
||||
raise HTTPException(status_code=404, detail="Unknown voice") from None
|
||||
if voice.provider is not provider:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Voice not offered by the active TTS provider"
|
||||
)
|
||||
|
||||
data, content_type = await render_voice_preview(voice, get_text_to_speech())
|
||||
return Response(content=data, media_type=content_type)
|
||||
|
||||
|
||||
@router.post("/podcasts", response_model=PodcastDetail, status_code=201)
|
||||
async def create_podcast(
|
||||
body: CreatePodcastRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
await _require(session, user, body.search_space_id, Permission.PODCASTS_CREATE)
|
||||
|
||||
service = PodcastService(session)
|
||||
podcast = await service.create(
|
||||
title=body.title,
|
||||
search_space_id=body.search_space_id,
|
||||
thread_id=body.thread_id,
|
||||
)
|
||||
podcast.source_content = body.source_content
|
||||
|
||||
spec = await propose_brief(
|
||||
session,
|
||||
search_space_id=body.search_space_id,
|
||||
speaker_count=body.speaker_count,
|
||||
min_minutes=body.min_minutes,
|
||||
max_minutes=body.max_minutes,
|
||||
focus=body.focus,
|
||||
)
|
||||
await service.attach_brief(podcast, spec)
|
||||
await session.commit()
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastDetail)
|
||||
async def get_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.patch("/podcasts/{podcast_id}/spec", response_model=PodcastDetail)
|
||||
async def update_spec(
|
||||
podcast_id: int,
|
||||
body: UpdateSpecRequest,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).update_spec(
|
||||
podcast, body.spec, body.expected_version
|
||||
)
|
||||
await session.commit()
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.post("/podcasts/{podcast_id}/brief/approve", response_model=PodcastDetail)
|
||||
async def approve_brief(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Approve the brief and start drafting the transcript."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).begin_drafting(podcast)
|
||||
await session.commit()
|
||||
draft_transcript_task.delay(podcast.id, podcast.search_space_id)
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/podcasts/{podcast_id}/transcript/regenerate", response_model=PodcastDetail
|
||||
)
|
||||
async def regenerate_transcript(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Reopen the brief gate for a fresh take; drafting waits for re-approval."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).regenerate(podcast)
|
||||
await session.commit()
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.post("/podcasts/{podcast_id}/regenerate/revert", response_model=PodcastDetail)
|
||||
async def revert_regeneration(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""Back out of a regeneration and return to the finished episode."""
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).revert_regeneration(podcast)
|
||||
await session.commit()
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.post("/podcasts/{podcast_id}/cancel", response_model=PodcastDetail)
|
||||
async def cancel_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_UPDATE)
|
||||
async with _lifecycle_errors():
|
||||
await PodcastService(session).cancel(podcast)
|
||||
await session.commit()
|
||||
return PodcastDetail.of(podcast)
|
||||
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_DELETE)
|
||||
await purge_audio(podcast)
|
||||
await session.delete(podcast)
|
||||
await session.commit()
|
||||
return {"message": "Podcast deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}/stream")
|
||||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
podcast = await _load(session, user, podcast_id, Permission.PODCASTS_READ)
|
||||
|
||||
if podcast.storage_key:
|
||||
return StreamingResponse(
|
||||
open_audio_stream(podcast),
|
||||
media_type="audio/mpeg",
|
||||
headers={"Accept-Ranges": "bytes"},
|
||||
)
|
||||
|
||||
# Back-compat: rows rendered before the storage migration kept a local path.
|
||||
if podcast.file_location and os.path.isfile(podcast.file_location):
|
||||
path = podcast.file_location
|
||||
|
||||
def iterfile():
|
||||
with open(path, mode="rb") as handle:
|
||||
yield from handle
|
||||
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={Path(path).name}",
|
||||
},
|
||||
)
|
||||
|
||||
raise HTTPException(status_code=404, detail="Podcast audio not found")
|
||||
|
||||
|
||||
async def _require(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
search_space_id: int,
|
||||
permission: Permission,
|
||||
) -> None:
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
permission.value,
|
||||
"You don't have permission for podcasts in this search space",
|
||||
)
|
||||
|
||||
|
||||
async def _load(
|
||||
session: AsyncSession,
|
||||
user: User,
|
||||
podcast_id: int,
|
||||
permission: Permission,
|
||||
) -> Podcast:
|
||||
podcast = await PodcastRepository(session).get(podcast_id)
|
||||
if podcast is None:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
await _require(session, user, podcast.search_space_id, permission)
|
||||
return podcast
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _lifecycle_errors() -> AsyncIterator[None]:
|
||||
"""Map service lifecycle errors onto HTTP responses."""
|
||||
try:
|
||||
yield
|
||||
except (SpecConflictError, InvalidTransitionError) as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except PreconditionFailedError as exc:
|
||||
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
||||
97
surfsense_backend/app/podcasts/api/schemas.py
Normal file
97
surfsense_backend/app/podcasts/api/schemas.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
"""Request and response shapes for the podcast API.
|
||||
|
||||
Read models surface the lifecycle state the frontend can't derive from Zero (the
|
||||
deserialized brief and transcript); the action requests carry just what each
|
||||
guarded transition needs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from app.podcasts.persistence import Podcast, PodcastStatus
|
||||
from app.podcasts.schemas import PodcastSpec, Transcript
|
||||
from app.podcasts.service import has_stored_episode, read_spec, read_transcript
|
||||
|
||||
# Defaults applied when a create request omits brief sizing; the brief gate lets
|
||||
# the user adjust before any cost is incurred.
|
||||
DEFAULT_SPEAKER_COUNT = 2
|
||||
DEFAULT_MIN_MINUTES = 10
|
||||
DEFAULT_MAX_MINUTES = 20
|
||||
|
||||
|
||||
class CreatePodcastRequest(BaseModel):
|
||||
"""Create a podcast and kick off brief proposal."""
|
||||
|
||||
title: str = Field(..., min_length=1, max_length=500)
|
||||
search_space_id: int
|
||||
source_content: str = Field(..., min_length=1)
|
||||
thread_id: int | None = None
|
||||
speaker_count: int = Field(default=DEFAULT_SPEAKER_COUNT, ge=1, le=6)
|
||||
min_minutes: int = Field(default=DEFAULT_MIN_MINUTES, ge=1)
|
||||
max_minutes: int = Field(default=DEFAULT_MAX_MINUTES, ge=1)
|
||||
focus: str | None = Field(default=None, max_length=2000)
|
||||
|
||||
|
||||
class UpdateSpecRequest(BaseModel):
|
||||
"""Replace the brief at the gate, guarded by the expected version."""
|
||||
|
||||
spec: PodcastSpec
|
||||
expected_version: int = Field(..., ge=1)
|
||||
|
||||
|
||||
class VoiceOption(BaseModel):
|
||||
"""One selectable voice surfaced to the brief editor."""
|
||||
|
||||
voice_id: str
|
||||
display_name: str
|
||||
language: str
|
||||
gender: str
|
||||
|
||||
|
||||
class PodcastSummary(BaseModel):
|
||||
"""Lightweight list item."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
title: str
|
||||
status: PodcastStatus
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
|
||||
|
||||
class PodcastDetail(BaseModel):
|
||||
"""Full podcast state for the detail view and action responses."""
|
||||
|
||||
id: int
|
||||
title: str
|
||||
status: PodcastStatus
|
||||
spec_version: int
|
||||
spec: PodcastSpec | None
|
||||
transcript: Transcript | None
|
||||
has_audio: bool
|
||||
duration_seconds: int | None
|
||||
error: str | None
|
||||
created_at: datetime
|
||||
search_space_id: int
|
||||
thread_id: int | None
|
||||
|
||||
@classmethod
|
||||
def of(cls, podcast: Podcast) -> PodcastDetail:
|
||||
return cls(
|
||||
id=podcast.id,
|
||||
title=podcast.title,
|
||||
status=PodcastStatus(podcast.status),
|
||||
spec_version=podcast.spec_version,
|
||||
spec=read_spec(podcast),
|
||||
transcript=read_transcript(podcast),
|
||||
has_audio=has_stored_episode(podcast),
|
||||
duration_seconds=podcast.duration_seconds,
|
||||
error=podcast.error,
|
||||
created_at=podcast.created_at,
|
||||
search_space_id=podcast.search_space_id,
|
||||
thread_id=podcast.thread_id,
|
||||
)
|
||||
19
surfsense_backend/app/podcasts/generation/__init__.py
Normal file
19
surfsense_backend/app/podcasts/generation/__init__.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""Generation: the controlled graphs that produce a brief and a transcript.
|
||||
|
||||
``brief`` proposes a reviewable spec from deterministic defaults; ``transcript``
|
||||
is the LLM-driven step, drafting long-form dialogue outline-first.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .brief import BriefConfig, BriefState, build_brief_graph
|
||||
from .transcript import TranscriptConfig, TranscriptState, build_transcript_graph
|
||||
|
||||
__all__ = [
|
||||
"BriefConfig",
|
||||
"BriefState",
|
||||
"TranscriptConfig",
|
||||
"TranscriptState",
|
||||
"build_brief_graph",
|
||||
"build_transcript_graph",
|
||||
]
|
||||
10
surfsense_backend/app/podcasts/generation/brief/__init__.py
Normal file
10
surfsense_backend/app/podcasts/generation/brief/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Brief planning: propose a reviewable spec from last-used preferences."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import BriefConfig
|
||||
from .graph import build_brief_graph
|
||||
from .propose import propose_brief
|
||||
from .state import BriefState
|
||||
|
||||
__all__ = ["BriefConfig", "BriefState", "build_brief_graph", "propose_brief"]
|
||||
30
surfsense_backend/app/podcasts/generation/brief/config.py
Normal file
30
surfsense_backend/app/podcasts/generation/brief/config.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
"""Configurable inputs for the brief-planning graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
# Sensible defaults for a fresh brief; the user adjusts the range at the gate.
|
||||
DEFAULT_SPEAKER_COUNT = 2
|
||||
DEFAULT_MIN_MINUTES = 10
|
||||
DEFAULT_MAX_MINUTES = 20
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class BriefConfig:
|
||||
"""Signals used to propose a brief; everything here is non-LLM context."""
|
||||
|
||||
speaker_count: int = DEFAULT_SPEAKER_COUNT
|
||||
min_minutes: int = DEFAULT_MIN_MINUTES
|
||||
max_minutes: int = DEFAULT_MAX_MINUTES
|
||||
focus: str | None = None
|
||||
last_used_language: str | None = None
|
||||
last_used_voices: list[str] = field(default_factory=list)
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(cls, config: RunnableConfig | None = None) -> BriefConfig:
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
names = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in names})
|
||||
25
surfsense_backend/app/podcasts/generation/brief/graph.py
Normal file
25
surfsense_backend/app/podcasts/generation/brief/graph.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
"""The brief-planning graph: propose a reviewable spec from defaults."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from .config import BriefConfig
|
||||
from .nodes import propose_spec
|
||||
from .state import BriefState
|
||||
|
||||
|
||||
def build_brief_graph():
|
||||
workflow = StateGraph(BriefState, config_schema=BriefConfig)
|
||||
|
||||
workflow.add_node("propose_spec", propose_spec)
|
||||
|
||||
workflow.add_edge("__start__", "propose_spec")
|
||||
workflow.add_edge("propose_spec", "__end__")
|
||||
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcast Brief"
|
||||
return graph
|
||||
|
||||
|
||||
graph = build_brief_graph()
|
||||
119
surfsense_backend/app/podcasts/generation/brief/nodes.py
Normal file
119
surfsense_backend/app/podcasts/generation/brief/nodes.py
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
"""Brief-planning node: propose a full spec from deterministic defaults.
|
||||
|
||||
``propose_spec`` is pure resolution — it never spends tokens. It reuses the
|
||||
user's last-used language/voices when available and otherwise falls back to
|
||||
English, so the brief gate opens pre-filled and the common case needs no edits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.config import config as app_config
|
||||
from app.podcasts.resolution import (
|
||||
DEFAULT_LANGUAGE,
|
||||
LanguageContext,
|
||||
resolve_language,
|
||||
resolve_voices,
|
||||
)
|
||||
from app.podcasts.schemas import (
|
||||
DurationTarget,
|
||||
PodcastSpec,
|
||||
PodcastStyle,
|
||||
SpeakerRole,
|
||||
SpeakerSpec,
|
||||
normalize_language_tag,
|
||||
)
|
||||
from app.podcasts.voices import (
|
||||
TtsProvider,
|
||||
VoiceCatalog,
|
||||
get_voice_catalog,
|
||||
provider_from_service,
|
||||
)
|
||||
|
||||
from .config import BriefConfig
|
||||
from .state import BriefState
|
||||
|
||||
# Default role per speaker slot; extra speakers beyond the list fall back to guest.
|
||||
_ROLE_BY_SLOT = (
|
||||
SpeakerRole.HOST,
|
||||
SpeakerRole.GUEST,
|
||||
SpeakerRole.EXPERT,
|
||||
SpeakerRole.COHOST,
|
||||
SpeakerRole.NARRATOR,
|
||||
)
|
||||
|
||||
|
||||
def propose_spec(state: BriefState, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Build a complete :class:`PodcastSpec` from the resolved defaults."""
|
||||
brief = BriefConfig.from_runnable_config(config)
|
||||
provider = _active_provider()
|
||||
catalog = get_voice_catalog()
|
||||
|
||||
language = _supported_language(
|
||||
last_used=brief.last_used_language,
|
||||
provider=provider,
|
||||
catalog=catalog,
|
||||
)
|
||||
voices = resolve_voices(
|
||||
catalog=catalog,
|
||||
provider=provider,
|
||||
language=language,
|
||||
speaker_count=brief.speaker_count,
|
||||
preferred=brief.last_used_voices,
|
||||
)
|
||||
|
||||
speakers = [
|
||||
SpeakerSpec(
|
||||
slot=slot,
|
||||
name=_default_name(slot),
|
||||
role=_role_for(slot),
|
||||
voice_id=voice.voice_id,
|
||||
)
|
||||
for slot, voice in enumerate(voices)
|
||||
]
|
||||
spec = PodcastSpec(
|
||||
language=language,
|
||||
style=PodcastStyle.CONVERSATIONAL,
|
||||
speakers=speakers,
|
||||
duration=DurationTarget(
|
||||
min_minutes=brief.min_minutes, max_minutes=brief.max_minutes
|
||||
),
|
||||
focus=brief.focus,
|
||||
)
|
||||
return {"spec": spec}
|
||||
|
||||
|
||||
def _active_provider() -> TtsProvider:
|
||||
service = app_config.TTS_SERVICE
|
||||
if not service:
|
||||
raise ValueError("TTS_SERVICE is not configured")
|
||||
return provider_from_service(service)
|
||||
|
||||
|
||||
def _supported_language(
|
||||
*,
|
||||
last_used: str | None,
|
||||
provider: TtsProvider,
|
||||
catalog: VoiceCatalog,
|
||||
) -> str:
|
||||
raw = resolve_language(LanguageContext(last_used=last_used))
|
||||
try:
|
||||
language = normalize_language_tag(raw)
|
||||
except ValueError:
|
||||
language = DEFAULT_LANGUAGE
|
||||
if not catalog.supports_language(provider, language):
|
||||
return DEFAULT_LANGUAGE
|
||||
return language
|
||||
|
||||
|
||||
def _role_for(slot: int) -> SpeakerRole:
|
||||
return _ROLE_BY_SLOT[slot] if slot < len(_ROLE_BY_SLOT) else SpeakerRole.GUEST
|
||||
|
||||
|
||||
def _default_name(slot: int) -> str:
|
||||
role = _role_for(slot)
|
||||
label = role.value.replace("cohost", "co-host").title()
|
||||
return label if slot < len(_ROLE_BY_SLOT) else f"{label} {slot}"
|
||||
40
surfsense_backend/app/podcasts/generation/brief/propose.py
Normal file
40
surfsense_backend/app/podcasts/generation/brief/propose.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Propose a podcast's initial brief spec."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.podcasts.persistence import PodcastRepository
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
from app.podcasts.service import preferences_from
|
||||
|
||||
from .config import DEFAULT_MAX_MINUTES, DEFAULT_MIN_MINUTES, DEFAULT_SPEAKER_COUNT
|
||||
from .graph import graph as brief_graph
|
||||
from .state import BriefState
|
||||
|
||||
|
||||
async def propose_brief(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
search_space_id: int,
|
||||
speaker_count: int = DEFAULT_SPEAKER_COUNT,
|
||||
min_minutes: int = DEFAULT_MIN_MINUTES,
|
||||
max_minutes: int = DEFAULT_MAX_MINUTES,
|
||||
focus: str | None = None,
|
||||
) -> PodcastSpec:
|
||||
"""Reuse the last-used language and voices, else English; return the spec."""
|
||||
last_language, last_voices = preferences_from(
|
||||
await PodcastRepository(session).latest_with_spec(search_space_id)
|
||||
)
|
||||
config = {
|
||||
"configurable": {
|
||||
"speaker_count": speaker_count,
|
||||
"min_minutes": min_minutes,
|
||||
"max_minutes": max_minutes,
|
||||
"focus": focus,
|
||||
"last_used_language": last_language,
|
||||
"last_used_voices": last_voices,
|
||||
}
|
||||
}
|
||||
result = await brief_graph.ainvoke(BriefState(), config=config)
|
||||
return result["spec"]
|
||||
14
surfsense_backend/app/podcasts/generation/brief/state.py
Normal file
14
surfsense_backend/app/podcasts/generation/brief/state.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
"""Mutable state threaded through the brief-planning graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class BriefState:
|
||||
"""The proposed spec the graph produces; inputs arrive via the config."""
|
||||
|
||||
spec: PodcastSpec | None = None
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
"""Prompt builders for the generation graphs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .draft_segment import draft_segment_prompt
|
||||
from .plan_outline import plan_outline_prompt
|
||||
from .speakers import render_speaker_roster
|
||||
|
||||
__all__ = [
|
||||
"draft_segment_prompt",
|
||||
"plan_outline_prompt",
|
||||
"render_speaker_roster",
|
||||
]
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
"""Prompt for drafting one outline segment into dialogue turns.
|
||||
|
||||
Each segment is drafted on its own so long episodes stay coherent and within
|
||||
context limits. A short recap of the preceding dialogue is passed in so the new
|
||||
segment continues naturally instead of restarting. The model must write in the
|
||||
episode language and attribute every line to a real speaker slot.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
|
||||
from .speakers import render_speaker_roster
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.podcasts.generation.transcript.planning import OutlineSegment
|
||||
|
||||
|
||||
def draft_segment_prompt(
|
||||
*,
|
||||
spec: PodcastSpec,
|
||||
segment: OutlineSegment,
|
||||
position: int,
|
||||
total: int,
|
||||
recap: str | None,
|
||||
) -> str:
|
||||
talking_points = "\n".join(f"- {point}" for point in segment.talking_points)
|
||||
recap_block = (
|
||||
f"\nRecap of the conversation so far (continue from here, do not repeat "
|
||||
f"it):\n{recap}\n"
|
||||
if recap
|
||||
else "\nThis is the opening segment; begin the conversation naturally.\n"
|
||||
)
|
||||
return f"""\
|
||||
You are scripting natural, engaging podcast dialogue for segment {position} of \
|
||||
{total}.
|
||||
|
||||
Write entirely in {spec.language}. The format is {spec.style.value}.
|
||||
Speakers — attribute every line using these exact slot numbers:
|
||||
{render_speaker_roster(spec)}
|
||||
{recap_block}
|
||||
This segment is "{segment.title}". Cover these points using only facts grounded \
|
||||
in the provided source content:
|
||||
{talking_points}
|
||||
|
||||
Aim for about {segment.target_words} words of dialogue. Keep turns conversational \
|
||||
and varied; speakers should react to each other rather than deliver monologues. \
|
||||
Do not add greetings or sign-offs unless this is the first or last segment.
|
||||
|
||||
Respond with strict JSON and nothing else:
|
||||
{{"turns": [{{"speaker": <slot>, "text": "..."}}]}}
|
||||
"""
|
||||
|
|
@ -0,0 +1,47 @@
|
|||
"""Prompt for planning a long-form podcast outline before drafting dialogue.
|
||||
|
||||
Outlining first is what makes long-form reliable: a single LLM call cannot hold
|
||||
a coherent one- to two-hour script, but it can plan segments that are then
|
||||
drafted independently against a shared plan. The prompt is told the target
|
||||
length so the number and size of segments scale with the requested duration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
|
||||
from .speakers import render_speaker_roster
|
||||
|
||||
|
||||
def plan_outline_prompt(
|
||||
*,
|
||||
spec: PodcastSpec,
|
||||
target_words: int,
|
||||
suggested_segments: int,
|
||||
focus: str | None,
|
||||
) -> str:
|
||||
focus_block = (
|
||||
f"\nThe user asked the episode to focus on:\n{focus}\n" if focus else ""
|
||||
)
|
||||
return f"""\
|
||||
You are a podcast showrunner planning the structure of an episode before any \
|
||||
dialogue is written.
|
||||
|
||||
The episode language is {spec.language}. The format is {spec.style.value}.
|
||||
Speakers (refer to them by these slots later):
|
||||
{render_speaker_roster(spec)}
|
||||
{focus_block}
|
||||
Plan an outline that, when fully drafted, reaches roughly {target_words} words \
|
||||
of spoken dialogue (about {suggested_segments} segments). Each segment is one \
|
||||
coherent beat of the conversation: an opening, distinct topic areas grounded in \
|
||||
the source content, and a closing.
|
||||
|
||||
For each segment provide:
|
||||
- title: a short label for the beat
|
||||
- talking_points: 2-5 concrete points to cover, drawn from the source content
|
||||
- target_words: how many words of dialogue this segment should run (the sum \
|
||||
across segments should approximate {target_words})
|
||||
|
||||
Respond with strict JSON and nothing else:
|
||||
{{"segments": [{{"title": "...", "talking_points": ["..."], "target_words": 0}}]}}
|
||||
"""
|
||||
|
|
@ -0,0 +1,18 @@
|
|||
"""Render a spec's speaker roster for prompts.
|
||||
|
||||
The drafting prompts must reference speakers by the exact ``slot`` the renderer
|
||||
expects, so this is the single place that formats that roster — keeping the
|
||||
slot contract identical across every prompt that mentions speakers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
|
||||
|
||||
def render_speaker_roster(spec: PodcastSpec) -> str:
|
||||
lines = [
|
||||
f"- slot {speaker.slot} — {speaker.name} (role: {speaker.role.value})"
|
||||
for speaker in spec.speakers
|
||||
]
|
||||
return "\n".join(lines)
|
||||
50
surfsense_backend/app/podcasts/generation/structured.py
Normal file
50
surfsense_backend/app/podcasts/generation/structured.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""Parse a model's reply into a Pydantic shape, tolerating chatty output.
|
||||
|
||||
Agent LLMs return JSON wrapped in prose, markdown fences, or reasoning blocks,
|
||||
so a plain ``model_validate_json`` is unreliable. Centralising the tolerant
|
||||
parse here keeps every generation node validating replies the same way.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from app.utils.content_utils import extract_text_content, strip_markdown_fences
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
class StructuredOutputError(RuntimeError):
|
||||
"""The model reply could not be parsed into the expected shape."""
|
||||
|
||||
|
||||
async def invoke_json[T: BaseModel](
|
||||
llm, messages: list[BaseMessage], model: type[T]
|
||||
) -> T:
|
||||
"""Invoke ``llm`` and validate its reply as ``model``."""
|
||||
response = await llm.ainvoke(messages)
|
||||
content = strip_markdown_fences(extract_text_content(response.content))
|
||||
|
||||
try:
|
||||
return model.model_validate_json(content)
|
||||
except (ValidationError, ValueError):
|
||||
pass
|
||||
|
||||
start = content.find("{")
|
||||
end = content.rfind("}") + 1
|
||||
if 0 <= start < end:
|
||||
try:
|
||||
return model.model_validate_json(content[start:end])
|
||||
except (ValidationError, ValueError) as exc:
|
||||
raise StructuredOutputError(
|
||||
f"could not parse {model.__name__} from model reply"
|
||||
) from exc
|
||||
|
||||
raise StructuredOutputError(
|
||||
f"no JSON object found for {model.__name__} in model reply"
|
||||
)
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
"""Transcript drafting: outline-first, long-form dialogue generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .config import TranscriptConfig
|
||||
from .graph import build_transcript_graph
|
||||
from .planning import Outline, OutlineSegment, SegmentDraft
|
||||
from .state import TranscriptState
|
||||
|
||||
__all__ = [
|
||||
"Outline",
|
||||
"OutlineSegment",
|
||||
"SegmentDraft",
|
||||
"TranscriptConfig",
|
||||
"TranscriptState",
|
||||
"build_transcript_graph",
|
||||
]
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
"""Configurable inputs for the transcript-drafting graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class TranscriptConfig:
|
||||
"""The approved spec and user focus that drive drafting."""
|
||||
|
||||
search_space_id: int
|
||||
spec: PodcastSpec
|
||||
focus: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_runnable_config(
|
||||
cls, config: RunnableConfig | None = None
|
||||
) -> TranscriptConfig:
|
||||
configurable = (config.get("configurable") or {}) if config else {}
|
||||
names = {f.name for f in fields(cls) if f.init}
|
||||
return cls(**{k: v for k, v in configurable.items() if k in names})
|
||||
|
|
@ -0,0 +1,29 @@
|
|||
"""The transcript-drafting graph: outline, draft segments, finalize."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from .config import TranscriptConfig
|
||||
from .nodes import draft_segments, finalize, plan_outline
|
||||
from .state import TranscriptState
|
||||
|
||||
|
||||
def build_transcript_graph():
|
||||
workflow = StateGraph(TranscriptState, config_schema=TranscriptConfig)
|
||||
|
||||
workflow.add_node("plan_outline", plan_outline)
|
||||
workflow.add_node("draft_segments", draft_segments)
|
||||
workflow.add_node("finalize", finalize)
|
||||
|
||||
workflow.add_edge("__start__", "plan_outline")
|
||||
workflow.add_edge("plan_outline", "draft_segments")
|
||||
workflow.add_edge("draft_segments", "finalize")
|
||||
workflow.add_edge("finalize", "__end__")
|
||||
|
||||
graph = workflow.compile()
|
||||
graph.name = "Surfsense Podcast Transcript"
|
||||
return graph
|
||||
|
||||
|
||||
graph = build_transcript_graph()
|
||||
127
surfsense_backend/app/podcasts/generation/transcript/nodes.py
Normal file
127
surfsense_backend/app/podcasts/generation/transcript/nodes.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Transcript-drafting nodes: plan an outline, draft each beat, then assemble.
|
||||
|
||||
Long-form is produced beat-by-beat: a single call plans the structure, then each
|
||||
segment is drafted on its own with a recap of what came before so the script
|
||||
stays coherent without holding the whole episode in one context window.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn
|
||||
from app.services.llm_service import get_agent_llm
|
||||
|
||||
from ..prompts import draft_segment_prompt, plan_outline_prompt
|
||||
from ..structured import invoke_json
|
||||
from .config import TranscriptConfig
|
||||
from .planning import Outline, SegmentDraft
|
||||
from .state import TranscriptState
|
||||
|
||||
# Average speaking rate; converts target minutes to a target word count.
|
||||
_WORDS_PER_MINUTE = 150
|
||||
# Rough words per outline segment, used to suggest how many segments to plan.
|
||||
_WORDS_PER_SEGMENT = 250
|
||||
# Cap on source text sent per LLM call to bound tokens on large sources.
|
||||
_SOURCE_BUDGET_CHARS = 12000
|
||||
# How much prior dialogue to recap into each segment for continuity.
|
||||
_RECAP_CHARS = 800
|
||||
|
||||
|
||||
async def plan_outline(
|
||||
state: TranscriptState, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Plan the segment structure sized to the spec's target duration."""
|
||||
tc = TranscriptConfig.from_runnable_config(config)
|
||||
llm = await _require_llm(state, tc)
|
||||
|
||||
target_words = round(tc.spec.duration.midpoint_minutes * _WORDS_PER_MINUTE)
|
||||
suggested_segments = max(1, round(target_words / _WORDS_PER_SEGMENT))
|
||||
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=plan_outline_prompt(
|
||||
spec=tc.spec,
|
||||
target_words=target_words,
|
||||
suggested_segments=suggested_segments,
|
||||
focus=tc.focus,
|
||||
)
|
||||
),
|
||||
HumanMessage(content=_source_block(state.source_content)),
|
||||
]
|
||||
outline = await invoke_json(llm, messages, Outline)
|
||||
return {"outline": outline}
|
||||
|
||||
|
||||
async def draft_segments(
|
||||
state: TranscriptState, config: RunnableConfig
|
||||
) -> dict[str, Any]:
|
||||
"""Draft each outline segment in order, carrying a running recap."""
|
||||
tc = TranscriptConfig.from_runnable_config(config)
|
||||
llm = await _require_llm(state, tc)
|
||||
outline = state.outline
|
||||
if outline is None:
|
||||
raise RuntimeError("draft_segments requires an outline")
|
||||
|
||||
source_block = _source_block(state.source_content)
|
||||
turns: list[TranscriptTurn] = []
|
||||
total = len(outline.segments)
|
||||
|
||||
for index, segment in enumerate(outline.segments):
|
||||
messages = [
|
||||
SystemMessage(
|
||||
content=draft_segment_prompt(
|
||||
spec=tc.spec,
|
||||
segment=segment,
|
||||
position=index + 1,
|
||||
total=total,
|
||||
recap=_recap(turns, tc.spec),
|
||||
)
|
||||
),
|
||||
HumanMessage(content=source_block),
|
||||
]
|
||||
draft = await invoke_json(llm, messages, SegmentDraft)
|
||||
turns.extend(_valid_turns(draft, tc.spec))
|
||||
|
||||
return {"drafted_turns": turns}
|
||||
|
||||
|
||||
def finalize(state: TranscriptState, config: RunnableConfig) -> dict[str, Any]:
|
||||
"""Assemble drafted turns into a validated transcript."""
|
||||
if not state.drafted_turns:
|
||||
raise RuntimeError("drafting produced no usable dialogue")
|
||||
return {"transcript": Transcript(turns=state.drafted_turns)}
|
||||
|
||||
|
||||
async def _require_llm(state: TranscriptState, tc: TranscriptConfig):
|
||||
llm = await get_agent_llm(state.db_session, tc.search_space_id)
|
||||
if llm is None:
|
||||
raise RuntimeError(
|
||||
f"no agent LLM configured for search space {tc.search_space_id}"
|
||||
)
|
||||
return llm
|
||||
|
||||
|
||||
def _source_block(source_content: str) -> str:
|
||||
sample = (source_content or "")[:_SOURCE_BUDGET_CHARS]
|
||||
return f"<source_content>{sample}</source_content>"
|
||||
|
||||
|
||||
def _valid_turns(draft: SegmentDraft, spec: PodcastSpec) -> list[TranscriptTurn]:
|
||||
# Drop any turn the model attributed to a slot the spec doesn't define, so a
|
||||
# stray attribution can't break rendering downstream.
|
||||
valid_slots = {speaker.slot for speaker in spec.speakers}
|
||||
return [turn for turn in draft.turns if turn.speaker in valid_slots]
|
||||
|
||||
|
||||
def _recap(turns: list[TranscriptTurn], spec: PodcastSpec) -> str | None:
|
||||
if not turns:
|
||||
return None
|
||||
names = {speaker.slot: speaker.name for speaker in spec.speakers}
|
||||
rendered = "\n".join(
|
||||
f"{names.get(turn.speaker, turn.speaker)}: {turn.text}" for turn in turns
|
||||
)
|
||||
return rendered[-_RECAP_CHARS:]
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
"""Internal shapes the transcript graph passes between its nodes.
|
||||
|
||||
These are generation-time artifacts (the outline and per-segment drafts), not
|
||||
persisted or API-facing. Segment drafts reuse :class:`TranscriptTurn` so the
|
||||
speaker-slot contract and turn validation are identical to the final transcript.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.podcasts.schemas import TranscriptTurn
|
||||
|
||||
|
||||
class OutlineSegment(BaseModel):
|
||||
"""One planned beat of the conversation, drafted independently."""
|
||||
|
||||
title: str = Field(..., min_length=1)
|
||||
talking_points: list[str] = Field(default_factory=list)
|
||||
target_words: int = Field(..., ge=1)
|
||||
|
||||
|
||||
class Outline(BaseModel):
|
||||
"""The full plan: ordered segments sized to the target duration."""
|
||||
|
||||
segments: list[OutlineSegment] = Field(..., min_length=1)
|
||||
|
||||
|
||||
class SegmentDraft(BaseModel):
|
||||
"""The dialogue a single segment produced."""
|
||||
|
||||
turns: list[TranscriptTurn] = Field(default_factory=list)
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
"""Mutable state threaded through the transcript-drafting graph."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.podcasts.schemas import Transcript, TranscriptTurn
|
||||
|
||||
from .planning import Outline
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptState:
|
||||
"""Source content plus the intermediate and final drafting artifacts."""
|
||||
|
||||
db_session: AsyncSession
|
||||
source_content: str
|
||||
outline: Outline | None = None
|
||||
drafted_turns: list[TranscriptTurn] = field(default_factory=list)
|
||||
transcript: Transcript | None = None
|
||||
9
surfsense_backend/app/podcasts/persistence/__init__.py
Normal file
9
surfsense_backend/app/podcasts/persistence/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
"""Models, enums, and data access for the podcasts table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .enums import PodcastStatus
|
||||
from .models import Podcast
|
||||
from .repository import PodcastRepository
|
||||
|
||||
__all__ = ["Podcast", "PodcastRepository", "PodcastStatus"]
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
"""Enums for the podcasts table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .podcast_status import PodcastStatus
|
||||
|
||||
__all__ = ["PodcastStatus"]
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
"""Podcast generation lifecycle.
|
||||
|
||||
The status drives a guarded state machine. A podcast is proposed (``PENDING``),
|
||||
gets a reviewable brief (``AWAITING_BRIEF``), is drafted into a transcript
|
||||
(``DRAFTING``), then rendered to audio (``RENDERING`` → ``READY``). ``FAILED``
|
||||
and ``CANCELLED`` are terminal; a ``READY`` episode can be sent back to the
|
||||
brief gate for regeneration, and an in-flight regeneration can be reverted to
|
||||
``READY`` while the previous audio still exists. ``AWAITING_REVIEW`` is
|
||||
retained for legacy rows but
|
||||
never entered anymore — the brief is the only approval gate. The Python enum is
|
||||
kept in lockstep with the ``podcast_status`` Postgres type via its paired
|
||||
migration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class PodcastStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
AWAITING_BRIEF = "awaiting_brief"
|
||||
DRAFTING = "drafting"
|
||||
AWAITING_REVIEW = "awaiting_review"
|
||||
RENDERING = "rendering"
|
||||
READY = "ready"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
@property
|
||||
def is_terminal(self) -> bool:
|
||||
"""Whether no further transition is possible from this state."""
|
||||
return self in _TERMINAL
|
||||
|
||||
@property
|
||||
def is_gate(self) -> bool:
|
||||
"""Whether this state waits on user input before proceeding."""
|
||||
return self in _GATES
|
||||
|
||||
|
||||
_TERMINAL = frozenset({PodcastStatus.FAILED, PodcastStatus.CANCELLED})
|
||||
_GATES = frozenset({PodcastStatus.AWAITING_BRIEF, PodcastStatus.AWAITING_REVIEW})
|
||||
82
surfsense_backend/app/podcasts/persistence/models.py
Normal file
82
surfsense_backend/app/podcasts/persistence/models.py
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
"""``podcasts`` table: a generated podcast, its brief, transcript, and state."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
Enum as SQLAlchemyEnum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from app.db import BaseModel, TimestampMixin
|
||||
|
||||
from .enums import PodcastStatus
|
||||
|
||||
|
||||
class Podcast(BaseModel, TimestampMixin):
|
||||
"""A podcast across its whole lifecycle: brief, transcript, audio, status.
|
||||
|
||||
``spec`` (the reviewable brief) and ``podcast_transcript`` are JSONB so the
|
||||
flexible Pydantic shapes can evolve without migrations. ``spec_version``
|
||||
backs optimistic concurrency on brief edits. Rendered audio lives in the
|
||||
object store, addressed by ``storage_backend`` + ``storage_key`` rather than
|
||||
a raw path.
|
||||
"""
|
||||
|
||||
__tablename__ = "podcasts"
|
||||
|
||||
title = Column(String(500), nullable=False)
|
||||
|
||||
status = Column(
|
||||
SQLAlchemyEnum(
|
||||
PodcastStatus,
|
||||
name="podcast_status",
|
||||
create_type=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
default=PodcastStatus.PENDING,
|
||||
server_default=PodcastStatus.PENDING.value,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# The source material the episode is generated from. Persisted because
|
||||
# drafting happens after the brief gate, long after creation.
|
||||
source_content = Column(Text, nullable=True)
|
||||
|
||||
# The reviewable brief (PodcastSpec); null until the brief gate is reached.
|
||||
spec = Column(JSONB, nullable=True)
|
||||
# Bumped on every spec edit; guards concurrent edits at the brief gate.
|
||||
spec_version = Column(Integer, nullable=False, default=1, server_default="1")
|
||||
|
||||
# The drafted dialogue (Transcript); null until drafting completes.
|
||||
podcast_transcript = Column(JSONB, nullable=True)
|
||||
|
||||
# Where the rendered audio lives in the object store; null until READY.
|
||||
storage_backend = Column(String(32), nullable=True)
|
||||
storage_key = Column(Text, nullable=True)
|
||||
duration_seconds = Column(Integer, nullable=True)
|
||||
|
||||
# Human-readable reason when status is FAILED.
|
||||
error = Column(Text, nullable=True)
|
||||
|
||||
# Legacy local audio path; retained for back-compat until cutover.
|
||||
file_location = Column(Text, nullable=True)
|
||||
|
||||
search_space_id = Column(
|
||||
Integer, ForeignKey("searchspaces.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
search_space = relationship("SearchSpace", back_populates="podcasts")
|
||||
|
||||
thread_id = Column(
|
||||
Integer,
|
||||
ForeignKey("new_chat_threads.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
thread = relationship("NewChatThread")
|
||||
46
surfsense_backend/app/podcasts/persistence/repository.py
Normal file
46
surfsense_backend/app/podcasts/persistence/repository.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Data access for the ``podcasts`` table.
|
||||
|
||||
A thin async repository so the service and tasks never write raw queries. It
|
||||
only loads and persists rows; lifecycle rules and (de)serialization live in the
|
||||
service.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .models import Podcast
|
||||
|
||||
|
||||
class PodcastRepository:
|
||||
"""Loads and stores :class:`Podcast` rows for one session."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def get(self, podcast_id: int) -> Podcast | None:
|
||||
return await self._session.get(Podcast, podcast_id)
|
||||
|
||||
async def add(self, podcast: Podcast) -> Podcast:
|
||||
"""Persist a new row and assign its primary key."""
|
||||
self._session.add(podcast)
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def latest_with_spec(self, search_space_id: int) -> Podcast | None:
|
||||
"""Most recent podcast in the space that has a stored brief.
|
||||
|
||||
Used to seed language/voice defaults for a new podcast from what the
|
||||
user chose last.
|
||||
"""
|
||||
result = await self._session.execute(
|
||||
select(Podcast)
|
||||
.where(
|
||||
Podcast.search_space_id == search_space_id,
|
||||
Podcast.spec.is_not(None),
|
||||
)
|
||||
.order_by(Podcast.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalars().first()
|
||||
12
surfsense_backend/app/podcasts/rendering/__init__.py
Normal file
12
surfsense_backend/app/podcasts/rendering/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Rendering: synthesise and merge an approved transcript into audio.
|
||||
|
||||
The :class:`PodcastRenderer` is the public entry point; the segment cache and
|
||||
FFmpeg merge are implementation details it owns.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .errors import RenderError
|
||||
from .renderer import PodcastRenderer, RenderedPodcast
|
||||
|
||||
__all__ = ["PodcastRenderer", "RenderError", "RenderedPodcast"]
|
||||
53
surfsense_backend/app/podcasts/rendering/cache.py
Normal file
53
surfsense_backend/app/podcasts/rendering/cache.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Content-addressed cache for synthesised segments.
|
||||
|
||||
Each segment's audio is keyed by everything that determines its bytes (voice,
|
||||
language, speed, text). Keeping the cache in a stable per-podcast directory
|
||||
makes re-renders cheap: changing one speaker's voice only misses that speaker's
|
||||
turns, and a worker restart mid-render resumes from whatever was already
|
||||
written. The key intentionally excludes the segment's position so identical
|
||||
lines (e.g. repeated "Right.") synthesise once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.podcasts.tts import SynthesisRequest
|
||||
|
||||
|
||||
class SegmentCache:
|
||||
"""On-disk store of segment audio, addressed by request content hash."""
|
||||
|
||||
def __init__(self, root: Path) -> None:
|
||||
self._root = root
|
||||
self._root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def key(self, request: SynthesisRequest) -> str:
|
||||
"""A stable hash of the inputs that determine the synthesised bytes."""
|
||||
material = json.dumps(
|
||||
{
|
||||
"voice": request.voice,
|
||||
"language": request.language,
|
||||
"speed": request.speed,
|
||||
"text": request.text,
|
||||
},
|
||||
sort_keys=True,
|
||||
ensure_ascii=True,
|
||||
)
|
||||
return hashlib.sha256(material.encode("utf-8")).hexdigest()
|
||||
|
||||
def path(self, key: str, container: str) -> Path:
|
||||
return self._root / f"{key}.{container}"
|
||||
|
||||
def get(self, key: str, container: str) -> Path | None:
|
||||
"""Return the cached segment path, or ``None`` on a miss."""
|
||||
path = self.path(key, container)
|
||||
return path if path.exists() else None
|
||||
|
||||
def put(self, key: str, container: str, data: bytes) -> Path:
|
||||
"""Write ``data`` for ``key`` and return its path."""
|
||||
path = self.path(key, container)
|
||||
path.write_bytes(data)
|
||||
return path
|
||||
11
surfsense_backend/app/podcasts/rendering/errors.py
Normal file
11
surfsense_backend/app/podcasts/rendering/errors.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
"""Failures raised while rendering a transcript to audio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class RenderError(RuntimeError):
|
||||
"""Rendering could not produce a final audio file.
|
||||
|
||||
Wraps both per-segment synthesis failures and the merge step so the render
|
||||
task sees one failure type regardless of where it originated.
|
||||
"""
|
||||
48
surfsense_backend/app/podcasts/rendering/merge.py
Normal file
48
surfsense_backend/app/podcasts/rendering/merge.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
"""Concatenate ordered segment files into a single MP3.
|
||||
|
||||
Uses FFmpeg's concat *demuxer* (a list file of inputs) rather than a
|
||||
``filter_complex`` graph. The demuxer takes one ``-i`` no matter how many
|
||||
segments there are, so an hour-long episode with thousands of turns never hits
|
||||
command-line length limits. Output is always re-encoded to MP3 for a uniform
|
||||
artifact regardless of the source container (Kokoro WAV or hosted MP3).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from ffmpeg.asyncio import FFmpeg
|
||||
|
||||
from .errors import RenderError
|
||||
|
||||
|
||||
async def concat_to_mp3(segment_paths: list[Path], output_path: Path) -> None:
|
||||
"""Merge ``segment_paths`` in order into ``output_path`` as MP3."""
|
||||
if not segment_paths:
|
||||
raise RenderError("cannot merge an empty list of segments")
|
||||
|
||||
list_file = output_path.with_name(f"{output_path.stem}.concat.txt")
|
||||
list_file.write_text(_concat_list(segment_paths), encoding="utf-8")
|
||||
|
||||
try:
|
||||
ffmpeg = (
|
||||
FFmpeg()
|
||||
.option("y")
|
||||
.input(str(list_file), f="concat", safe=0)
|
||||
.output(str(output_path), {"c:a": "libmp3lame"})
|
||||
)
|
||||
await ffmpeg.execute()
|
||||
except Exception as exc:
|
||||
raise RenderError(f"audio merge failed: {exc}") from exc
|
||||
finally:
|
||||
list_file.unlink(missing_ok=True)
|
||||
|
||||
|
||||
def _concat_list(segment_paths: list[Path]) -> str:
|
||||
# The concat demuxer reads `file '<path>'` lines; single quotes in a path
|
||||
# are escaped per its quoting rules ('\'').
|
||||
lines = []
|
||||
for path in segment_paths:
|
||||
escaped = str(path.resolve()).replace("'", "'\\''")
|
||||
lines.append(f"file '{escaped}'")
|
||||
return "\n".join(lines) + "\n"
|
||||
155
surfsense_backend/app/podcasts/rendering/renderer.py
Normal file
155
surfsense_backend/app/podcasts/rendering/renderer.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Render an approved transcript into a single podcast audio file.
|
||||
|
||||
The renderer is the only place that turns dialogue into sound. It maps each
|
||||
turn to its speaker's voice, synthesises segments concurrently (capped, served
|
||||
from the segment cache when possible, and coalesced so identical lines render
|
||||
once), then merges them in order. It takes a settled spec + transcript and
|
||||
returns bytes; persistence and lifecycle transitions belong to the service.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn
|
||||
from app.podcasts.tts import SynthesisRequest, TextToSpeech, TextToSpeechError
|
||||
from app.podcasts.voices import VoiceCatalog
|
||||
|
||||
from .cache import SegmentCache
|
||||
from .errors import RenderError
|
||||
from .merge import concat_to_mp3
|
||||
|
||||
# Bounds how many segments synthesise at once. Protects hosted-provider rate
|
||||
# limits and avoids thrashing the local Kokoro pipeline; the renderer is I/O- or
|
||||
# model-bound per segment, so a small pool already saturates throughput.
|
||||
DEFAULT_MAX_CONCURRENCY = 4
|
||||
|
||||
_MERGED_FILENAME = "podcast.mp3"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RenderedPodcast:
|
||||
"""The finished episode: encoded bytes plus their container."""
|
||||
|
||||
data: bytes
|
||||
container: str
|
||||
|
||||
|
||||
class PodcastRenderer:
|
||||
"""Synthesises and merges a transcript using one TTS provider."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tts: TextToSpeech,
|
||||
catalog: VoiceCatalog,
|
||||
max_concurrency: int = DEFAULT_MAX_CONCURRENCY,
|
||||
) -> None:
|
||||
self._tts = tts
|
||||
self._catalog = catalog
|
||||
self._max_concurrency = max_concurrency
|
||||
|
||||
async def render(
|
||||
self,
|
||||
*,
|
||||
spec: PodcastSpec,
|
||||
transcript: Transcript,
|
||||
workdir: Path,
|
||||
) -> RenderedPodcast:
|
||||
"""Produce the merged MP3 for ``transcript`` under ``spec``.
|
||||
|
||||
``workdir`` holds the segment cache and merge output; reusing the same
|
||||
directory across renders is what makes voice edits cheap.
|
||||
"""
|
||||
cache = SegmentCache(workdir / "segments")
|
||||
requests = [self._request_for(spec, turn) for turn in transcript.turns]
|
||||
|
||||
# Concurrency primitives are created per render so each call is bound to
|
||||
# the event loop running it (Celery tasks may use a fresh loop).
|
||||
synthesizer = _SegmentSynthesizer(self._tts, cache, self._max_concurrency)
|
||||
segment_paths = await asyncio.gather(
|
||||
*(synthesizer.segment(request) for request in requests)
|
||||
)
|
||||
|
||||
output_path = workdir / _MERGED_FILENAME
|
||||
await concat_to_mp3(list(segment_paths), output_path)
|
||||
return RenderedPodcast(data=output_path.read_bytes(), container="mp3")
|
||||
|
||||
def _request_for(self, spec: PodcastSpec, turn: TranscriptTurn) -> SynthesisRequest:
|
||||
try:
|
||||
speaker = spec.speaker_for(turn.speaker)
|
||||
except KeyError as exc:
|
||||
raise RenderError(
|
||||
f"transcript references unknown speaker slot {turn.speaker}"
|
||||
) from exc
|
||||
try:
|
||||
voice = self._catalog.get(speaker.voice_id)
|
||||
except KeyError as exc:
|
||||
raise RenderError(f"unknown voice {speaker.voice_id!r}") from exc
|
||||
return SynthesisRequest(
|
||||
text=turn.text, voice=voice.native_ref, language=spec.language
|
||||
)
|
||||
|
||||
|
||||
class _SegmentSynthesizer:
|
||||
"""Per-render synthesis coordinator: caps concurrency and dedupes work.
|
||||
|
||||
Beyond the on-disk cache (which serves cross-render reuse), this coalesces
|
||||
identical segments that race within one render so the same line is voiced
|
||||
once even when several turns request it simultaneously.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, tts: TextToSpeech, cache: SegmentCache, max_concurrency: int
|
||||
) -> None:
|
||||
self._tts = tts
|
||||
self._cache = cache
|
||||
self._container = tts.container
|
||||
self._semaphore = asyncio.Semaphore(max_concurrency)
|
||||
self._inflight: dict[str, asyncio.Future[Path]] = {}
|
||||
self._inflight_lock = asyncio.Lock()
|
||||
|
||||
async def segment(self, request: SynthesisRequest) -> Path:
|
||||
key = self._cache.key(request)
|
||||
cached = self._cache.get(key, self._container)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
async with self._inflight_lock:
|
||||
future = self._inflight.get(key)
|
||||
owner = future is None
|
||||
if owner:
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self._inflight[key] = future
|
||||
|
||||
# The owner runs the work and publishes the outcome on the shared future;
|
||||
# every caller (owner included) reads it back via ``await future`` so the
|
||||
# result is retrieved exactly once-or-more and never left dangling.
|
||||
if owner:
|
||||
try:
|
||||
path = await self._synthesize(request, key)
|
||||
except BaseException as exc:
|
||||
future.set_exception(exc)
|
||||
else:
|
||||
future.set_result(path)
|
||||
finally:
|
||||
await self._forget(key)
|
||||
|
||||
return await future
|
||||
|
||||
async def _synthesize(self, request: SynthesisRequest, key: str) -> Path:
|
||||
async with self._semaphore:
|
||||
cached = self._cache.get(key, self._container)
|
||||
if cached is not None:
|
||||
return cached
|
||||
try:
|
||||
audio = await self._tts.synthesize(request)
|
||||
except TextToSpeechError as exc:
|
||||
raise RenderError(f"segment synthesis failed: {exc}") from exc
|
||||
return self._cache.put(key, audio.container, audio.data)
|
||||
|
||||
async def _forget(self, key: str) -> None:
|
||||
async with self._inflight_lock:
|
||||
self._inflight.pop(key, None)
|
||||
27
surfsense_backend/app/podcasts/resolution/__init__.py
Normal file
27
surfsense_backend/app/podcasts/resolution/__init__.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Resolution: deterministic default chains for a fresh brief.
|
||||
|
||||
Turns the user's last-used preferences into concrete language and voice
|
||||
defaults, so the brief gate opens pre-filled and most users approve without
|
||||
editing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .language import (
|
||||
DEFAULT_LANGUAGE,
|
||||
DEFAULT_LANGUAGE_CHAIN,
|
||||
LanguageContext,
|
||||
LanguageResolver,
|
||||
resolve_language,
|
||||
)
|
||||
from .voices import VoiceResolutionError, resolve_voices
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_LANGUAGE",
|
||||
"DEFAULT_LANGUAGE_CHAIN",
|
||||
"LanguageContext",
|
||||
"LanguageResolver",
|
||||
"VoiceResolutionError",
|
||||
"resolve_language",
|
||||
"resolve_voices",
|
||||
]
|
||||
64
surfsense_backend/app/podcasts/resolution/language.py
Normal file
64
surfsense_backend/app/podcasts/resolution/language.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
"""Resolve the brief's language without spending tokens at the gate.
|
||||
|
||||
The chain mirrors the agreed policy: reuse the language the user last chose, and
|
||||
otherwise default to English (which the user can still override in the brief). We
|
||||
deliberately never guess the language from the source content — proposing a
|
||||
language the user did not ask for is worse than a predictable default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
# What a brand-new user with no signal gets, and what every chain ends on.
|
||||
DEFAULT_LANGUAGE = "en"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class LanguageContext:
|
||||
"""Signals available when proposing a language for a fresh podcast."""
|
||||
|
||||
last_used: str | None = None
|
||||
|
||||
|
||||
class LanguageResolver(ABC):
|
||||
"""One step in the language fallback chain."""
|
||||
|
||||
@abstractmethod
|
||||
def resolve(self, context: LanguageContext) -> str | None:
|
||||
"""Return a language tag, or ``None`` to defer to the next resolver."""
|
||||
|
||||
|
||||
class LastUsedLanguage(LanguageResolver):
|
||||
"""Reuse the language from the user's previous podcast."""
|
||||
|
||||
def resolve(self, context: LanguageContext) -> str | None:
|
||||
return context.last_used
|
||||
|
||||
|
||||
class DefaultLanguage(LanguageResolver):
|
||||
"""Terminal step: always yields the default so the chain never fails."""
|
||||
|
||||
def resolve(self, context: LanguageContext) -> str | None:
|
||||
return DEFAULT_LANGUAGE
|
||||
|
||||
|
||||
# Order encodes the policy; prepend stronger signals here as they appear.
|
||||
DEFAULT_LANGUAGE_CHAIN: tuple[LanguageResolver, ...] = (
|
||||
LastUsedLanguage(),
|
||||
DefaultLanguage(),
|
||||
)
|
||||
|
||||
|
||||
def resolve_language(
|
||||
context: LanguageContext,
|
||||
chain: tuple[LanguageResolver, ...] = DEFAULT_LANGUAGE_CHAIN,
|
||||
) -> str:
|
||||
"""Walk ``chain`` and return the first language a resolver yields."""
|
||||
for resolver in chain:
|
||||
language = resolver.resolve(context)
|
||||
if language:
|
||||
return language.strip()
|
||||
# The default resolver guarantees a value; this guards a misconfigured chain.
|
||||
return DEFAULT_LANGUAGE
|
||||
79
surfsense_backend/app/podcasts/resolution/voices.py
Normal file
79
surfsense_backend/app/podcasts/resolution/voices.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
"""Assign a default voice to each speaker for the resolved language.
|
||||
|
||||
The default chain reuses the user's previously chosen voices where they are
|
||||
still valid for the new language/provider, then fills any remaining speakers
|
||||
with distinct catalog voices (preferring an unused gender so a two-speaker
|
||||
episode sounds like two people). The user can override any of these in the
|
||||
brief; this only seeds sensible defaults so most briefs need no edits.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from app.podcasts.voices import CatalogVoice, TtsProvider, VoiceCatalog
|
||||
|
||||
|
||||
class VoiceResolutionError(RuntimeError):
|
||||
"""No catalog voice exists for the requested provider and language."""
|
||||
|
||||
|
||||
def resolve_voices(
|
||||
*,
|
||||
catalog: VoiceCatalog,
|
||||
provider: TtsProvider,
|
||||
language: str,
|
||||
speaker_count: int,
|
||||
preferred: Sequence[str] | None = None,
|
||||
) -> list[CatalogVoice]:
|
||||
"""Return one :class:`CatalogVoice` per speaker, in slot order.
|
||||
|
||||
``preferred`` is the user's last-used voice ids (by slot); any that no
|
||||
longer fit the provider/language are silently dropped and replaced.
|
||||
"""
|
||||
if speaker_count < 1:
|
||||
raise ValueError("speaker_count must be >= 1")
|
||||
|
||||
available = catalog.for_language(provider, language)
|
||||
if not available:
|
||||
raise VoiceResolutionError(
|
||||
f"{provider.value} has no voice for language {language!r}"
|
||||
)
|
||||
|
||||
preferred = preferred or ()
|
||||
by_id = {voice.voice_id: voice for voice in available}
|
||||
|
||||
assignment: list[CatalogVoice] = []
|
||||
used_ids: set[str] = set()
|
||||
used_genders: set = set()
|
||||
|
||||
for slot in range(speaker_count):
|
||||
reuse_id = preferred[slot] if slot < len(preferred) else None
|
||||
if reuse_id and reuse_id in by_id and reuse_id not in used_ids:
|
||||
voice = by_id[reuse_id]
|
||||
else:
|
||||
voice = _pick_distinct(available, used_ids, used_genders)
|
||||
assignment.append(voice)
|
||||
used_ids.add(voice.voice_id)
|
||||
used_genders.add(voice.gender)
|
||||
|
||||
return assignment
|
||||
|
||||
|
||||
def _pick_distinct(
|
||||
available: list[CatalogVoice],
|
||||
used_ids: set[str],
|
||||
used_genders: set,
|
||||
) -> CatalogVoice:
|
||||
"""Pick a fresh voice, preferring an unused gender, then any unused voice.
|
||||
|
||||
Falls back to the first catalog voice when speakers outnumber distinct
|
||||
voices, so resolution always assigns every speaker rather than failing.
|
||||
"""
|
||||
fresh = [v for v in available if v.voice_id not in used_ids]
|
||||
if fresh:
|
||||
for voice in fresh:
|
||||
if voice.gender not in used_genders:
|
||||
return voice
|
||||
return fresh[0]
|
||||
return available[0]
|
||||
24
surfsense_backend/app/podcasts/schemas/__init__.py
Normal file
24
surfsense_backend/app/podcasts/schemas/__init__.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""Pydantic shapes for the podcast brief and transcript."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .spec import (
|
||||
DurationTarget,
|
||||
PodcastSpec,
|
||||
PodcastStyle,
|
||||
SpeakerRole,
|
||||
SpeakerSpec,
|
||||
normalize_language_tag,
|
||||
)
|
||||
from .transcript import Transcript, TranscriptTurn
|
||||
|
||||
__all__ = [
|
||||
"DurationTarget",
|
||||
"PodcastSpec",
|
||||
"PodcastStyle",
|
||||
"SpeakerRole",
|
||||
"SpeakerSpec",
|
||||
"Transcript",
|
||||
"TranscriptTurn",
|
||||
"normalize_language_tag",
|
||||
]
|
||||
166
surfsense_backend/app/podcasts/schemas/spec.py
Normal file
166
surfsense_backend/app/podcasts/schemas/spec.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""The brief: the editable configuration a user approves before drafting.
|
||||
|
||||
A :class:`PodcastSpec` front-loads every decision that drives token or audio
|
||||
cost (language, speakers, voices, style, target length) so the expensive
|
||||
drafting and rendering steps run once against settled inputs. It is stored as
|
||||
JSONB on the ``podcasts`` row and round-trips through the review API.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
|
||||
|
||||
# A speaker count beyond this is almost never a real podcast and explodes the
|
||||
# voice/turn-attribution space, so we reject it at the brief gate.
|
||||
MAX_SPEAKERS = 6
|
||||
|
||||
# Long-form is a goal, but an open-ended upper bound invites runaway TTS bills.
|
||||
# One day of audio is a generous ceiling that still blocks obvious mistakes.
|
||||
MAX_DURATION_MINUTES = 24 * 60
|
||||
|
||||
# BCP-47 primary subtag plus optional region (e.g. ``en``, ``en-US``, ``pt-BR``).
|
||||
# Kept deliberately permissive: the voice catalog, not the brief, decides which
|
||||
# languages can actually be synthesised. Casing is normalised after matching.
|
||||
_LANGUAGE_TAG = re.compile(r"^[A-Za-z]{2,3}(-[A-Za-z0-9]{2,8})*$")
|
||||
|
||||
|
||||
def normalize_language_tag(value: str) -> str:
|
||||
"""Validate and canonicalise a BCP-47 tag (lowercased primary subtag).
|
||||
|
||||
Shared with the generation layer so resolved and user-entered languages are
|
||||
normalised identically before they reach a :class:`PodcastSpec`.
|
||||
"""
|
||||
cleaned = value.strip()
|
||||
if not _LANGUAGE_TAG.match(cleaned):
|
||||
raise ValueError(f"not a valid BCP-47 language tag: {value!r}")
|
||||
primary, _, rest = cleaned.partition("-")
|
||||
return primary.lower() if not rest else f"{primary.lower()}-{rest}"
|
||||
|
||||
|
||||
class SpeakerRole(StrEnum):
|
||||
"""How a speaker functions in the conversation, used to steer drafting."""
|
||||
|
||||
HOST = "host"
|
||||
COHOST = "cohost"
|
||||
GUEST = "guest"
|
||||
EXPERT = "expert"
|
||||
NARRATOR = "narrator"
|
||||
|
||||
|
||||
class PodcastStyle(StrEnum):
|
||||
"""The conversational format the transcript should follow."""
|
||||
|
||||
CONVERSATIONAL = "conversational"
|
||||
INTERVIEW = "interview"
|
||||
DEBATE = "debate"
|
||||
MONOLOGUE = "monologue"
|
||||
NARRATIVE = "narrative"
|
||||
|
||||
|
||||
class SpeakerSpec(BaseModel):
|
||||
"""One voice in the podcast: who they are and which TTS voice renders them.
|
||||
|
||||
``slot`` is the stable join key. Transcript turns reference a speaker by
|
||||
``slot`` and the renderer resolves ``voice_id`` for that same slot, so the
|
||||
two never drift even if speakers are reordered in the brief.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
slot: int = Field(
|
||||
..., ge=0, description="Stable index a transcript turn references"
|
||||
)
|
||||
name: str = Field(..., min_length=1, max_length=120)
|
||||
role: SpeakerRole
|
||||
voice_id: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Catalog voice id valid for the spec's language and provider",
|
||||
)
|
||||
|
||||
@field_validator("name", "voice_id")
|
||||
@classmethod
|
||||
def _strip_required_text(cls, value: str) -> str:
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("must not be blank")
|
||||
return cleaned
|
||||
|
||||
|
||||
class DurationTarget(BaseModel):
|
||||
"""The desired finished length as an inclusive minute range.
|
||||
|
||||
Drafting aims for the midpoint and treats the bounds as soft guardrails;
|
||||
storing a range (rather than a point) keeps long-form expectations honest
|
||||
without pretending we can hit an exact runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
min_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES)
|
||||
max_minutes: int = Field(..., ge=1, le=MAX_DURATION_MINUTES)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_order(self) -> DurationTarget:
|
||||
if self.max_minutes < self.min_minutes:
|
||||
raise ValueError("max_minutes must be >= min_minutes")
|
||||
return self
|
||||
|
||||
@property
|
||||
def midpoint_minutes(self) -> float:
|
||||
"""The runtime drafting should aim for within the range."""
|
||||
return (self.min_minutes + self.max_minutes) / 2
|
||||
|
||||
|
||||
class PodcastSpec(BaseModel):
|
||||
"""The full brief approved before any tokens or audio are spent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
language: str = Field(..., description="BCP-47 tag, e.g. 'en', 'en-US', 'pt-BR'")
|
||||
style: PodcastStyle = PodcastStyle.CONVERSATIONAL
|
||||
speakers: list[SpeakerSpec] = Field(..., min_length=1, max_length=MAX_SPEAKERS)
|
||||
duration: DurationTarget
|
||||
focus: str | None = Field(
|
||||
default=None,
|
||||
max_length=2000,
|
||||
description="Optional user steer for what the episode should emphasise",
|
||||
)
|
||||
|
||||
@field_validator("language")
|
||||
@classmethod
|
||||
def _normalise_language(cls, value: str) -> str:
|
||||
return normalize_language_tag(value)
|
||||
|
||||
@field_validator("focus")
|
||||
@classmethod
|
||||
def _blank_focus_is_none(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
cleaned = value.strip()
|
||||
return cleaned or None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_speaker_slots(self) -> PodcastSpec:
|
||||
slots = [speaker.slot for speaker in self.speakers]
|
||||
if len(slots) != len(set(slots)):
|
||||
raise ValueError("speaker slots must be unique")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_style_speakers(self) -> PodcastSpec:
|
||||
# One voice is what "monologue" means; letting extra speakers through
|
||||
# would force drafting to silently pick a winner.
|
||||
if self.style is PodcastStyle.MONOLOGUE and len(self.speakers) != 1:
|
||||
raise ValueError("a monologue has exactly one speaker")
|
||||
return self
|
||||
|
||||
def speaker_for(self, slot: int) -> SpeakerSpec:
|
||||
"""Return the speaker bound to ``slot`` or raise if none matches."""
|
||||
for speaker in self.speakers:
|
||||
if speaker.slot == slot:
|
||||
return speaker
|
||||
raise KeyError(f"no speaker for slot {slot}")
|
||||
41
surfsense_backend/app/podcasts/schemas/transcript.py
Normal file
41
surfsense_backend/app/podcasts/schemas/transcript.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""The transcript: ordered dialogue turns drafting produces for review.
|
||||
|
||||
A :class:`Transcript` is the reviewable artifact at the go/no-go gate and the
|
||||
exact input the renderer turns into audio. Each turn names a speaker by the
|
||||
``slot`` defined in the :class:`~app.podcasts.schemas.spec.PodcastSpec`, so the
|
||||
renderer can resolve the right voice without re-attributing anything.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
|
||||
class TranscriptTurn(BaseModel):
|
||||
"""A single spoken line by one speaker."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
speaker: int = Field(..., ge=0, description="The PodcastSpec speaker slot speaking")
|
||||
text: str = Field(..., min_length=1)
|
||||
|
||||
@field_validator("text")
|
||||
@classmethod
|
||||
def _strip_text(cls, value: str) -> str:
|
||||
cleaned = value.strip()
|
||||
if not cleaned:
|
||||
raise ValueError("turn text must not be blank")
|
||||
return cleaned
|
||||
|
||||
|
||||
class Transcript(BaseModel):
|
||||
"""The full ordered dialogue for an episode."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
turns: list[TranscriptTurn] = Field(..., min_length=1)
|
||||
|
||||
@property
|
||||
def word_count(self) -> int:
|
||||
"""Total spoken words, used to estimate runtime against the brief."""
|
||||
return sum(len(turn.text.split()) for turn in self.turns)
|
||||
255
surfsense_backend/app/podcasts/service.py
Normal file
255
surfsense_backend/app/podcasts/service.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
"""The podcast lifecycle authority: every status change goes through here.
|
||||
|
||||
The service owns the state machine. Each method names a real lifecycle step,
|
||||
validates it against the allowed-transition table, and (de)serializes the brief
|
||||
and transcript to/from their JSONB columns. It deliberately does not enqueue
|
||||
Celery work — callers transition the row here, then schedule the next task — so
|
||||
the rules stay testable and free of task-queue coupling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.podcasts.persistence import Podcast, PodcastRepository, PodcastStatus
|
||||
from app.podcasts.schemas import PodcastSpec, Transcript, TranscriptTurn
|
||||
|
||||
_MAX_ERROR_CHARS = 2000
|
||||
|
||||
# The only status changes the machine permits. Terminal states have no exits.
|
||||
_ALLOWED: dict[PodcastStatus, frozenset[PodcastStatus]] = {
|
||||
PodcastStatus.PENDING: frozenset(
|
||||
{PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED}
|
||||
),
|
||||
# The READY exits below exist for reverting a regeneration; the audio
|
||||
# guard for that lives in revert_regeneration.
|
||||
PodcastStatus.AWAITING_BRIEF: frozenset(
|
||||
{
|
||||
PodcastStatus.DRAFTING,
|
||||
PodcastStatus.READY,
|
||||
PodcastStatus.FAILED,
|
||||
PodcastStatus.CANCELLED,
|
||||
}
|
||||
),
|
||||
PodcastStatus.DRAFTING: frozenset(
|
||||
{
|
||||
PodcastStatus.RENDERING,
|
||||
PodcastStatus.READY,
|
||||
PodcastStatus.FAILED,
|
||||
PodcastStatus.CANCELLED,
|
||||
}
|
||||
),
|
||||
# Never entered anymore (the transcript gate was dropped); kept with exits
|
||||
# so legacy rows aren't stranded.
|
||||
PodcastStatus.AWAITING_REVIEW: frozenset(
|
||||
{PodcastStatus.AWAITING_BRIEF, PodcastStatus.FAILED, PodcastStatus.CANCELLED}
|
||||
),
|
||||
PodcastStatus.RENDERING: frozenset(
|
||||
{PodcastStatus.READY, PodcastStatus.FAILED, PodcastStatus.CANCELLED}
|
||||
),
|
||||
# Not terminal: regeneration reopens the brief gate so the user can tweak
|
||||
# the spec before a new take is drafted.
|
||||
PodcastStatus.READY: frozenset({PodcastStatus.AWAITING_BRIEF}),
|
||||
PodcastStatus.FAILED: frozenset(),
|
||||
PodcastStatus.CANCELLED: frozenset(),
|
||||
}
|
||||
|
||||
|
||||
class PodcastError(RuntimeError):
|
||||
"""Base class for lifecycle errors."""
|
||||
|
||||
|
||||
class InvalidTransitionError(PodcastError):
|
||||
"""A requested status change is not permitted from the current state."""
|
||||
|
||||
|
||||
class SpecConflictError(PodcastError):
|
||||
"""A spec edit raced another: the expected version is stale."""
|
||||
|
||||
def __init__(self, expected: int, actual: int) -> None:
|
||||
super().__init__(
|
||||
f"spec version conflict: expected {expected}, current is {actual}"
|
||||
)
|
||||
self.expected = expected
|
||||
self.actual = actual
|
||||
|
||||
|
||||
class PreconditionFailedError(PodcastError):
|
||||
"""A transition's data precondition (brief/transcript present) is unmet."""
|
||||
|
||||
|
||||
class PodcastService:
|
||||
"""Drives one podcast through its lifecycle within a single session."""
|
||||
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
self._repo = PodcastRepository(session)
|
||||
|
||||
async def create(
|
||||
self, *, title: str, search_space_id: int, thread_id: int | None = None
|
||||
) -> Podcast:
|
||||
"""Create a fresh podcast in ``PENDING`` awaiting its brief."""
|
||||
podcast = Podcast(
|
||||
title=title,
|
||||
search_space_id=search_space_id,
|
||||
thread_id=thread_id,
|
||||
status=PodcastStatus.PENDING,
|
||||
spec_version=1,
|
||||
)
|
||||
return await self._repo.add(podcast)
|
||||
|
||||
async def attach_brief(self, podcast: Podcast, spec: PodcastSpec) -> Podcast:
|
||||
"""Record the proposed brief and open the review gate."""
|
||||
self._transition(podcast, PodcastStatus.AWAITING_BRIEF)
|
||||
podcast.spec = spec.model_dump(mode="json")
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def update_spec(
|
||||
self, podcast: Podcast, spec: PodcastSpec, expected_version: int
|
||||
) -> Podcast:
|
||||
"""Edit the brief at the gate, guarded by optimistic concurrency."""
|
||||
if _status(podcast) is not PodcastStatus.AWAITING_BRIEF:
|
||||
raise InvalidTransitionError(
|
||||
f"the brief can only be edited while awaiting_brief, "
|
||||
f"not {_status(podcast).value}"
|
||||
)
|
||||
if expected_version != podcast.spec_version:
|
||||
raise SpecConflictError(expected_version, podcast.spec_version)
|
||||
podcast.spec = spec.model_dump(mode="json")
|
||||
podcast.spec_version += 1
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def begin_drafting(self, podcast: Podcast) -> Podcast:
|
||||
"""Approve the brief and start transcript drafting."""
|
||||
if podcast.spec is None:
|
||||
raise PreconditionFailedError("cannot draft without a brief")
|
||||
self._transition(podcast, PodcastStatus.DRAFTING)
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def attach_transcript(
|
||||
self, podcast: Podcast, transcript: Transcript
|
||||
) -> Podcast:
|
||||
"""Record the drafted transcript and move straight to rendering."""
|
||||
self._transition(podcast, PodcastStatus.RENDERING)
|
||||
podcast.podcast_transcript = transcript.model_dump(mode="json")
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
# Guards regenerate beyond the transition table: from PENDING the
|
||||
# AWAITING_BRIEF target is also legal, but there it means attaching a brief.
|
||||
_REGENERABLE = frozenset({PodcastStatus.READY, PodcastStatus.AWAITING_REVIEW})
|
||||
|
||||
async def regenerate(self, podcast: Podcast) -> Podcast:
|
||||
"""Reopen the brief gate; the saved spec becomes the new starting point."""
|
||||
if _status(podcast) not in self._REGENERABLE:
|
||||
raise InvalidTransitionError(
|
||||
f"nothing to regenerate from {_status(podcast).value}"
|
||||
)
|
||||
# Legacy episodes finished before briefs existed; a gate with nothing
|
||||
# to review would strand them.
|
||||
if podcast.spec is None:
|
||||
raise PreconditionFailedError("cannot regenerate without a brief")
|
||||
self._transition(podcast, PodcastStatus.AWAITING_BRIEF)
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def revert_regeneration(self, podcast: Podcast) -> Podcast:
|
||||
"""Back out of a regeneration and fall back to the stored episode.
|
||||
|
||||
Regeneration keeps the rendered audio until a new take replaces it, so
|
||||
any point before that commit is a free change of mind. A fresh podcast
|
||||
has no regeneration to revert and is rejected.
|
||||
"""
|
||||
if not has_stored_episode(podcast):
|
||||
raise InvalidTransitionError("no finished episode to fall back to")
|
||||
self._transition(podcast, PodcastStatus.READY)
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def attach_audio(
|
||||
self,
|
||||
podcast: Podcast,
|
||||
*,
|
||||
storage_backend: str,
|
||||
storage_key: str,
|
||||
duration_seconds: int | None = None,
|
||||
) -> Podcast:
|
||||
"""Record rendered audio and mark the podcast ready."""
|
||||
self._transition(podcast, PodcastStatus.READY)
|
||||
podcast.storage_backend = storage_backend
|
||||
podcast.storage_key = storage_key
|
||||
podcast.duration_seconds = duration_seconds
|
||||
podcast.error = None
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def fail(self, podcast: Podcast, error: str) -> Podcast:
|
||||
"""Move a non-terminal podcast to ``FAILED`` with a reason."""
|
||||
self._transition(podcast, PodcastStatus.FAILED)
|
||||
podcast.error = (error or "")[:_MAX_ERROR_CHARS] or None
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
async def cancel(self, podcast: Podcast) -> Podcast:
|
||||
"""Cancel a podcast that has produced nothing the user could keep.
|
||||
|
||||
No user action may destroy playable audio: once an episode exists,
|
||||
backing out goes through revert_regeneration instead.
|
||||
"""
|
||||
if has_stored_episode(podcast):
|
||||
raise InvalidTransitionError(
|
||||
"a finished episode exists; revert the regeneration instead"
|
||||
)
|
||||
self._transition(podcast, PodcastStatus.CANCELLED)
|
||||
await self._session.flush()
|
||||
return podcast
|
||||
|
||||
def _transition(self, podcast: Podcast, target: PodcastStatus) -> None:
|
||||
current = _status(podcast)
|
||||
if target not in _ALLOWED[current]:
|
||||
raise InvalidTransitionError(
|
||||
f"{current.value} -> {target.value} is not allowed"
|
||||
)
|
||||
podcast.status = target
|
||||
|
||||
|
||||
def _status(podcast: Podcast) -> PodcastStatus:
|
||||
return PodcastStatus(podcast.status)
|
||||
|
||||
|
||||
def has_stored_episode(podcast: Podcast) -> bool:
|
||||
"""Whether finished audio is stored (``file_location`` covers legacy rows)."""
|
||||
return bool(podcast.storage_key or podcast.file_location)
|
||||
|
||||
|
||||
def read_spec(podcast: Podcast) -> PodcastSpec | None:
|
||||
"""Deserialize the stored brief, or ``None`` if not yet proposed."""
|
||||
return PodcastSpec.model_validate(podcast.spec) if podcast.spec else None
|
||||
|
||||
|
||||
def read_transcript(podcast: Podcast) -> Transcript | None:
|
||||
"""Deserialize the stored transcript, or ``None`` if not yet drafted."""
|
||||
raw = podcast.podcast_transcript
|
||||
if not raw:
|
||||
return None
|
||||
# Rows from before the lifecycle rework stored a bare turn list with
|
||||
# different field names; they must keep reading, not fail validation.
|
||||
if isinstance(raw, list):
|
||||
return Transcript(
|
||||
turns=[
|
||||
TranscriptTurn(speaker=turn["speaker_id"], text=turn["dialog"])
|
||||
for turn in raw
|
||||
]
|
||||
)
|
||||
return Transcript.model_validate(raw)
|
||||
|
||||
|
||||
def preferences_from(podcast: Podcast | None) -> tuple[str | None, list[str]]:
|
||||
"""Extract reusable (language, voice_ids) defaults from a prior podcast."""
|
||||
spec = read_spec(podcast) if podcast is not None else None
|
||||
if spec is None:
|
||||
return None, []
|
||||
return spec.language, [speaker.voice_id for speaker in spec.speakers]
|
||||
53
surfsense_backend/app/podcasts/storage.py
Normal file
53
surfsense_backend/app/podcasts/storage.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Durable storage for rendered podcast audio.
|
||||
|
||||
Wraps the shared :class:`StorageBackend` so the rest of the module never deals
|
||||
with object keys directly. Audio is stored under a per-podcast key, streamed for
|
||||
download, and purged when a podcast is deleted.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from app.file_storage.factory import get_storage_backend
|
||||
from app.podcasts.persistence import Podcast
|
||||
|
||||
_AUDIO_CONTENT_TYPE = "audio/mpeg"
|
||||
|
||||
|
||||
def build_audio_key(*, search_space_id: int, podcast_id: int) -> str:
|
||||
"""Object key for a podcast's audio.
|
||||
|
||||
Shape: ``podcasts/{search_space_id}/{podcast_id}/{uuid}.mp3``. The uuid lets
|
||||
a re-render write a fresh object before the old one is purged.
|
||||
"""
|
||||
return f"podcasts/{search_space_id}/{podcast_id}/{uuid.uuid4().hex}.mp3"
|
||||
|
||||
|
||||
async def store_audio(
|
||||
*, search_space_id: int, podcast_id: int, data: bytes
|
||||
) -> tuple[str, str]:
|
||||
"""Persist audio bytes and return ``(backend_name, storage_key)``."""
|
||||
backend = get_storage_backend()
|
||||
key = build_audio_key(search_space_id=search_space_id, podcast_id=podcast_id)
|
||||
await backend.put(key, data, content_type=_AUDIO_CONTENT_TYPE)
|
||||
return backend.backend_name, key
|
||||
|
||||
|
||||
def open_audio_stream(podcast: Podcast) -> AsyncIterator[bytes]:
|
||||
"""Stream a ready podcast's audio bytes. Raises if it has none."""
|
||||
if not podcast.storage_key:
|
||||
raise FileNotFoundError(f"podcast {podcast.id} has no stored audio")
|
||||
return get_storage_backend().open_stream(podcast.storage_key)
|
||||
|
||||
|
||||
async def purge_audio(podcast: Podcast) -> None:
|
||||
"""Delete a podcast's stored audio if present; a missing object is fine."""
|
||||
await purge_audio_object(podcast.storage_key)
|
||||
|
||||
|
||||
async def purge_audio_object(key: str | None) -> None:
|
||||
"""Delete a stored audio object by key, e.g. the one a re-render replaced."""
|
||||
if key:
|
||||
await get_storage_backend().delete(key)
|
||||
17
surfsense_backend/app/podcasts/tasks/__init__.py
Normal file
17
surfsense_backend/app/podcasts/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""Celery tasks driving the podcast lifecycle across its expensive phases.
|
||||
|
||||
One task per heavy async phase: draft the transcript (LLM) and render the audio
|
||||
(TTS). The brief is deterministic and proposed inline at create time, so it has
|
||||
no task. Each task is enqueued by the API after it performs the guarded status
|
||||
transition, and each pushes its result onto the row for the frontend to observe.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .draft import draft_transcript_task
|
||||
from .render import render_audio_task
|
||||
|
||||
__all__ = [
|
||||
"draft_transcript_task",
|
||||
"render_audio_task",
|
||||
]
|
||||
100
surfsense_backend/app/podcasts/tasks/draft.py
Normal file
100
surfsense_backend/app/podcasts/tasks/draft.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""Transcript-drafting task: DRAFTING -> RENDERING.
|
||||
|
||||
The expensive, LLM-heavy step, so it runs under ``billable_call``. The API has
|
||||
already moved the row to DRAFTING and stored the approved brief; this task
|
||||
drafts the long-form transcript and chains straight into the render — the brief
|
||||
gate is the only approval in the lifecycle.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.config import config as app_config
|
||||
from app.podcasts.generation.transcript.graph import graph as transcript_graph
|
||||
from app.podcasts.generation.transcript.state import TranscriptState
|
||||
from app.podcasts.persistence import PodcastRepository
|
||||
from app.podcasts.service import PodcastService, read_spec
|
||||
from app.services.billable_calls import (
|
||||
BillingSettlementError,
|
||||
QuotaInsufficientError,
|
||||
_resolve_agent_billing_for_search_space,
|
||||
billable_call,
|
||||
)
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
from .render import render_audio_task
|
||||
from .runtime import billable_session, mark_failed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(name="podcast.draft_transcript", bind=True)
|
||||
def draft_transcript_task(self, podcast_id: int, search_space_id: int) -> dict:
|
||||
try:
|
||||
return run_async_celery_task(
|
||||
lambda: _draft_transcript(podcast_id, search_space_id)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("Podcast %s drafting failed: %s", podcast_id, exc)
|
||||
message = str(exc)
|
||||
run_async_celery_task(lambda: mark_failed(podcast_id, message))
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
|
||||
|
||||
async def _draft_transcript(podcast_id: int, search_space_id: int) -> dict:
|
||||
async with get_celery_session_maker()() as session:
|
||||
repo = PodcastRepository(session)
|
||||
service = PodcastService(session)
|
||||
podcast = await repo.get(podcast_id)
|
||||
if podcast is None:
|
||||
raise ValueError(f"podcast {podcast_id} not found")
|
||||
|
||||
spec = read_spec(podcast)
|
||||
if spec is None:
|
||||
raise ValueError(f"podcast {podcast_id} has no approved brief")
|
||||
|
||||
owner_id, tier, base_model = await _resolve_agent_billing_for_search_space(
|
||||
session, search_space_id, thread_id=podcast.thread_id
|
||||
)
|
||||
|
||||
state = TranscriptState(
|
||||
db_session=session, source_content=podcast.source_content or ""
|
||||
)
|
||||
config = {
|
||||
"configurable": {
|
||||
"search_space_id": search_space_id,
|
||||
"spec": spec,
|
||||
"focus": spec.focus,
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
async with billable_call(
|
||||
user_id=owner_id,
|
||||
search_space_id=search_space_id,
|
||||
billing_tier=tier,
|
||||
base_model=base_model,
|
||||
quota_reserve_micros_override=app_config.QUOTA_DEFAULT_PODCAST_RESERVE_MICROS,
|
||||
usage_type="podcast_generation",
|
||||
call_details={"podcast_id": podcast_id, "title": podcast.title},
|
||||
billable_session_factory=billable_session,
|
||||
):
|
||||
result = await transcript_graph.ainvoke(state, config=config)
|
||||
except QuotaInsufficientError:
|
||||
await service.fail(podcast, "premium quota exhausted")
|
||||
await session.commit()
|
||||
return {"status": "failed", "podcast_id": podcast_id, "reason": "quota"}
|
||||
except BillingSettlementError:
|
||||
await service.fail(podcast, "billing settlement failed")
|
||||
await session.commit()
|
||||
return {"status": "failed", "podcast_id": podcast_id, "reason": "billing"}
|
||||
|
||||
await service.attach_transcript(podcast, result["transcript"])
|
||||
await session.commit()
|
||||
|
||||
# Enqueue only after the transaction is committed, so the render worker can
|
||||
# never pick up a row whose transcript isn't visible yet.
|
||||
render_audio_task.delay(podcast_id)
|
||||
return {"status": "rendering", "podcast_id": podcast_id}
|
||||
88
surfsense_backend/app/podcasts/tasks/render.py
Normal file
88
surfsense_backend/app/podcasts/tasks/render.py
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
"""Audio-rendering task: RENDERING -> READY.
|
||||
|
||||
Synthesises and merges the approved transcript, stores the MP3 in the object
|
||||
store, and marks the podcast ready. The working directory is stable per podcast
|
||||
so a re-render (e.g. after a voice change) reuses the segment cache.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.celery_app import celery_app
|
||||
from app.podcasts.persistence import PodcastRepository
|
||||
from app.podcasts.rendering import PodcastRenderer
|
||||
from app.podcasts.service import (
|
||||
InvalidTransitionError,
|
||||
PodcastService,
|
||||
read_spec,
|
||||
read_transcript,
|
||||
)
|
||||
from app.podcasts.storage import purge_audio_object, store_audio
|
||||
from app.podcasts.tts import get_text_to_speech
|
||||
from app.podcasts.voices import get_voice_catalog
|
||||
from app.tasks.celery_tasks import get_celery_session_maker, run_async_celery_task
|
||||
|
||||
from .runtime import mark_failed
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_WORKDIR_BASE = Path(tempfile.gettempdir()) / "surfsense_podcasts"
|
||||
|
||||
|
||||
@celery_app.task(name="podcast.render_audio", bind=True)
|
||||
def render_audio_task(self, podcast_id: int) -> dict:
|
||||
try:
|
||||
return run_async_celery_task(lambda: _render_audio(podcast_id))
|
||||
except Exception as exc:
|
||||
logger.error("Podcast %s render failed: %s", podcast_id, exc)
|
||||
message = str(exc)
|
||||
run_async_celery_task(lambda: mark_failed(podcast_id, message))
|
||||
return {"status": "failed", "podcast_id": podcast_id}
|
||||
|
||||
|
||||
async def _render_audio(podcast_id: int) -> dict:
|
||||
async with get_celery_session_maker()() as session:
|
||||
repo = PodcastRepository(session)
|
||||
podcast = await repo.get(podcast_id)
|
||||
if podcast is None:
|
||||
raise ValueError(f"podcast {podcast_id} not found")
|
||||
|
||||
spec = read_spec(podcast)
|
||||
transcript = read_transcript(podcast)
|
||||
if spec is None or transcript is None:
|
||||
raise ValueError(f"podcast {podcast_id} is missing brief or transcript")
|
||||
|
||||
renderer = PodcastRenderer(
|
||||
tts=get_text_to_speech(), catalog=get_voice_catalog()
|
||||
)
|
||||
workdir = _WORKDIR_BASE / str(podcast_id)
|
||||
workdir.mkdir(parents=True, exist_ok=True)
|
||||
rendered = await renderer.render(
|
||||
spec=spec, transcript=transcript, workdir=workdir
|
||||
)
|
||||
|
||||
superseded_key = podcast.storage_key
|
||||
|
||||
backend_name, key = await store_audio(
|
||||
search_space_id=podcast.search_space_id,
|
||||
podcast_id=podcast_id,
|
||||
data=rendered.data,
|
||||
)
|
||||
try:
|
||||
await PodcastService(session).attach_audio(
|
||||
podcast, storage_backend=backend_name, storage_key=key
|
||||
)
|
||||
await session.commit()
|
||||
except InvalidTransitionError:
|
||||
# A user back-out won the race (e.g. the regeneration was
|
||||
# reverted): drop the stale render and leave the row alone.
|
||||
await purge_audio_object(key)
|
||||
return {"status": "superseded", "podcast_id": podcast_id}
|
||||
|
||||
# Purge only after the new audio is committed, so a failed re-render never
|
||||
# destroys the episode the user can still play.
|
||||
await purge_audio_object(superseded_key)
|
||||
return {"status": "ready", "podcast_id": podcast_id}
|
||||
40
surfsense_backend/app/podcasts/tasks/runtime.py
Normal file
40
surfsense_backend/app/podcasts/tasks/runtime.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Shared plumbing for the podcast Celery tasks.
|
||||
|
||||
Each task runs its async body via :func:`run_async_celery_task` and, on any
|
||||
failure, records the reason on the row through the lifecycle service. Marking
|
||||
failed is best-effort: a podcast that already reached a terminal state is left
|
||||
untouched rather than forced.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from app.podcasts.persistence import PodcastRepository
|
||||
from app.podcasts.service import PodcastError, PodcastService
|
||||
from app.tasks.celery_tasks import get_celery_session_maker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def billable_session():
|
||||
"""Session factory for ``billable_call`` inside the worker loop."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def mark_failed(podcast_id: int, error: str) -> None:
|
||||
"""Best-effort: move a non-terminal podcast to FAILED with ``error``."""
|
||||
async with get_celery_session_maker()() as session:
|
||||
repo = PodcastRepository(session)
|
||||
podcast = await repo.get(podcast_id)
|
||||
if podcast is None:
|
||||
return
|
||||
try:
|
||||
await PodcastService(session).fail(podcast, error)
|
||||
await session.commit()
|
||||
except PodcastError:
|
||||
# Already terminal (e.g. cancelled): nothing to record.
|
||||
logger.info("Podcast %s already terminal; not marking failed", podcast_id)
|
||||
22
surfsense_backend/app/podcasts/tts/__init__.py
Normal file
22
surfsense_backend/app/podcasts/tts/__init__.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""Text-to-speech: a per-segment synthesis port with provider adapters.
|
||||
|
||||
Callers depend on :class:`TextToSpeech` and obtain the configured provider from
|
||||
:func:`get_text_to_speech`; the concrete Kokoro/LiteLLM adapters stay private.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .audio import SynthesizedAudio
|
||||
from .errors import TextToSpeechError
|
||||
from .factory import get_text_to_speech
|
||||
from .port import TextToSpeech
|
||||
from .request import SynthesisRequest, VoiceRef
|
||||
|
||||
__all__ = [
|
||||
"SynthesisRequest",
|
||||
"SynthesizedAudio",
|
||||
"TextToSpeech",
|
||||
"TextToSpeechError",
|
||||
"VoiceRef",
|
||||
"get_text_to_speech",
|
||||
]
|
||||
3
surfsense_backend/app/podcasts/tts/adapters/__init__.py
Normal file
3
surfsense_backend/app/podcasts/tts/adapters/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
"""Per-provider TextToSpeech implementations."""
|
||||
|
||||
from __future__ import annotations
|
||||
109
surfsense_backend/app/podcasts/tts/adapters/kokoro.py
Normal file
109
surfsense_backend/app/podcasts/tts/adapters/kokoro.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
"""Local Kokoro adapter: on-box synthesis, no network or per-segment cost.
|
||||
|
||||
Kokoro selects its language model by a single-letter ``lang_code``, so this
|
||||
adapter maps the brief's BCP-47 tag to that code and caches one pipeline per
|
||||
code (pipeline construction loads weights and is expensive). Pipelines run in a
|
||||
thread pool because Kokoro is synchronous; the renderer caps how many segments
|
||||
synthesise at once.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import io
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..audio import SynthesizedAudio
|
||||
from ..errors import TextToSpeechError
|
||||
from ..port import TextToSpeech
|
||||
from ..request import SynthesisRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kokoro import KPipeline
|
||||
|
||||
# Kokoro emits 24 kHz mono PCM regardless of voice.
|
||||
_SAMPLE_RATE = 24000
|
||||
|
||||
# BCP-47 primary subtag -> Kokoro language code. English defaults to American;
|
||||
# the en-GB region override below switches it to British.
|
||||
_LANG_CODE_BY_PRIMARY = {
|
||||
"en": "a",
|
||||
"es": "e",
|
||||
"fr": "f",
|
||||
"hi": "h",
|
||||
"it": "i",
|
||||
"ja": "j",
|
||||
"pt": "p",
|
||||
"zh": "z",
|
||||
}
|
||||
|
||||
|
||||
class KokoroTextToSpeech(TextToSpeech):
|
||||
"""Synthesises segments with locally hosted Kokoro pipelines."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pipelines: dict[str, KPipeline] = {}
|
||||
|
||||
@property
|
||||
def container(self) -> str:
|
||||
return "wav"
|
||||
|
||||
async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio:
|
||||
if not isinstance(request.voice, str):
|
||||
raise TextToSpeechError("Kokoro voices are named by string, not a mapping")
|
||||
|
||||
pipeline = self._pipeline_for(request.language)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
generator = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: pipeline(
|
||||
request.text,
|
||||
voice=request.voice,
|
||||
speed=request.speed,
|
||||
split_pattern=r"\n+",
|
||||
),
|
||||
)
|
||||
segments = [audio for _gs, _ps, audio in generator]
|
||||
except Exception as exc:
|
||||
raise TextToSpeechError(f"Kokoro synthesis failed: {exc}") from exc
|
||||
|
||||
if not segments:
|
||||
raise TextToSpeechError("Kokoro produced no audio for the text")
|
||||
|
||||
return SynthesizedAudio(
|
||||
data=_encode_wav(segments, _SAMPLE_RATE),
|
||||
container="wav",
|
||||
sample_rate=_SAMPLE_RATE,
|
||||
)
|
||||
|
||||
def _pipeline_for(self, language: str) -> KPipeline:
|
||||
lang_code = _lang_code(language)
|
||||
pipeline = self._pipelines.get(lang_code)
|
||||
if pipeline is None:
|
||||
from kokoro import KPipeline
|
||||
|
||||
pipeline = KPipeline(lang_code=lang_code)
|
||||
self._pipelines[lang_code] = pipeline
|
||||
return pipeline
|
||||
|
||||
|
||||
def _lang_code(language: str) -> str:
|
||||
normalised = language.strip().lower()
|
||||
if normalised.startswith("en-gb") or normalised == "en-uk":
|
||||
return "b"
|
||||
primary = normalised.partition("-")[0]
|
||||
code = _LANG_CODE_BY_PRIMARY.get(primary)
|
||||
if code is None:
|
||||
raise TextToSpeechError(f"Kokoro has no language model for {language!r}")
|
||||
return code
|
||||
|
||||
|
||||
def _encode_wav(segments: list, sample_rate: int) -> bytes:
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
|
||||
waveform = segments[0] if len(segments) == 1 else np.concatenate(segments)
|
||||
buffer = io.BytesIO()
|
||||
sf.write(buffer, waveform, sample_rate, format="WAV")
|
||||
return buffer.getvalue()
|
||||
67
surfsense_backend/app/podcasts/tts/adapters/litellm.py
Normal file
67
surfsense_backend/app/podcasts/tts/adapters/litellm.py
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
"""LiteLLM adapter: hosted TTS (OpenAI, Azure, Vertex AI) via one ``aspeech`` call.
|
||||
|
||||
LiteLLM normalises every hosted provider behind the same ``aspeech`` surface,
|
||||
so a single adapter covers them all. The provider is encoded in the model
|
||||
string (e.g. ``openai/tts-1``, ``vertex_ai/...``) and the voice reference is
|
||||
whatever that provider expects, which the catalog already supplies.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..audio import SynthesizedAudio
|
||||
from ..errors import TextToSpeechError
|
||||
from ..port import TextToSpeech
|
||||
from ..request import SynthesisRequest
|
||||
|
||||
# Hosted providers return MP3-encoded bytes from ``aspeech``.
|
||||
_CONTAINER = "mp3"
|
||||
|
||||
# A long single segment still finishes well under this; retries absorb transient
|
||||
# upstream failures without failing the whole render.
|
||||
_TIMEOUT_SECONDS = 600
|
||||
_MAX_RETRIES = 2
|
||||
|
||||
|
||||
class LiteLlmTextToSpeech(TextToSpeech):
|
||||
"""Synthesises segments through any LiteLLM-supported hosted TTS model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
api_base: str | None = None,
|
||||
api_key: str | None = None,
|
||||
) -> None:
|
||||
self._model = model
|
||||
self._api_base = api_base
|
||||
self._api_key = api_key
|
||||
|
||||
@property
|
||||
def container(self) -> str:
|
||||
return _CONTAINER
|
||||
|
||||
async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio:
|
||||
from litellm import aspeech
|
||||
|
||||
kwargs = {
|
||||
"model": self._model,
|
||||
"voice": request.voice,
|
||||
"input": request.text,
|
||||
"max_retries": _MAX_RETRIES,
|
||||
"timeout": _TIMEOUT_SECONDS,
|
||||
}
|
||||
if self._api_base:
|
||||
kwargs["api_base"] = self._api_base
|
||||
if self._api_key:
|
||||
kwargs["api_key"] = self._api_key
|
||||
|
||||
try:
|
||||
response = await aspeech(**kwargs)
|
||||
except Exception as exc:
|
||||
raise TextToSpeechError(f"{self._model} synthesis failed: {exc}") from exc
|
||||
|
||||
data = getattr(response, "content", None)
|
||||
if not data:
|
||||
raise TextToSpeechError(f"{self._model} returned no audio")
|
||||
|
||||
return SynthesizedAudio(data=data, container=_CONTAINER)
|
||||
19
surfsense_backend/app/podcasts/tts/audio.py
Normal file
19
surfsense_backend/app/podcasts/tts/audio.py
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
"""The bytes a TTS provider returns for one segment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SynthesizedAudio:
|
||||
"""Encoded audio for a single segment, ready to cache and concatenate.
|
||||
|
||||
``container`` is the file extension the bytes are encoded as (``"wav"`` or
|
||||
``"mp3"``); the renderer uses it to name the on-disk segment so FFmpeg can
|
||||
demux the right format during merge.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
container: str
|
||||
sample_rate: int | None = None
|
||||
13
surfsense_backend/app/podcasts/tts/errors.py
Normal file
13
surfsense_backend/app/podcasts/tts/errors.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
"""Failures raised by the TTS layer."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class TextToSpeechError(RuntimeError):
|
||||
"""A provider failed to synthesise a segment.
|
||||
|
||||
Raised for both configuration faults (an unusable voice reference) and
|
||||
provider faults (the upstream call errored or returned no audio), so the
|
||||
renderer can fail the segment without unwrapping provider-specific
|
||||
exceptions.
|
||||
"""
|
||||
38
surfsense_backend/app/podcasts/tts/factory.py
Normal file
38
surfsense_backend/app/podcasts/tts/factory.py
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
"""Resolve the configured :class:`TextToSpeech` as a process-wide singleton."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
from .port import TextToSpeech
|
||||
|
||||
# Sentinel model string that selects the local Kokoro pipeline; anything else is
|
||||
# treated as a LiteLLM-hosted model (``openai/...``, ``vertex_ai/...``, etc.).
|
||||
KOKORO_SERVICE = "local/kokoro"
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_text_to_speech() -> TextToSpeech:
|
||||
"""Build the provider selected by ``TTS_SERVICE`` (adapters lazy-imported).
|
||||
|
||||
Cached because the Kokoro adapter holds loaded pipelines that must be reused
|
||||
across segments and requests rather than rebuilt per call.
|
||||
"""
|
||||
from app.config import config as app_config
|
||||
|
||||
service = app_config.TTS_SERVICE
|
||||
if not service:
|
||||
raise ValueError("TTS_SERVICE is not configured")
|
||||
|
||||
if service == KOKORO_SERVICE:
|
||||
from .adapters.kokoro import KokoroTextToSpeech
|
||||
|
||||
return KokoroTextToSpeech()
|
||||
|
||||
from .adapters.litellm import LiteLlmTextToSpeech
|
||||
|
||||
return LiteLlmTextToSpeech(
|
||||
model=service,
|
||||
api_base=app_config.TTS_SERVICE_API_BASE,
|
||||
api_key=app_config.TTS_SERVICE_API_KEY,
|
||||
)
|
||||
31
surfsense_backend/app/podcasts/tts/port.py
Normal file
31
surfsense_backend/app/podcasts/tts/port.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
"""The TTS contract: turn one segment of text into encoded audio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from .audio import SynthesizedAudio
|
||||
from .request import SynthesisRequest
|
||||
|
||||
|
||||
class TextToSpeech(ABC):
|
||||
"""Synthesises a single segment; one implementation per provider.
|
||||
|
||||
The contract is intentionally per-segment rather than per-episode: it keeps
|
||||
each call independently cacheable and lets the renderer cap concurrency and
|
||||
retry segments in isolation. Stitching segments into one file is the
|
||||
renderer's job, not the provider's.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def container(self) -> str:
|
||||
"""File extension/container this provider emits (e.g. ``"mp3"``)."""
|
||||
|
||||
@abstractmethod
|
||||
async def synthesize(self, request: SynthesisRequest) -> SynthesizedAudio:
|
||||
"""Voice ``request.text`` and return its encoded audio.
|
||||
|
||||
Raises :class:`~app.podcasts.tts.errors.TextToSpeechError` on any
|
||||
provider or configuration failure.
|
||||
"""
|
||||
22
surfsense_backend/app/podcasts/tts/request.py
Normal file
22
surfsense_backend/app/podcasts/tts/request.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
"""What the renderer hands a TTS provider to voice a single segment."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# A provider-native voice reference. OpenAI/Azure/Kokoro name a voice with a
|
||||
# string; Vertex passes a mapping (``languageCode`` + ``name``). The catalog
|
||||
# stores whichever shape the provider expects and we pass it through untouched.
|
||||
VoiceRef = str | Mapping[str, Any]
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class SynthesisRequest:
|
||||
"""One unit of speech to synthesise: the smallest cacheable render step."""
|
||||
|
||||
text: str
|
||||
voice: VoiceRef
|
||||
language: str
|
||||
speed: float = 1.0
|
||||
23
surfsense_backend/app/podcasts/voices/__init__.py
Normal file
23
surfsense_backend/app/podcasts/voices/__init__.py
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
"""Voices: the catalog of selectable TTS voices and the active provider.
|
||||
|
||||
Callers obtain the catalog via :func:`get_voice_catalog` and identify the
|
||||
configured provider via :func:`provider_from_service`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .catalog import VoiceCatalog, get_voice_catalog
|
||||
from .preview import render_voice_preview
|
||||
from .provider import TtsProvider, provider_from_service
|
||||
from .voice import ANY_LANGUAGE, CatalogVoice, VoiceGender
|
||||
|
||||
__all__ = [
|
||||
"ANY_LANGUAGE",
|
||||
"CatalogVoice",
|
||||
"TtsProvider",
|
||||
"VoiceCatalog",
|
||||
"VoiceGender",
|
||||
"get_voice_catalog",
|
||||
"provider_from_service",
|
||||
"render_voice_preview",
|
||||
]
|
||||
51
surfsense_backend/app/podcasts/voices/catalog.py
Normal file
51
surfsense_backend/app/podcasts/voices/catalog.py
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
"""The voice catalog: look up and filter selectable voices.
|
||||
|
||||
A :class:`VoiceCatalog` is the single source of truth for which voices exist.
|
||||
Resolution uses it to pick defaults for a brief, the API exposes it as picker
|
||||
options, and the renderer uses it to turn a stored ``voice_id`` back into the
|
||||
provider-native reference.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import lru_cache
|
||||
|
||||
from .data import AZURE_VOICES, KOKORO_VOICES, OPENAI_VOICES, VERTEX_VOICES
|
||||
from .provider import TtsProvider
|
||||
from .voice import CatalogVoice
|
||||
|
||||
|
||||
class VoiceCatalog:
|
||||
"""An indexed, read-only collection of :class:`CatalogVoice`."""
|
||||
|
||||
def __init__(self, voices: Iterable[CatalogVoice]) -> None:
|
||||
self._by_id: dict[str, CatalogVoice] = {}
|
||||
self._by_provider: dict[TtsProvider, list[CatalogVoice]] = {}
|
||||
for voice in voices:
|
||||
if voice.voice_id in self._by_id:
|
||||
raise ValueError(f"duplicate voice_id: {voice.voice_id}")
|
||||
self._by_id[voice.voice_id] = voice
|
||||
self._by_provider.setdefault(voice.provider, []).append(voice)
|
||||
|
||||
def get(self, voice_id: str) -> CatalogVoice:
|
||||
"""Return the voice with ``voice_id`` or raise ``KeyError``."""
|
||||
return self._by_id[voice_id]
|
||||
|
||||
def for_provider(self, provider: TtsProvider) -> list[CatalogVoice]:
|
||||
"""All voices offered by ``provider``, in catalog order."""
|
||||
return list(self._by_provider.get(provider, ()))
|
||||
|
||||
def for_language(self, provider: TtsProvider, language: str) -> list[CatalogVoice]:
|
||||
"""``provider`` voices that can render ``language``, in catalog order."""
|
||||
return [v for v in self.for_provider(provider) if v.speaks(language)]
|
||||
|
||||
def supports_language(self, provider: TtsProvider, language: str) -> bool:
|
||||
"""Whether ``provider`` has at least one voice for ``language``."""
|
||||
return any(v.speaks(language) for v in self.for_provider(provider))
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_voice_catalog() -> VoiceCatalog:
|
||||
"""The process-wide catalog assembled from every provider's roster."""
|
||||
return VoiceCatalog((*KOKORO_VOICES, *OPENAI_VOICES, *AZURE_VOICES, *VERTEX_VOICES))
|
||||
10
surfsense_backend/app/podcasts/voices/data/__init__.py
Normal file
10
surfsense_backend/app/podcasts/voices/data/__init__.py
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
"""Static per-provider voice rosters that compose the catalog."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .azure import AZURE_VOICES
|
||||
from .kokoro import KOKORO_VOICES
|
||||
from .openai import OPENAI_VOICES
|
||||
from .vertex import VERTEX_VOICES
|
||||
|
||||
__all__ = ["AZURE_VOICES", "KOKORO_VOICES", "OPENAI_VOICES", "VERTEX_VOICES"]
|
||||
32
surfsense_backend/app/podcasts/voices/data/azure.py
Normal file
32
surfsense_backend/app/podcasts/voices/data/azure.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""Azure TTS voices, routed through the OpenAI-compatible voice names.
|
||||
|
||||
The deployment fronts Azure with OpenAI-style voice names (matching the legacy
|
||||
podcaster), so these mirror the OpenAI roster and, like it, speak any requested
|
||||
language.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..provider import TtsProvider
|
||||
from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender
|
||||
|
||||
|
||||
def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice:
|
||||
return CatalogVoice(
|
||||
voice_id=f"azure:{name}",
|
||||
provider=TtsProvider.AZURE,
|
||||
language=ANY_LANGUAGE,
|
||||
display_name=display,
|
||||
gender=gender,
|
||||
native_ref=name,
|
||||
)
|
||||
|
||||
|
||||
AZURE_VOICES: tuple[CatalogVoice, ...] = (
|
||||
_voice("alloy", "Alloy", VoiceGender.NEUTRAL),
|
||||
_voice("echo", "Echo", VoiceGender.MALE),
|
||||
_voice("fable", "Fable", VoiceGender.NEUTRAL),
|
||||
_voice("onyx", "Onyx", VoiceGender.MALE),
|
||||
_voice("nova", "Nova", VoiceGender.FEMALE),
|
||||
_voice("shimmer", "Shimmer", VoiceGender.FEMALE),
|
||||
)
|
||||
63
surfsense_backend/app/podcasts/voices/data/kokoro.py
Normal file
63
surfsense_backend/app/podcasts/voices/data/kokoro.py
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
"""Curated Kokoro voices, the local provider's multilingual roster.
|
||||
|
||||
Kokoro voice names encode language and gender in their first two letters
|
||||
(``a``=American English, ``b``=British, ``e``=Spanish, ``f``=French,
|
||||
``h``=Hindi, ``i``=Italian, ``j``=Japanese, ``p``=Brazilian Portuguese,
|
||||
``z``=Mandarin; second letter ``f``/``m`` = female/male). We carry at least one
|
||||
male and one female voice per language so a two-speaker brief always has a
|
||||
distinct pair. ``native_ref`` is the bare voice name Kokoro expects.
|
||||
|
||||
Reference: https://huggingface.co/hexgrad/Kokoro-82M/tree/main/voices
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..provider import TtsProvider
|
||||
from ..voice import CatalogVoice, VoiceGender
|
||||
|
||||
|
||||
def _voice(name: str, language: str, display: str, gender: VoiceGender) -> CatalogVoice:
|
||||
return CatalogVoice(
|
||||
voice_id=f"kokoro:{name}",
|
||||
provider=TtsProvider.KOKORO,
|
||||
language=language,
|
||||
display_name=display,
|
||||
gender=gender,
|
||||
native_ref=name,
|
||||
)
|
||||
|
||||
|
||||
KOKORO_VOICES: tuple[CatalogVoice, ...] = (
|
||||
# American English
|
||||
_voice("am_adam", "en-US", "Adam (US)", VoiceGender.MALE),
|
||||
_voice("am_michael", "en-US", "Michael (US)", VoiceGender.MALE),
|
||||
_voice("af_bella", "en-US", "Bella (US)", VoiceGender.FEMALE),
|
||||
_voice("af_heart", "en-US", "Heart (US)", VoiceGender.FEMALE),
|
||||
_voice("af_nicole", "en-US", "Nicole (US)", VoiceGender.FEMALE),
|
||||
_voice("af_sarah", "en-US", "Sarah (US)", VoiceGender.FEMALE),
|
||||
# British English
|
||||
_voice("bm_george", "en-GB", "George (UK)", VoiceGender.MALE),
|
||||
_voice("bm_lewis", "en-GB", "Lewis (UK)", VoiceGender.MALE),
|
||||
_voice("bf_emma", "en-GB", "Emma (UK)", VoiceGender.FEMALE),
|
||||
_voice("bf_isabella", "en-GB", "Isabella (UK)", VoiceGender.FEMALE),
|
||||
# Spanish
|
||||
_voice("em_alex", "es", "Alex (ES)", VoiceGender.MALE),
|
||||
_voice("ef_dora", "es", "Dora (ES)", VoiceGender.FEMALE),
|
||||
# French
|
||||
_voice("ff_siwis", "fr", "Siwis (FR)", VoiceGender.FEMALE),
|
||||
# Hindi
|
||||
_voice("hm_omega", "hi", "Omega (HI)", VoiceGender.MALE),
|
||||
_voice("hf_alpha", "hi", "Alpha (HI)", VoiceGender.FEMALE),
|
||||
# Italian
|
||||
_voice("im_nicola", "it", "Nicola (IT)", VoiceGender.MALE),
|
||||
_voice("if_sara", "it", "Sara (IT)", VoiceGender.FEMALE),
|
||||
# Japanese
|
||||
_voice("jm_kumo", "ja", "Kumo (JA)", VoiceGender.MALE),
|
||||
_voice("jf_alpha", "ja", "Alpha (JA)", VoiceGender.FEMALE),
|
||||
# Brazilian Portuguese
|
||||
_voice("pm_alex", "pt-BR", "Alex (BR)", VoiceGender.MALE),
|
||||
_voice("pf_dora", "pt-BR", "Dora (BR)", VoiceGender.FEMALE),
|
||||
# Mandarin Chinese
|
||||
_voice("zm_yunxi", "zh", "Yunxi (ZH)", VoiceGender.MALE),
|
||||
_voice("zf_xiaoxiao", "zh", "Xiaoxiao (ZH)", VoiceGender.FEMALE),
|
||||
)
|
||||
32
surfsense_backend/app/podcasts/voices/data/openai.py
Normal file
32
surfsense_backend/app/podcasts/voices/data/openai.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""OpenAI TTS voices: language-agnostic, so each speaks any requested language.
|
||||
|
||||
OpenAI voices follow the language of the input text rather than being tied to a
|
||||
locale, so they are tagged :data:`ANY_LANGUAGE` and match every brief. The
|
||||
``native_ref`` is the plain voice name the API expects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..provider import TtsProvider
|
||||
from ..voice import ANY_LANGUAGE, CatalogVoice, VoiceGender
|
||||
|
||||
|
||||
def _voice(name: str, display: str, gender: VoiceGender) -> CatalogVoice:
|
||||
return CatalogVoice(
|
||||
voice_id=f"openai:{name}",
|
||||
provider=TtsProvider.OPENAI,
|
||||
language=ANY_LANGUAGE,
|
||||
display_name=display,
|
||||
gender=gender,
|
||||
native_ref=name,
|
||||
)
|
||||
|
||||
|
||||
OPENAI_VOICES: tuple[CatalogVoice, ...] = (
|
||||
_voice("alloy", "Alloy", VoiceGender.NEUTRAL),
|
||||
_voice("echo", "Echo", VoiceGender.MALE),
|
||||
_voice("fable", "Fable", VoiceGender.NEUTRAL),
|
||||
_voice("onyx", "Onyx", VoiceGender.MALE),
|
||||
_voice("nova", "Nova", VoiceGender.FEMALE),
|
||||
_voice("shimmer", "Shimmer", VoiceGender.FEMALE),
|
||||
)
|
||||
81
surfsense_backend/app/podcasts/voices/data/vertex.py
Normal file
81
surfsense_backend/app/podcasts/voices/data/vertex.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
"""Vertex AI Studio voices: locale-specific, referenced by a mapping.
|
||||
|
||||
Vertex voices are tied to a locale and named via a ``{languageCode, name}``
|
||||
mapping, which is exactly the ``native_ref`` the LiteLLM adapter forwards. The
|
||||
values mirror the legacy podcaster's English Studio voices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ..provider import TtsProvider
|
||||
from ..voice import CatalogVoice, VoiceGender
|
||||
|
||||
|
||||
def _voice(
|
||||
key: str,
|
||||
language: str,
|
||||
locale: str,
|
||||
name: str,
|
||||
display: str,
|
||||
gender: VoiceGender,
|
||||
) -> CatalogVoice:
|
||||
return CatalogVoice(
|
||||
voice_id=f"vertex_ai:{key}",
|
||||
provider=TtsProvider.VERTEX_AI,
|
||||
language=language,
|
||||
display_name=display,
|
||||
gender=gender,
|
||||
native_ref={"languageCode": locale, "name": name},
|
||||
)
|
||||
|
||||
|
||||
VERTEX_VOICES: tuple[CatalogVoice, ...] = (
|
||||
_voice(
|
||||
"en-US-Studio-O",
|
||||
"en-US",
|
||||
"en-US",
|
||||
"en-US-Studio-O",
|
||||
"Studio O (US)",
|
||||
VoiceGender.FEMALE,
|
||||
),
|
||||
_voice(
|
||||
"en-US-Studio-M",
|
||||
"en-US",
|
||||
"en-US",
|
||||
"en-US-Studio-M",
|
||||
"Studio M (US)",
|
||||
VoiceGender.MALE,
|
||||
),
|
||||
_voice(
|
||||
"en-GB-Studio-A",
|
||||
"en-GB",
|
||||
"en-UK",
|
||||
"en-UK-Studio-A",
|
||||
"Studio A (UK)",
|
||||
VoiceGender.FEMALE,
|
||||
),
|
||||
_voice(
|
||||
"en-GB-Studio-B",
|
||||
"en-GB",
|
||||
"en-UK",
|
||||
"en-UK-Studio-B",
|
||||
"Studio B (UK)",
|
||||
VoiceGender.MALE,
|
||||
),
|
||||
_voice(
|
||||
"en-AU-Studio-A",
|
||||
"en-AU",
|
||||
"en-AU",
|
||||
"en-AU-Studio-A",
|
||||
"Studio A (AU)",
|
||||
VoiceGender.FEMALE,
|
||||
),
|
||||
_voice(
|
||||
"en-AU-Studio-B",
|
||||
"en-AU",
|
||||
"en-AU",
|
||||
"en-AU-Studio-B",
|
||||
"Studio B (AU)",
|
||||
VoiceGender.MALE,
|
||||
),
|
||||
)
|
||||
65
surfsense_backend/app/podcasts/voices/preview.py
Normal file
65
surfsense_backend/app/podcasts/voices/preview.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
"""Audible previews so users pick voices by sound, not by name.
|
||||
|
||||
A preview is a short sample sentence synthesised in the voice's own language.
|
||||
Samples are served through the same content-addressed cache the renderer uses,
|
||||
so each voice costs at most one synthesis per cache lifetime — repeat listens
|
||||
while comparing voices are free.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.podcasts.rendering.cache import SegmentCache
|
||||
from app.podcasts.tts import SynthesisRequest, TextToSpeech
|
||||
|
||||
from .voice import ANY_LANGUAGE, CatalogVoice
|
||||
|
||||
# Previews are user-independent, so one rendered sample serves everyone.
|
||||
PREVIEW_CACHE_ROOT = Path(tempfile.gettempdir()) / "surfsense_podcasts" / "previews"
|
||||
|
||||
_FALLBACK_LANGUAGE = "en"
|
||||
|
||||
# A voice previews best speaking its own language.
|
||||
_SAMPLE_TEXTS = {
|
||||
"en": "Hi there! This is how I sound when narrating your podcast.",
|
||||
"es": "¡Hola! Así sueno cuando narro tu pódcast.",
|
||||
"fr": "Bonjour ! Voici ma voix quand je raconte votre podcast.",
|
||||
"hi": "नमस्ते! आपका पॉडकास्ट सुनाते समय मेरी आवाज़ ऐसी होती है।",
|
||||
"it": "Ciao! Questa è la mia voce quando racconto il tuo podcast.",
|
||||
"ja": "こんにちは。ポッドキャストをお届けするときの私の声です。",
|
||||
"pt": "Olá! É assim que eu soo ao narrar o seu podcast.",
|
||||
"zh": "你好!这就是我为你播报播客时的声音。", # noqa: RUF001
|
||||
}
|
||||
|
||||
_CONTENT_TYPES = {"mp3": "audio/mpeg", "wav": "audio/wav"}
|
||||
|
||||
|
||||
async def render_voice_preview(
|
||||
voice: CatalogVoice, tts: TextToSpeech
|
||||
) -> tuple[bytes, str]:
|
||||
"""Return ``(audio_bytes, content_type)`` for a sample spoken by ``voice``."""
|
||||
language = _FALLBACK_LANGUAGE if voice.language == ANY_LANGUAGE else voice.language
|
||||
request = SynthesisRequest(
|
||||
text=_sample_text(language), voice=voice.native_ref, language=language
|
||||
)
|
||||
|
||||
cache = SegmentCache(PREVIEW_CACHE_ROOT)
|
||||
key = cache.key(request)
|
||||
cached = cache.get(key, tts.container)
|
||||
if cached is not None:
|
||||
return cached.read_bytes(), _content_type(tts.container)
|
||||
|
||||
audio = await tts.synthesize(request)
|
||||
cache.put(key, audio.container, audio.data)
|
||||
return audio.data, _content_type(audio.container)
|
||||
|
||||
|
||||
def _sample_text(language: str) -> str:
|
||||
primary = language.split("-", 1)[0].strip().lower()
|
||||
return _SAMPLE_TEXTS.get(primary, _SAMPLE_TEXTS[_FALLBACK_LANGUAGE])
|
||||
|
||||
|
||||
def _content_type(container: str) -> str:
|
||||
return _CONTENT_TYPES.get(container, "application/octet-stream")
|
||||
27
surfsense_backend/app/podcasts/voices/provider.py
Normal file
27
surfsense_backend/app/podcasts/voices/provider.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""The TTS providers we carry voices for, and how to name one from config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class TtsProvider(StrEnum):
|
||||
"""A speech provider whose voices the catalog enumerates."""
|
||||
|
||||
KOKORO = "kokoro"
|
||||
OPENAI = "openai"
|
||||
AZURE = "azure"
|
||||
VERTEX_AI = "vertex_ai"
|
||||
|
||||
|
||||
def provider_from_service(service: str) -> TtsProvider:
|
||||
"""Map a ``TTS_SERVICE`` string to its provider.
|
||||
|
||||
The config value is a LiteLLM-style ``provider/model`` string
|
||||
(``openai/tts-1``, ``vertex_ai/...``) except for local Kokoro, which is
|
||||
spelled ``local/kokoro``; both halves of that special case resolve here.
|
||||
"""
|
||||
prefix = service.split("/", 1)[0].strip().lower()
|
||||
if prefix == "local":
|
||||
return TtsProvider.KOKORO
|
||||
return TtsProvider(prefix)
|
||||
50
surfsense_backend/app/podcasts/voices/voice.py
Normal file
50
surfsense_backend/app/podcasts/voices/voice.py
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
"""A catalog voice: a stable id paired with its provider-native reference."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from app.podcasts.tts import VoiceRef
|
||||
|
||||
from .provider import TtsProvider
|
||||
|
||||
# A voice that speaks whatever language the input text is in (e.g. OpenAI's
|
||||
# voices), matched against every requested language.
|
||||
ANY_LANGUAGE = "*"
|
||||
|
||||
|
||||
class VoiceGender(StrEnum):
|
||||
"""Perceived voice gender, used to pick distinct voices per speaker."""
|
||||
|
||||
MALE = "male"
|
||||
FEMALE = "female"
|
||||
NEUTRAL = "neutral"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class CatalogVoice:
|
||||
"""One selectable voice.
|
||||
|
||||
``voice_id`` is the provider-prefixed, stable id stored on a speaker in the
|
||||
brief (e.g. ``"kokoro:am_adam"``). ``native_ref`` is the untyped value the
|
||||
TTS adapter passes to the provider — a string for most, a mapping for
|
||||
Vertex — kept separate so renaming the catalog id never breaks synthesis.
|
||||
"""
|
||||
|
||||
voice_id: str
|
||||
provider: TtsProvider
|
||||
language: str
|
||||
display_name: str
|
||||
gender: VoiceGender
|
||||
native_ref: VoiceRef
|
||||
|
||||
def speaks(self, language: str) -> bool:
|
||||
"""Whether this voice can render ``language`` (primary subtag match)."""
|
||||
if self.language == ANY_LANGUAGE:
|
||||
return True
|
||||
return _primary(self.language) == _primary(language)
|
||||
|
||||
|
||||
def _primary(language: str) -> str:
|
||||
return language.split("-", 1)[0].strip().lower()
|
||||
|
|
@ -4,6 +4,7 @@ from app.automations.api import router as automations_router
|
|||
from app.file_storage.api import router as file_storage_router
|
||||
from app.gateway import require_gateway_enabled
|
||||
from app.notifications.api import router as notifications_router
|
||||
from app.podcasts.api import router as podcasts_router
|
||||
|
||||
from .agent_action_log_route import router as agent_action_log_router
|
||||
from .agent_flags_route import router as agent_flags_router
|
||||
|
|
@ -50,7 +51,6 @@ from .notes_routes import router as notes_router
|
|||
from .notion_add_connector_route import router as notion_add_connector_router
|
||||
from .obsidian_plugin_routes import router as obsidian_plugin_router
|
||||
from .onedrive_add_connector_route import router as onedrive_add_connector_router
|
||||
from .podcasts_routes import router as podcasts_router
|
||||
from .prompts_routes import router as prompts_router
|
||||
from .public_chat_routes import router as public_chat_router
|
||||
from .rbac_routes import router as rbac_router
|
||||
|
|
|
|||
|
|
@ -622,11 +622,10 @@ async def create_image_generation(
|
|||
detail={
|
||||
"error_code": "premium_quota_exhausted",
|
||||
"usage_type": exc.usage_type,
|
||||
"used_micros": exc.used_micros,
|
||||
"limit_micros": exc.limit_micros,
|
||||
"balance_micros": exc.balance_micros,
|
||||
"remaining_micros": exc.remaining_micros,
|
||||
"message": (
|
||||
"Out of premium credits for image generation. "
|
||||
"Out of credits for image generation. "
|
||||
"Purchase additional credits or switch to a free model."
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Incentive Tasks API routes.
|
||||
Allows users to complete tasks (like starring GitHub repo) to earn free pages.
|
||||
Allows users to complete tasks (like starring GitHub repo) to earn free credits.
|
||||
Each task can only be completed once per user.
|
||||
"""
|
||||
|
||||
|
|
@ -42,21 +42,21 @@ async def get_incentive_tasks(
|
|||
|
||||
# Build task list with completion status
|
||||
tasks = []
|
||||
total_pages_earned = 0
|
||||
total_credit_micros_earned = 0
|
||||
|
||||
for task_type, config in INCENTIVE_TASKS_CONFIG.items():
|
||||
completed_task = completed_tasks.get(task_type)
|
||||
is_completed = completed_task is not None
|
||||
|
||||
if is_completed:
|
||||
total_pages_earned += completed_task.pages_awarded
|
||||
total_credit_micros_earned += completed_task.credit_micros_awarded
|
||||
|
||||
tasks.append(
|
||||
IncentiveTaskInfo(
|
||||
task_type=task_type,
|
||||
title=config["title"],
|
||||
description=config["description"],
|
||||
pages_reward=config["pages_reward"],
|
||||
credit_micros_reward=config["credit_micros_reward"],
|
||||
action_url=config["action_url"],
|
||||
completed=is_completed,
|
||||
completed_at=completed_task.completed_at if completed_task else None,
|
||||
|
|
@ -65,7 +65,7 @@ async def get_incentive_tasks(
|
|||
|
||||
return IncentiveTasksResponse(
|
||||
tasks=tasks,
|
||||
total_pages_earned=total_pages_earned,
|
||||
total_credit_micros_earned=total_credit_micros_earned,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -79,10 +79,10 @@ async def complete_task(
|
|||
session: AsyncSession = Depends(get_async_session),
|
||||
) -> CompleteTaskResponse | TaskAlreadyCompletedResponse:
|
||||
"""
|
||||
Mark an incentive task as completed and award pages to the user.
|
||||
Mark an incentive task as completed and award credit to the user.
|
||||
|
||||
Each task can only be completed once. If the task was already completed,
|
||||
returns the existing completion information without awarding additional pages.
|
||||
returns the existing completion information without awarding additional credit.
|
||||
"""
|
||||
# Validate task type exists in config
|
||||
task_config = INCENTIVE_TASKS_CONFIG.get(task_type)
|
||||
|
|
@ -109,25 +109,23 @@ async def complete_task(
|
|||
)
|
||||
|
||||
# Create the task completion record
|
||||
pages_reward = task_config["pages_reward"]
|
||||
credit_micros_reward = task_config["credit_micros_reward"]
|
||||
new_task = UserIncentiveTask(
|
||||
user_id=user.id,
|
||||
task_type=task_type,
|
||||
pages_awarded=pages_reward,
|
||||
credit_micros_awarded=credit_micros_reward,
|
||||
)
|
||||
session.add(new_task)
|
||||
|
||||
# pages_used can exceed pages_limit when a document's final page count is
|
||||
# determined after processing. Base the new limit on the higher of the two
|
||||
# so the rewarded pages are fully usable above the current high-water mark.
|
||||
user.pages_limit = max(user.pages_used, user.pages_limit) + pages_reward
|
||||
# Add the reward directly to the user's spendable wallet balance.
|
||||
user.credit_micros_balance = user.credit_micros_balance + credit_micros_reward
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return CompleteTaskResponse(
|
||||
success=True,
|
||||
message=f"Task completed! You earned {pages_reward} pages.",
|
||||
pages_awarded=pages_reward,
|
||||
new_pages_limit=user.pages_limit,
|
||||
message=f"Task completed! You earned ${credit_micros_reward / 1_000_000:.2f} of credit.",
|
||||
credit_micros_awarded=credit_micros_reward,
|
||||
new_balance_micros=user.credit_micros_balance,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,211 +0,0 @@
|
|||
"""
|
||||
Podcast routes for CRUD operations and audio streaming.
|
||||
|
||||
These routes support the podcast generation feature in new-chat.
|
||||
Frontend polls GET /podcasts/{podcast_id} to check status field.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.db import (
|
||||
Permission,
|
||||
Podcast,
|
||||
SearchSpace,
|
||||
SearchSpaceMembership,
|
||||
User,
|
||||
get_async_session,
|
||||
)
|
||||
from app.schemas import PodcastRead
|
||||
from app.users import current_active_user
|
||||
from app.utils.rbac import check_permission
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/podcasts", response_model=list[PodcastRead])
|
||||
async def read_podcasts(
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
search_space_id: int | None = None,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
List podcasts the user has access to.
|
||||
Requires PODCASTS_READ permission for the search space(s).
|
||||
"""
|
||||
if skip < 0 or limit < 1:
|
||||
raise HTTPException(status_code=400, detail="Invalid pagination parameters")
|
||||
try:
|
||||
if search_space_id is not None:
|
||||
# Check permission for specific search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.filter(Podcast.search_space_id == search_space_id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
else:
|
||||
# Get podcasts from all search spaces user has membership in
|
||||
result = await session.execute(
|
||||
select(Podcast)
|
||||
.join(SearchSpace)
|
||||
.join(SearchSpaceMembership)
|
||||
.filter(SearchSpaceMembership.user_id == user.id)
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
)
|
||||
return result.scalars().all()
|
||||
except HTTPException:
|
||||
raise
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcasts"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}", response_model=PodcastRead)
|
||||
async def read_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Get a specific podcast by ID.
|
||||
|
||||
Requires authentication with PODCASTS_READ permission.
|
||||
For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Podcast not found",
|
||||
)
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to read podcasts in this search space",
|
||||
)
|
||||
|
||||
return PodcastRead.from_orm_with_entries(podcast)
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while fetching podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.delete("/podcasts/{podcast_id}", response_model=dict)
|
||||
async def delete_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Delete a podcast.
|
||||
Requires PODCASTS_DELETE permission for the search space.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
db_podcast = result.scalars().first()
|
||||
|
||||
if not db_podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
# Check permission for the search space
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
db_podcast.search_space_id,
|
||||
Permission.PODCASTS_DELETE.value,
|
||||
"You don't have permission to delete podcasts in this search space",
|
||||
)
|
||||
|
||||
await session.delete(db_podcast)
|
||||
await session.commit()
|
||||
return {"message": "Podcast deleted successfully"}
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except SQLAlchemyError:
|
||||
await session.rollback()
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Database error occurred while deleting podcast"
|
||||
) from None
|
||||
|
||||
|
||||
@router.get("/podcasts/{podcast_id}/stream")
|
||||
@router.get("/podcasts/{podcast_id}/audio")
|
||||
async def stream_podcast(
|
||||
podcast_id: int,
|
||||
session: AsyncSession = Depends(get_async_session),
|
||||
user: User = Depends(current_active_user),
|
||||
):
|
||||
"""
|
||||
Stream a podcast audio file.
|
||||
|
||||
Requires authentication with PODCASTS_READ permission.
|
||||
For public podcast access, use /public/{share_token}/podcasts/{podcast_id}/stream
|
||||
|
||||
Note: Both /stream and /audio endpoints are supported for compatibility.
|
||||
"""
|
||||
try:
|
||||
result = await session.execute(select(Podcast).filter(Podcast.id == podcast_id))
|
||||
podcast = result.scalars().first()
|
||||
|
||||
if not podcast:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
await check_permission(
|
||||
session,
|
||||
user,
|
||||
podcast.search_space_id,
|
||||
Permission.PODCASTS_READ.value,
|
||||
"You don't have permission to access podcasts in this search space",
|
||||
)
|
||||
|
||||
file_path = podcast.file_location
|
||||
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
raise HTTPException(status_code=404, detail="Podcast audio file not found")
|
||||
|
||||
def iterfile():
|
||||
with open(file_path, mode="rb") as file_like:
|
||||
yield from file_like
|
||||
|
||||
return StreamingResponse(
|
||||
iterfile(),
|
||||
media_type="audio/mpeg",
|
||||
headers={
|
||||
"Accept-Ranges": "bytes",
|
||||
"Content-Disposition": f"inline; filename={Path(file_path).name}",
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException as he:
|
||||
raise he
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error streaming podcast: {e!s}"
|
||||
) from e
|
||||
|
|
@ -99,6 +99,17 @@ async def stream_public_podcast(
|
|||
if not podcast_info:
|
||||
raise HTTPException(status_code=404, detail="Podcast not found")
|
||||
|
||||
storage_key = podcast_info.get("storage_key")
|
||||
if storage_key:
|
||||
from app.file_storage.factory import get_storage_backend
|
||||
|
||||
return StreamingResponse(
|
||||
get_storage_backend().open_stream(storage_key),
|
||||
media_type="audio/mpeg",
|
||||
headers={"Accept-Ranges": "bytes"},
|
||||
)
|
||||
|
||||
# Legacy fallback for snapshots taken before the storage migration.
|
||||
file_path = podcast_info.get("file_path")
|
||||
|
||||
if not file_path or not os.path.isfile(file_path):
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue